diff --git a/package.json b/package.json index 880116d774..f0e8a02b3d 100644 --- a/package.json +++ b/package.json @@ -79,7 +79,7 @@ "node-stream-zip": "^1.15.0", "officeparser": "^4.2.0", "os-proxy-config": "^1.1.2", - "selection-hook": "^1.0.8", + "selection-hook": "^1.0.9", "turndown": "7.2.0" }, "devDependencies": { @@ -216,7 +216,7 @@ "lucide-react": "^0.525.0", "macos-release": "^3.4.0", "markdown-it": "^14.1.0", - "mermaid": "^11.7.0", + "mermaid": "^11.9.0", "mime": "^4.0.4", "motion": "^12.10.5", "notion-helper": "^1.3.22", diff --git a/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts b/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts index afc8d1ba9b..834ff2f27e 100644 --- a/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts +++ b/src/main/knowledge/preprocess/Doc2xPreprocessProvider.ts @@ -5,7 +5,7 @@ import { loggerService } from '@logger' import { fileStorage } from '@main/services/FileStorage' import { FileMetadata, PreprocessProvider } from '@types' import AdmZip from 'adm-zip' -import axios, { AxiosRequestConfig } from 'axios' +import { net } from 'electron' import BasePreprocessProvider from './BasePreprocessProvider' @@ -38,19 +38,24 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } private async validateFile(filePath: string): Promise { - const pdfBuffer = await fs.promises.readFile(filePath) + // 首先检查文件大小,避免读取大文件到内存 + const stats = await fs.promises.stat(filePath) + const fileSizeBytes = stats.size + // 文件大小小于300MB + if (fileSizeBytes >= 300 * 1024 * 1024) { + const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024)) + throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 300MB`) + } + + // 只有在文件大小合理的情况下才读取文件内容检查页数 + const pdfBuffer = await fs.promises.readFile(filePath) const doc = await this.readPdf(pdfBuffer) // 文件页数小于1000页 if (doc.numPages >= 1000) { throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 1000 pages`) } - // 文件大小小于300MB - if (pdfBuffer.length >= 300 * 1024 * 1024) { - const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024)) - throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 300MB`) - } } public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> { @@ -160,11 +165,23 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { * @returns 预上传响应的url和uid */ private async preupload(): Promise { - const config = this.createAuthConfig() const endpoint = `${this.provider.apiHost}/api/v2/parse/preupload` try { - const { data } = await axios.post>(endpoint, null, config) + const response = await net.fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.provider.apiKey}` + }, + body: null + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse if (data.code === 'success' && data.data) { return data.data @@ -178,17 +195,29 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } /** - * 上传文件 + * 上传文件(使用流式上传) * @param filePath 文件路径 * @param url 预上传响应的url */ private async putFile(filePath: string, url: string): Promise { try { - const fileStream = fs.createReadStream(filePath) - const response = await axios.put(url, fileStream) + // 获取文件大小用于设置 Content-Length + const stats = await fs.promises.stat(filePath) + const fileSize = stats.size - if (response.status !== 200) { - throw new Error(`HTTP status ${response.status}: ${response.statusText}`) + // 创建可读流 + const fileStream = fs.createReadStream(filePath) + + const response = await net.fetch(url, { + method: 'PUT', + body: fileStream as any, // TypeScript 类型转换,net.fetch 支持 ReadableStream + headers: { + 'Content-Length': fileSize.toString() + } + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) } } catch (error) { logger.error(`Failed to upload file ${filePath}: ${error instanceof Error ? error.message : String(error)}`) @@ -197,16 +226,25 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } private async getStatus(uid: string): Promise { - const config = this.createAuthConfig() const endpoint = `${this.provider.apiHost}/api/v2/parse/status?uid=${uid}` try { - const response = await axios.get>(endpoint, config) + const response = await net.fetch(endpoint, { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}` + } + }) - if (response.data.code === 'success' && response.data.data) { - return response.data.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse + if (data.code === 'success' && data.data) { + return data.data } else { - throw new Error(`API returned error: ${response.data.message || JSON.stringify(response.data)}`) + throw new Error(`API returned error: ${data.message || JSON.stringify(data)}`) } } catch (error) { logger.error(`Failed to get status for uid ${uid}: ${error instanceof Error ? error.message : String(error)}`) @@ -221,13 +259,6 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { */ private async convertFile(uid: string, filePath: string): Promise { const fileName = path.parse(filePath).name - const config = { - ...this.createAuthConfig(), - headers: { - ...this.createAuthConfig().headers, - 'Content-Type': 'application/json' - } - } const payload = { uid, @@ -239,10 +270,22 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { const endpoint = `${this.provider.apiHost}/api/v2/convert/parse` try { - const response = await axios.post>(endpoint, payload, config) + const response = await net.fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${this.provider.apiKey}` + }, + body: JSON.stringify(payload) + }) - if (response.data.code !== 'success') { - throw new Error(`API returned error: ${response.data.message || JSON.stringify(response.data)}`) + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse + if (data.code !== 'success') { + throw new Error(`API returned error: ${data.message || JSON.stringify(data)}`) } } catch (error) { logger.error(`Failed to convert file ${filePath}: ${error instanceof Error ? error.message : String(error)}`) @@ -256,16 +299,25 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { * @returns 解析后的文件信息 */ private async getParsedFile(uid: string): Promise { - const config = this.createAuthConfig() const endpoint = `${this.provider.apiHost}/api/v2/convert/parse/result?uid=${uid}` try { - const response = await axios.get>(endpoint, config) + const response = await net.fetch(endpoint, { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}` + } + }) - if (response.status === 200 && response.data.data) { - return response.data.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as ApiResponse + if (data.data) { + return data.data } else { - throw new Error(`HTTP status ${response.status}: ${response.statusText}`) + throw new Error(`No data in response`) } } catch (error) { logger.error( @@ -295,8 +347,12 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { try { // 下载文件 - const response = await axios.get(url, { responseType: 'arraybuffer' }) - fs.writeFileSync(zipPath, response.data) + const response = await net.fetch(url, { method: 'GET' }) + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + const arrayBuffer = await response.arrayBuffer() + fs.writeFileSync(zipPath, Buffer.from(arrayBuffer)) // 确保提取目录存在 if (!fs.existsSync(extractPath)) { @@ -318,14 +374,6 @@ export default class Doc2xPreprocessProvider extends BasePreprocessProvider { } } - private createAuthConfig(): AxiosRequestConfig { - return { - headers: { - Authorization: `Bearer ${this.provider.apiKey}` - } - } - } - public checkQuota(): Promise { throw new Error('Method not implemented.') } diff --git a/src/main/knowledge/preprocess/MineruPreprocessProvider.ts b/src/main/knowledge/preprocess/MineruPreprocessProvider.ts index 0e29a6443f..1976f64c05 100644 --- a/src/main/knowledge/preprocess/MineruPreprocessProvider.ts +++ b/src/main/knowledge/preprocess/MineruPreprocessProvider.ts @@ -5,7 +5,7 @@ import { loggerService } from '@logger' import { fileStorage } from '@main/services/FileStorage' import { FileMetadata, PreprocessProvider } from '@types' import AdmZip from 'adm-zip' -import axios from 'axios' +import { net } from 'electron' import BasePreprocessProvider from './BasePreprocessProvider' @@ -95,7 +95,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { public async checkQuota() { try { - const quota = await fetch(`${this.provider.apiHost}/api/v4/quota`, { + const quota = await net.fetch(`${this.provider.apiHost}/api/v4/quota`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -179,8 +179,12 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { try { // 下载ZIP文件 - const response = await axios.get(zipUrl, { responseType: 'arraybuffer' }) - fs.writeFileSync(zipPath, Buffer.from(response.data)) + const response = await net.fetch(zipUrl, { method: 'GET' }) + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + const arrayBuffer = await response.arrayBuffer() + fs.writeFileSync(zipPath, Buffer.from(arrayBuffer)) logger.info(`Downloaded ZIP file: ${zipPath}`) // 确保提取目录存在 @@ -236,7 +240,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { } try { - const response = await fetch(endpoint, { + const response = await net.fetch(endpoint, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -271,7 +275,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { try { const fileBuffer = await fs.promises.readFile(filePath) - const response = await fetch(uploadUrl, { + const response = await net.fetch(uploadUrl, { method: 'PUT', body: fileBuffer, headers: { @@ -316,7 +320,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { const endpoint = `${this.provider.apiHost}/api/v4/extract-results/batch/${batchId}` try { - const response = await fetch(endpoint, { + const response = await net.fetch(endpoint, { method: 'GET', headers: { 'Content-Type': 'application/json', diff --git a/src/main/knowledge/reranker/GeneralReranker.ts b/src/main/knowledge/reranker/GeneralReranker.ts index 1252ecad57..5a0e240a9d 100644 --- a/src/main/knowledge/reranker/GeneralReranker.ts +++ b/src/main/knowledge/reranker/GeneralReranker.ts @@ -1,6 +1,6 @@ import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { KnowledgeBaseParams } from '@types' -import axios from 'axios' +import { net } from 'electron' import BaseReranker from './BaseReranker' @@ -15,7 +15,17 @@ export default class GeneralReranker extends BaseReranker { const requestBody = this.getRerankRequestBody(query, searchResults) try { - const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() }) + const response = await net.fetch(url, { + method: 'POST', + headers: this.defaultHeaders(), + body: JSON.stringify(requestBody) + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = await response.json() const rerankResults = this.extractRerankResult(data) return this.getRerankResult(searchResults, rerankResults) diff --git a/src/main/mcpServers/brave-search.ts b/src/main/mcpServers/brave-search.ts index 6f219e1eb8..d11a4f2580 100644 --- a/src/main/mcpServers/brave-search.ts +++ b/src/main/mcpServers/brave-search.ts @@ -3,6 +3,7 @@ import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { CallToolRequestSchema, ListToolsRequestSchema, Tool } from '@modelcontextprotocol/sdk/types.js' +import { net } from 'electron' const WEB_SEARCH_TOOL: Tool = { name: 'brave_web_search', @@ -159,7 +160,7 @@ async function performWebSearch(apiKey: string, query: string, count: number = 1 url.searchParams.set('count', Math.min(count, 20).toString()) // API limit url.searchParams.set('offset', offset.toString()) - const response = await fetch(url, { + const response = await net.fetch(url.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', @@ -192,7 +193,7 @@ async function performLocalSearch(apiKey: string, query: string, count: number = webUrl.searchParams.set('result_filter', 'locations') webUrl.searchParams.set('count', Math.min(count, 20).toString()) - const webResponse = await fetch(webUrl, { + const webResponse = await net.fetch(webUrl.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', @@ -225,7 +226,7 @@ async function getPoisData(apiKey: string, ids: string[]): Promise url.searchParams.append('ids', id)) - const response = await fetch(url, { + const response = await net.fetch(url.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', @@ -244,7 +245,7 @@ async function getDescriptionsData(apiKey: string, ids: string[]): Promise url.searchParams.append('ids', id)) - const response = await fetch(url, { + const response = await net.fetch(url.toString(), { headers: { Accept: 'application/json', 'Accept-Encoding': 'gzip', diff --git a/src/main/mcpServers/dify-knowledge.ts b/src/main/mcpServers/dify-knowledge.ts index 2bd2c4adda..83a352fd4f 100644 --- a/src/main/mcpServers/dify-knowledge.ts +++ b/src/main/mcpServers/dify-knowledge.ts @@ -2,6 +2,7 @@ import { loggerService } from '@logger' import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { net } from 'electron' import * as z from 'zod/v4' const logger = loggerService.withContext('DifyKnowledgeServer') @@ -134,7 +135,7 @@ class DifyKnowledgeServer { private async performListKnowledges(difyKey: string, apiHost: string): Promise { try { const url = `${apiHost.replace(/\/$/, '')}/datasets` - const response = await fetch(url, { + const response = await net.fetch(url, { method: 'GET', headers: { Authorization: `Bearer ${difyKey}` @@ -186,7 +187,7 @@ class DifyKnowledgeServer { try { const url = `${apiHost.replace(/\/$/, '')}/datasets/${id}/retrieve` - const response = await fetch(url, { + const response = await net.fetch(url, { method: 'POST', headers: { Authorization: `Bearer ${difyKey}`, diff --git a/src/main/mcpServers/fetch.ts b/src/main/mcpServers/fetch.ts index 04839d8a92..e55b114776 100644 --- a/src/main/mcpServers/fetch.ts +++ b/src/main/mcpServers/fetch.ts @@ -2,6 +2,7 @@ import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js' +import { net } from 'electron' import { JSDOM } from 'jsdom' import TurndownService from 'turndown' import { z } from 'zod' @@ -16,7 +17,7 @@ export type RequestPayload = z.infer export class Fetcher { private static async _fetch({ url, headers }: RequestPayload): Promise { try { - const response = await fetch(url, { + const response = await net.fetch(url, { headers: { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', diff --git a/src/main/services/AppUpdater.ts b/src/main/services/AppUpdater.ts index e60dac31f0..bdfb8e3cc8 100644 --- a/src/main/services/AppUpdater.ts +++ b/src/main/services/AppUpdater.ts @@ -6,9 +6,10 @@ import { generateUserAgent } from '@main/utils/systemInfo' import { FeedUrl, UpgradeChannel } from '@shared/config/constant' import { IpcChannel } from '@shared/IpcChannel' import { CancellationToken, UpdateInfo } from 'builder-util-runtime' -import { app, BrowserWindow, dialog } from 'electron' +import { app, BrowserWindow, dialog, net } from 'electron' import { AppUpdater as _AppUpdater, autoUpdater, Logger, NsisUpdater, UpdateCheckResult } from 'electron-updater' import path from 'path' +import semver from 'semver' import icon from '../../../build/icon.png?asset' import { configManager } from './ConfigManager' @@ -44,12 +45,6 @@ export default class AppUpdater { // 检测到不需要更新时 autoUpdater.on('update-not-available', () => { - if (configManager.getTestPlan() && this.autoUpdater.channel !== UpgradeChannel.LATEST) { - logger.info('test plan is enabled, but update is not available, do not send update not available event') - // will not send update not available event, because will check for updates with latest channel - return - } - windowService.getMainWindow()?.webContents.send(IpcChannel.UpdateNotAvailable) }) @@ -72,18 +67,24 @@ export default class AppUpdater { this.autoUpdater = autoUpdater } - private async _getPreReleaseVersionFromGithub(channel: UpgradeChannel) { + private async _getReleaseVersionFromGithub(channel: UpgradeChannel) { + const headers = { + Accept: 'application/vnd.github+json', + 'X-GitHub-Api-Version': '2022-11-28', + 'Accept-Language': 'en-US,en;q=0.9' + } try { - logger.info(`get pre release version from github: ${channel}`) - const responses = await fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', { - headers: { - Accept: 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28', - 'Accept-Language': 'en-US,en;q=0.9' - } + logger.info(`get release version from github: ${channel}`) + const responses = await net.fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', { + headers }) const data = (await responses.json()) as GithubReleaseInfo[] + let mightHaveLatest = false const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => { + if (!item.draft && !item.prerelease) { + mightHaveLatest = true + } + return item.prerelease && item.tag_name.includes(`-${channel}.`) }) @@ -91,8 +92,29 @@ export default class AppUpdater { return null } - logger.info(`prerelease url is ${release.tag_name}, set channel to ${channel}`) + // if the release version is the same as the current version, return null + if (release.tag_name === app.getVersion()) { + return null + } + if (mightHaveLatest) { + logger.info(`might have latest release, get latest release`) + const latestReleaseResponse = await net.fetch( + 'https://api.github.com/repos/CherryHQ/cherry-studio/releases/latest', + { + headers + } + ) + const latestRelease = (await latestReleaseResponse.json()) as GithubReleaseInfo + if (semver.gt(latestRelease.tag_name, release.tag_name)) { + logger.info( + `latest release version is ${latestRelease.tag_name}, prerelease version is ${release.tag_name}, return null` + ) + return null + } + } + + logger.info(`release url is ${release.tag_name}, set channel to ${channel}`) return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}` } catch (error) { logger.error('Failed to get latest not draft version from github:', error as Error) @@ -151,14 +173,14 @@ export default class AppUpdater { return } - const preReleaseUrl = await this._getPreReleaseVersionFromGithub(channel) - if (preReleaseUrl) { - logger.info(`prerelease url is ${preReleaseUrl}, set channel to ${channel}`) - this._setChannel(channel, preReleaseUrl) + const releaseUrl = await this._getReleaseVersionFromGithub(channel) + if (releaseUrl) { + logger.info(`release url is ${releaseUrl}, set channel to ${channel}`) + this._setChannel(channel, releaseUrl) return } - // if no prerelease url, use github latest to avoid error + // if no prerelease url, use github latest to get release this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST) return } @@ -195,17 +217,6 @@ export default class AppUpdater { `update check result: ${this.updateCheckResult?.isUpdateAvailable}, channel: ${this.autoUpdater.channel}, currentVersion: ${this.autoUpdater.currentVersion}` ) - // if the update is not available, and the test plan is enabled, set the feed url to the github latest - if ( - !this.updateCheckResult?.isUpdateAvailable && - configManager.getTestPlan() && - this.autoUpdater.channel !== UpgradeChannel.LATEST - ) { - logger.info('test plan is enabled, but update is not available, set channel to latest') - this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST) - this.updateCheckResult = await this.autoUpdater.checkForUpdates() - } - if (this.updateCheckResult?.isUpdateAvailable && !this.autoUpdater.autoDownload) { // 如果 autoDownload 为 false,则需要再调用下面的函数触发下 // do not use await, because it will block the return of this function diff --git a/src/main/services/BackupManager.ts b/src/main/services/BackupManager.ts index 56d3a97379..c6d3ee1841 100644 --- a/src/main/services/BackupManager.ts +++ b/src/main/services/BackupManager.ts @@ -21,6 +21,27 @@ class BackupManager { private tempDir = path.join(app.getPath('temp'), 'cherry-studio', 'backup', 'temp') private backupDir = path.join(app.getPath('temp'), 'cherry-studio', 'backup') + // 缓存实例,避免重复创建 + private s3Storage: S3Storage | null = null + private webdavInstance: WebDav | null = null + + // 缓存核心连接配置,用于检测连接配置是否变更 + private cachedS3ConnectionConfig: { + endpoint: string + region: string + bucket: string + accessKeyId: string + secretAccessKey: string + root?: string + } | null = null + + private cachedWebdavConnectionConfig: { + webdavHost: string + webdavUser?: string + webdavPass?: string + webdavPath?: string + } | null = null + constructor() { this.checkConnection = this.checkConnection.bind(this) this.backup = this.backup.bind(this) @@ -87,6 +108,88 @@ class BackupManager { } } + /** + * 比较两个配置对象是否相等,只比较影响客户端连接的核心字段,忽略 fileName 等易变字段 + */ + private isS3ConfigEqual(cachedConfig: typeof this.cachedS3ConnectionConfig, config: S3Config): boolean { + if (!cachedConfig) return false + + return ( + cachedConfig.endpoint === config.endpoint && + cachedConfig.region === config.region && + cachedConfig.bucket === config.bucket && + cachedConfig.accessKeyId === config.accessKeyId && + cachedConfig.secretAccessKey === config.secretAccessKey && + cachedConfig.root === config.root + ) + } + + /** + * 深度比较两个 WebDAV 配置对象是否相等,只比较影响客户端连接的核心字段,忽略 fileName 等易变字段 + */ + private isWebDavConfigEqual(cachedConfig: typeof this.cachedWebdavConnectionConfig, config: WebDavConfig): boolean { + if (!cachedConfig) return false + + return ( + cachedConfig.webdavHost === config.webdavHost && + cachedConfig.webdavUser === config.webdavUser && + cachedConfig.webdavPass === config.webdavPass && + cachedConfig.webdavPath === config.webdavPath + ) + } + + /** + * 获取 S3Storage 实例,如果连接配置未变且实例已存在则复用,否则创建新实例 + * 注意:只有连接相关的配置变更才会重新创建实例,其他配置变更不影响实例复用 + */ + private getS3Storage(config: S3Config): S3Storage { + // 检查核心连接配置是否变更 + const configChanged = !this.isS3ConfigEqual(this.cachedS3ConnectionConfig, config) + + if (configChanged || !this.s3Storage) { + this.s3Storage = new S3Storage(config) + // 只缓存连接相关的配置字段 + this.cachedS3ConnectionConfig = { + endpoint: config.endpoint, + region: config.region, + bucket: config.bucket, + accessKeyId: config.accessKeyId, + secretAccessKey: config.secretAccessKey, + root: config.root + } + logger.debug('[BackupManager] Created new S3Storage instance') + } else { + logger.debug('[BackupManager] Reusing existing S3Storage instance') + } + + return this.s3Storage + } + + /** + * 获取 WebDav 实例,如果连接配置未变且实例已存在则复用,否则创建新实例 + * 注意:只有连接相关的配置变更才会重新创建实例,其他配置变更不影响实例复用 + */ + private getWebDavInstance(config: WebDavConfig): WebDav { + // 检查核心连接配置是否变更 + const configChanged = !this.isWebDavConfigEqual(this.cachedWebdavConnectionConfig, config) + + if (configChanged || !this.webdavInstance) { + this.webdavInstance = new WebDav(config) + // 只缓存连接相关的配置字段 + this.cachedWebdavConnectionConfig = { + webdavHost: config.webdavHost, + webdavUser: config.webdavUser, + webdavPass: config.webdavPass, + webdavPath: config.webdavPath + } + logger.debug('[BackupManager] Created new WebDav instance') + } else { + logger.debug('[BackupManager] Reusing existing WebDav instance') + } + + return this.webdavInstance + } + async backup( _: Electron.IpcMainInvokeEvent, fileName: string, @@ -322,7 +425,7 @@ class BackupManager { async backupToWebdav(_: Electron.IpcMainInvokeEvent, data: string, webdavConfig: WebDavConfig) { const filename = webdavConfig.fileName || 'cherry-studio.backup.zip' const backupedFilePath = await this.backup(_, filename, data, undefined, webdavConfig.skipBackupFile) - const webdavClient = new WebDav(webdavConfig) + const webdavClient = this.getWebDavInstance(webdavConfig) try { let result if (webdavConfig.disableStream) { @@ -349,7 +452,7 @@ class BackupManager { async restoreFromWebdav(_: Electron.IpcMainInvokeEvent, webdavConfig: WebDavConfig) { const filename = webdavConfig.fileName || 'cherry-studio.backup.zip' - const webdavClient = new WebDav(webdavConfig) + const webdavClient = this.getWebDavInstance(webdavConfig) try { const retrievedFile = await webdavClient.getFileContents(filename) const backupedFilePath = path.join(this.backupDir, filename) @@ -377,7 +480,7 @@ class BackupManager { listWebdavFiles = async (_: Electron.IpcMainInvokeEvent, config: WebDavConfig) => { try { - const client = new WebDav(config) + const client = this.getWebDavInstance(config) const response = await client.getDirectoryContents() const files = Array.isArray(response) ? response : response.data @@ -467,7 +570,7 @@ class BackupManager { } async checkConnection(_: Electron.IpcMainInvokeEvent, webdavConfig: WebDavConfig) { - const webdavClient = new WebDav(webdavConfig) + const webdavClient = this.getWebDavInstance(webdavConfig) return await webdavClient.checkConnection() } @@ -477,13 +580,13 @@ class BackupManager { path: string, options?: CreateDirectoryOptions ) { - const webdavClient = new WebDav(webdavConfig) + const webdavClient = this.getWebDavInstance(webdavConfig) return await webdavClient.createDirectory(path, options) } async deleteWebdavFile(_: Electron.IpcMainInvokeEvent, fileName: string, webdavConfig: WebDavConfig) { try { - const webdavClient = new WebDav(webdavConfig) + const webdavClient = this.getWebDavInstance(webdavConfig) return await webdavClient.deleteFile(fileName) } catch (error: any) { logger.error('Failed to delete WebDAV file:', error) @@ -525,7 +628,7 @@ class BackupManager { logger.debug(`Starting S3 backup to ${filename}`) const backupedFilePath = await this.backup(_, filename, data, undefined, s3Config.skipBackupFile) - const s3Client = new S3Storage(s3Config) + const s3Client = this.getS3Storage(s3Config) try { const fileBuffer = await fs.promises.readFile(backupedFilePath) const result = await s3Client.putFileContents(filename, fileBuffer) @@ -603,7 +706,7 @@ class BackupManager { logger.debug(`Starting restore from S3: ${filename}`) - const s3Client = new S3Storage(s3Config) + const s3Client = this.getS3Storage(s3Config) try { const retrievedFile = await s3Client.getFileContents(filename) const backupedFilePath = path.join(this.backupDir, filename) @@ -628,7 +731,7 @@ class BackupManager { listS3Files = async (_: Electron.IpcMainInvokeEvent, s3Config: S3Config) => { try { - const s3Client = new S3Storage(s3Config) + const s3Client = this.getS3Storage(s3Config) const objects = await s3Client.listFiles() const files = objects @@ -652,7 +755,7 @@ class BackupManager { async deleteS3File(_: Electron.IpcMainInvokeEvent, fileName: string, s3Config: S3Config) { try { - const s3Client = new S3Storage(s3Config) + const s3Client = this.getS3Storage(s3Config) return await s3Client.deleteFile(fileName) } catch (error: any) { logger.error('Failed to delete S3 file:', error) @@ -661,7 +764,7 @@ class BackupManager { } async checkS3Connection(_: Electron.IpcMainInvokeEvent, s3Config: S3Config) { - const s3Client = new S3Storage(s3Config) + const s3Client = this.getS3Storage(s3Config) return await s3Client.checkConnection() } } diff --git a/src/main/services/CopilotService.ts b/src/main/services/CopilotService.ts index bb54e74932..f5c773a7cc 100644 --- a/src/main/services/CopilotService.ts +++ b/src/main/services/CopilotService.ts @@ -1,6 +1,5 @@ import { loggerService } from '@logger' -import { AxiosRequestConfig } from 'axios' -import axios from 'axios' +import { net } from 'electron' import { app, safeStorage } from 'electron' import fs from 'fs/promises' import path from 'path' @@ -86,7 +85,8 @@ class CopilotService { */ public getUser = async (_: Electron.IpcMainInvokeEvent, token: string): Promise => { try { - const config: AxiosRequestConfig = { + const response = await net.fetch(CONFIG.API_URLS.GITHUB_USER, { + method: 'GET', headers: { Connection: 'keep-alive', 'user-agent': 'Visual Studio Code (desktop)', @@ -95,12 +95,16 @@ class CopilotService { 'Sec-Fetch-Dest': 'empty', authorization: `token ${token}` } + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) } - const response = await axios.get(CONFIG.API_URLS.GITHUB_USER, config) + const data = await response.json() return { - login: response.data.login, - avatar: response.data.avatar_url + login: data.login, + avatar: data.avatar_url } } catch (error) { logger.error('Failed to get user information:', error as Error) @@ -118,16 +122,23 @@ class CopilotService { try { this.updateHeaders(headers) - const response = await axios.post( - CONFIG.API_URLS.GITHUB_DEVICE_CODE, - { + const response = await net.fetch(CONFIG.API_URLS.GITHUB_DEVICE_CODE, { + method: 'POST', + headers: { + ...this.headers, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ client_id: CONFIG.GITHUB_CLIENT_ID, scope: 'read:user' - }, - { headers: this.headers } - ) + }) + }) - return response.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + return (await response.json()) as AuthResponse } catch (error) { logger.error('Failed to get auth message:', error as Error) throw new CopilotServiceError('无法获取GitHub授权信息', error) @@ -150,17 +161,25 @@ class CopilotService { await this.delay(currentDelay) try { - const response = await axios.post( - CONFIG.API_URLS.GITHUB_ACCESS_TOKEN, - { + const response = await net.fetch(CONFIG.API_URLS.GITHUB_ACCESS_TOKEN, { + method: 'POST', + headers: { + ...this.headers, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ client_id: CONFIG.GITHUB_CLIENT_ID, device_code, grant_type: 'urn:ietf:params:oauth:grant-type:device_code' - }, - { headers: this.headers } - ) + }) + }) - const { access_token } = response.data + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = (await response.json()) as TokenResponse + const { access_token } = data if (access_token) { return { access_token } } @@ -205,16 +224,19 @@ class CopilotService { const encryptedToken = await fs.readFile(this.tokenFilePath) const access_token = safeStorage.decryptString(Buffer.from(encryptedToken)) - const config: AxiosRequestConfig = { + const response = await net.fetch(CONFIG.API_URLS.COPILOT_TOKEN, { + method: 'GET', headers: { ...this.headers, authorization: `token ${access_token}` } + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) } - const response = await axios.get(CONFIG.API_URLS.COPILOT_TOKEN, config) - - return response.data + return (await response.json()) as CopilotTokenResponse } catch (error) { logger.error('Failed to get Copilot token:', error as Error) throw new CopilotServiceError('无法获取Copilot令牌,请重新授权', error) diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts index e34c51f299..f5df9ed3f7 100644 --- a/src/main/services/FileStorage.ts +++ b/src/main/services/FileStorage.ts @@ -5,6 +5,7 @@ import { FileMetadata } from '@types' import * as crypto from 'crypto' import { dialog, + net, OpenDialogOptions, OpenDialogReturnValue, SaveDialogOptions, @@ -509,7 +510,7 @@ class FileStorage { isUseContentType?: boolean ): Promise => { try { - const response = await fetch(url) + const response = await net.fetch(url) if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`) } diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index d3909cc86f..a7f907f65f 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -29,7 +29,7 @@ import { } from '@modelcontextprotocol/sdk/types.js' import { nanoid } from '@reduxjs/toolkit' import type { GetResourceResponse, MCPCallToolResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@types' -import { app } from 'electron' +import { app, net } from 'electron' import { EventEmitter } from 'events' import { memoize } from 'lodash' import { v4 as uuidv4 } from 'uuid' @@ -205,7 +205,7 @@ class McpService { } } - return fetch(url, { ...init, headers }) + return net.fetch(typeof url === 'string' ? url : url.toString(), { ...init, headers }) } }, requestInit: { diff --git a/src/main/services/NutstoreService.ts b/src/main/services/NutstoreService.ts index 4422ea8a07..f4ad5a2c33 100644 --- a/src/main/services/NutstoreService.ts +++ b/src/main/services/NutstoreService.ts @@ -2,6 +2,7 @@ import path from 'node:path' import { loggerService } from '@logger' import { NUTSTORE_HOST } from '@shared/config/nutstore' +import { net } from 'electron' import { XMLParser } from 'fast-xml-parser' import { isNil, partial } from 'lodash' import { type FileStat } from 'webdav' @@ -62,7 +63,7 @@ export async function getDirectoryContents(token: string, target: string): Promi let currentUrl = `${NUTSTORE_HOST}${target}` while (true) { - const response = await fetch(currentUrl, { + const response = await net.fetch(currentUrl, { method: 'PROPFIND', headers: { Authorization: `Basic ${token}`, diff --git a/src/main/utils/ipService.ts b/src/main/utils/ipService.ts index ec5ab78215..3180f9457c 100644 --- a/src/main/utils/ipService.ts +++ b/src/main/utils/ipService.ts @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import { net } from 'electron' const logger = loggerService.withContext('IpService') @@ -12,7 +13,7 @@ export async function getIpCountry(): Promise { const controller = new AbortController() const timeoutId = setTimeout(() => controller.abort(), 5000) - const ipinfo = await fetch('https://ipinfo.io/json', { + const ipinfo = await net.fetch('https://ipinfo.io/json', { signal: controller.signal, headers: { 'User-Agent': diff --git a/src/renderer/miniWindow.html b/src/renderer/miniWindow.html index 83b108b8a4..7f3b936444 100644 --- a/src/renderer/miniWindow.html +++ b/src/renderer/miniWindow.html @@ -6,7 +6,7 @@ - Cherry Studio + Cherry Studio Quick Assistant