diff --git a/electron-builder.yml b/electron-builder.yml index 59de8a4f50..4598455544 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -92,9 +92,9 @@ artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | ⚠️ 注意:升级前请备份数据,否则将无法降级 - 重构消息结构,支持不同类型消息按时间顺序显示 - 智能体支持导入和导出 - 快捷面板增加网络搜索引擎选择 - 显示设置增加缩放控制按钮 - 支持添加自定义小程序 - 性能优化和错误修复 + 优化软件启动速度 + 优化软件进入后台后性能问题 + 修复导出对话时自动重命名失败问题 + 防止输入法切换期间误发消息问题 + 修复群组消息重发功能问题及富文本粘贴兼容性问题 + 改进 MCP 服务处理及 IPC 注册逻辑 diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 0a435c5493..bcbb53d079 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -85,12 +85,19 @@ export default defineConfig({ miniWindow: resolve(__dirname, 'src/renderer/miniWindow.html') }, output: { - manualChunks: (id) => { + manualChunks: (id: string) => { // 检测所有 worker 文件,提取 worker 名称作为 chunk 名 if (id.includes('.worker') && id.endsWith('?worker')) { const workerName = id.split('/').pop()?.split('.')[0] || 'worker' return `workers/${workerName}` } + + // All node_modules are in the vendor chunk + if (id.includes('node_modules')) { + return 'vendor' + } + + // Other modules use default chunk splitting strategy return undefined } } diff --git a/package.json b/package.json index 8531b1bb1e..fb900e0650 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "CherryStudio", - "version": "1.3.5", + "version": "1.3.6", "private": true, "description": "A powerful AI assistant for producer.", "main": "./out/main/index.js", diff --git a/src/main/index.ts b/src/main/index.ts index f85803ed84..44d516a5ca 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -2,12 +2,11 @@ import '@main/config' import { electronApp, optimizer } from '@electron-toolkit/utils' import { replaceDevtoolsFont } from '@main/utils/windowUtil' -import { IpcChannel } from '@shared/IpcChannel' -import { app, BrowserWindow, ipcMain } from 'electron' +import { app } from 'electron' import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer' import Logger from 'electron-log' -import { isDev, isMac, isWin } from './constant' +import { isDev } from './constant' import { registerIpc } from './ipc' import { configManager } from './services/ConfigManager' import mcpService from './services/MCPService' @@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) { .then((name) => console.log(`Added Extension: ${name}`)) .catch((err) => console.log('An error occurred: ', err)) } - ipcMain.handle(IpcChannel.System_GetDeviceType, () => { - return isMac ? 'mac' : isWin ? 'windows' : 'linux' - }) - - ipcMain.handle(IpcChannel.System_GetHostname, () => { - return require('os').hostname() - }) - - ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => { - const win = BrowserWindow.fromWebContents(e.sender) - win && win.webContents.toggleDevTools() - }) }) registerProtocolClient(app) @@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) { app.on('will-quit', async () => { // event.preventDefault() try { - await mcpService.cleanup() + await mcpService().cleanup() } catch (error) { Logger.error('Error cleaning up MCP service:', error) } diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 665e8114b7..439ed63e3d 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -19,7 +19,7 @@ import FileService from './services/FileService' import FileStorage from './services/FileStorage' import { GeminiService } from './services/GeminiService' import KnowledgeService from './services/KnowledgeService' -import mcpService from './services/MCPService' +import { getMcpInstance } from './services/MCPService' import * as NutstoreService from './services/NutstoreService' import ObsidianVaultService from './services/ObsidianVaultService' import { ProxyConfig, proxyManager } from './services/ProxyManager' @@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text)) ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) + // system + ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux')) + ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname()) + ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => { + const win = BrowserWindow.fromWebContents(e.sender) + win && win.webContents.toggleDevTools() + }) + // backup ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup) ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore) @@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ) // Register MCP handlers - ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer) - ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer) - ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer) - ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools) - ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool) - ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts) - ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt) - ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources) - ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource) - ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo) + ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server)) + ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params)) + ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server)) + ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params)) + ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server)) + ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params)) + ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo()) ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name)) ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name)) diff --git a/src/main/reranker/BaseReranker.ts b/src/main/reranker/BaseReranker.ts index a88d0883ae..f956a0573f 100644 --- a/src/main/reranker/BaseReranker.ts +++ b/src/main/reranker/BaseReranker.ts @@ -38,7 +38,7 @@ export default abstract class BaseReranker { protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) { const provider = this.base.rerankModelProvider const documents = searchResults.map((doc) => doc.pageContent) - const topN = this.base.topN || 10 + const topN = this.base.documentCount if (provider === 'voyageai') { return { diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 237e709deb..90e50ec65a 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -4,6 +4,7 @@ import path from 'node:path' import { createInMemoryMCPServer } from '@main/mcpServers/factory' import { makeSureDirExists } from '@main/utils' +import { buildFunctionCallToolName } from '@main/utils/mcp' import { getBinaryName, getBinaryPath } from '@main/utils/process' import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { SSEClientTransport, SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js' @@ -68,20 +69,18 @@ function withCache( } class McpService { + private static instance: McpService | null = null private clients: Map = new Map() + private pendingClients: Map> = new Map() - private getServerKey(server: MCPServer): string { - return JSON.stringify({ - baseUrl: server.baseUrl, - command: server.command, - args: server.args, - registryUrl: server.registryUrl, - env: server.env, - id: server.id - }) + public static getInstance(): McpService { + if (!McpService.instance) { + McpService.instance = new McpService() + } + return McpService.instance } - constructor() { + private constructor() { this.initClient = this.initClient.bind(this) this.listTools = this.listTools.bind(this) this.callTool = this.callTool.bind(this) @@ -96,9 +95,26 @@ class McpService { this.cleanup = this.cleanup.bind(this) } + private getServerKey(server: MCPServer): string { + return JSON.stringify({ + baseUrl: server.baseUrl, + command: server.command, + args: server.args, + registryUrl: server.registryUrl, + env: server.env, + id: server.id + }) + } + async initClient(server: MCPServer): Promise { const serverKey = this.getServerKey(server) + // If there's a pending initialization, wait for it + const pendingClient = this.pendingClients.get(serverKey) + if (pendingClient) { + return pendingClient + } + // Check if we already have a client for this server configuration const existingClient = this.clients.get(serverKey) if (existingClient) { @@ -113,209 +129,226 @@ class McpService { } else { return existingClient } - } catch (error) { - Logger.error(`[MCP] Error pinging server ${server.name}:`, error) + } catch (error: any) { + Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message) this.clients.delete(serverKey) } } - // Create new client instance for each connection - const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} }) - const args = [...(server.args || [])] + // Create a promise for the initialization process + const initPromise = (async () => { + try { + // Create new client instance for each connection + const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} }) - // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport - const authProvider = new McpOAuthClientProvider({ - serverUrlHash: crypto - .createHash('md5') - .update(server.baseUrl || '') - .digest('hex') - }) + const args = [...(server.args || [])] - const initTransport = async (): Promise< - StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport - > => { - // Create appropriate transport based on configuration - if (server.type === 'inMemory') { - Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() - // start the in-memory server with the given name and environment variables - const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) - try { - await inMemoryServer.connect(serverTransport) - Logger.info(`[MCP] In-memory server started: ${server.name}`) - } catch (error: Error | any) { - Logger.error(`[MCP] Error starting in-memory server: ${error}`) - throw new Error(`Failed to start in-memory server: ${error.message}`) - } - // set the client transport to the client - return clientTransport - } else if (server.baseUrl) { - if (server.type === 'streamableHttp') { - const options: StreamableHTTPClientTransportOptions = { - requestInit: { - headers: server.headers || {} - }, - authProvider - } - return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) - } else if (server.type === 'sse') { - const options: SSEClientTransportOptions = { - eventSourceInit: { - fetch: async (url, init) => { - const headers = { ...(server.headers || {}), ...(init?.headers || {}) } + // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + const authProvider = new McpOAuthClientProvider({ + serverUrlHash: crypto + .createHash('md5') + .update(server.baseUrl || '') + .digest('hex') + }) - // Get tokens from authProvider to make sure using the latest tokens - if (authProvider && typeof authProvider.tokens === 'function') { - try { - const tokens = await authProvider.tokens() - if (tokens && tokens.access_token) { - headers['Authorization'] = `Bearer ${tokens.access_token}` + const initTransport = async (): Promise< + StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + > => { + // Create appropriate transport based on configuration + if (server.type === 'inMemory') { + Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() + // start the in-memory server with the given name and environment variables + const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) + try { + await inMemoryServer.connect(serverTransport) + Logger.info(`[MCP] In-memory server started: ${server.name}`) + } catch (error: Error | any) { + Logger.error(`[MCP] Error starting in-memory server: ${error}`) + throw new Error(`Failed to start in-memory server: ${error.message}`) + } + // set the client transport to the client + return clientTransport + } else if (server.baseUrl) { + if (server.type === 'streamableHttp') { + const options: StreamableHTTPClientTransportOptions = { + requestInit: { + headers: server.headers || {} + }, + authProvider + } + return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) + } else if (server.type === 'sse') { + const options: SSEClientTransportOptions = { + eventSourceInit: { + fetch: async (url, init) => { + const headers = { ...(server.headers || {}), ...(init?.headers || {}) } + + // Get tokens from authProvider to make sure using the latest tokens + if (authProvider && typeof authProvider.tokens === 'function') { + try { + const tokens = await authProvider.tokens() + if (tokens && tokens.access_token) { + headers['Authorization'] = `Bearer ${tokens.access_token}` + } + } catch (error) { + Logger.error('Failed to fetch tokens:', error) + } } - } catch (error) { - Logger.error('Failed to fetch tokens:', error) + + return fetch(url, { ...init, headers }) } + }, + requestInit: { + headers: server.headers || {} + }, + authProvider + } + return new SSEClientTransport(new URL(server.baseUrl!), options) + } else { + throw new Error('Invalid server type') + } + } else if (server.command) { + let cmd = server.command + + if (server.command === 'npx') { + cmd = await getBinaryPath('bun') + Logger.info(`[MCP] Using command: ${cmd}`) + + // add -x to args if args exist + if (args && args.length > 0) { + if (!args.includes('-y')) { + args.unshift('-y') + } + if (!args.includes('x')) { + args.unshift('x') + } + } + if (server.registryUrl) { + server.env = { + ...server.env, + NPM_CONFIG_REGISTRY: server.registryUrl } - return fetch(url, { ...init, headers }) + // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory + if (server.name.includes('mcp-auto-install')) { + const binPath = await getBinaryPath() + makeSureDirExists(binPath) + server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') + } + } + } else if (server.command === 'uvx' || server.command === 'uv') { + cmd = await getBinaryPath(server.command) + if (server.registryUrl) { + server.env = { + ...server.env, + UV_DEFAULT_INDEX: server.registryUrl, + PIP_INDEX_URL: server.registryUrl + } } - }, - requestInit: { - headers: server.headers || {} - }, - authProvider - } - return new SSEClientTransport(new URL(server.baseUrl!), options) - } else { - throw new Error('Invalid server type') - } - } else if (server.command) { - let cmd = server.command - - if (server.command === 'npx') { - cmd = await getBinaryPath('bun') - Logger.info(`[MCP] Using command: ${cmd}`) - - // add -x to args if args exist - if (args && args.length > 0) { - if (!args.includes('-y')) { - args.unshift('-y') - } - if (!args.includes('x')) { - args.unshift('x') - } - } - if (server.registryUrl) { - server.env = { - ...server.env, - NPM_CONFIG_REGISTRY: server.registryUrl } - // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory - if (server.name.includes('mcp-auto-install')) { - const binPath = await getBinaryPath() - makeSureDirExists(binPath) - server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') - } - } - } else if (server.command === 'uvx' || server.command === 'uv') { - cmd = await getBinaryPath(server.command) - if (server.registryUrl) { - server.env = { - ...server.env, - UV_DEFAULT_INDEX: server.registryUrl, - PIP_INDEX_URL: server.registryUrl - } + Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) + // Logger.info(`[MCP] Environment variables for server:`, server.env) + const loginShellEnv = await this.getLoginShellEnv() + const stdioTransport = new StdioClientTransport({ + command: cmd, + args, + env: { + ...loginShellEnv, + ...server.env + }, + stderr: 'pipe' + }) + stdioTransport.stderr?.on('data', (data) => + Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) + ) + return stdioTransport + } else { + throw new Error('Either baseUrl or command must be provided') } } - Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) - // Logger.info(`[MCP] Environment variables for server:`, server.env) - const loginShellEnv = await this.getLoginShellEnv() - const stdioTransport = new StdioClientTransport({ - command: cmd, - args, - env: { - ...loginShellEnv, - ...server.env - }, - stderr: 'pipe' - }) - stdioTransport.stderr?.on('data', (data) => - Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) - ) - return stdioTransport - } else { - throw new Error('Either baseUrl or command must be provided') - } - } + const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { + Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) + // Create an event emitter for the OAuth callback + const events = new EventEmitter() - const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { - Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) - // Create an event emitter for the OAuth callback - const events = new EventEmitter() + // Create a callback server + const callbackServer = new CallBackServer({ + port: authProvider.config.callbackPort, + path: authProvider.config.callbackPath || '/oauth/callback', + events + }) - // Create a callback server - const callbackServer = new CallBackServer({ - port: authProvider.config.callbackPort, - path: authProvider.config.callbackPath || '/oauth/callback', - events - }) + // Set a timeout to close the callback server + const timeoutId = setTimeout(() => { + Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) + callbackServer.close() + }, 300000) // 5 minutes timeout - // Set a timeout to close the callback server - const timeoutId = setTimeout(() => { - Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) - callbackServer.close() - }, 300000) // 5 minutes timeout + try { + // Wait for the authorization code + const authCode = await callbackServer.waitForAuthCode() + Logger.info(`[MCP] Received auth code: ${authCode}`) - try { - // Wait for the authorization code - const authCode = await callbackServer.waitForAuthCode() - Logger.info(`[MCP] Received auth code: ${authCode}`) + // Complete the OAuth flow + await transport.finishAuth(authCode) - // Complete the OAuth flow - await transport.finishAuth(authCode) + Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) - Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) + const newTransport = await initTransport() + // Try to connect again + await client.connect(newTransport) - const newTransport = await initTransport() - // Try to connect again - await client.connect(newTransport) + Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) + } catch (oauthError) { + Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) + throw new Error( + `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` + ) + } finally { + // Clear the timeout and close the callback server + clearTimeout(timeoutId) + callbackServer.close() + } + } - Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) - } catch (oauthError) { - Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) - throw new Error( - `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` - ) + try { + const transport = await initTransport() + try { + await client.connect(transport) + } catch (error: Error | any) { + if ( + error instanceof Error && + (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized')) + ) { + Logger.info(`[MCP] Authentication required for server: ${server.name}`) + await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) + } else { + throw error + } + } + + // Store the new client in the cache + this.clients.set(serverKey, client) + + Logger.info(`[MCP] Activated server: ${server.name}`) + return client + } catch (error: any) { + Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message) + throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) + } } finally { - // Clear the timeout and close the callback server - clearTimeout(timeoutId) - callbackServer.close() + // Clean up the pending promise when done + this.pendingClients.delete(serverKey) } - } + })() - try { - const transport = await initTransport() - try { - await client.connect(transport) - } catch (error: Error | any) { - if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) { - Logger.info(`[MCP] Authentication required for server: ${server.name}`) - await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) - } else { - throw error - } - } + // Store the pending promise + this.pendingClients.set(serverKey, initPromise) - // Store the new client in the cache - this.clients.set(serverKey, client) - - Logger.info(`[MCP] Activated server: ${server.name}`) - return client - } catch (error: any) { - Logger.error(`[MCP] Error activating server ${server.name}:`, error) - throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) - } + return initPromise } async closeClient(serverKey: string) { @@ -357,8 +390,8 @@ class McpService { for (const [key] of this.clients) { try { await this.closeClient(key) - } catch (error) { - Logger.error(`[MCP] Failed to close client: ${error}`) + } catch (error: any) { + Logger.error(`[MCP] Failed to close client: ${error?.message}`) } } } @@ -372,15 +405,15 @@ class McpService { tools.map((tool: any) => { const serverTool: MCPTool = { ...tool, - id: `f${nanoid()}`, + id: buildFunctionCallToolName(server.name, tool.name), serverId: server.id, serverName: server.name } serverTools.push(serverTool) }) return serverTools - } catch (error) { - Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error) + } catch (error: any) { + Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message) return [] } } @@ -439,8 +472,8 @@ class McpService { * List prompts available on an MCP server */ private async listPromptsImpl(server: MCPServer): Promise { - Logger.info(`[MCP] Listing prompts for server: ${server.name}`) const client = await this.initClient(server) + Logger.info(`[MCP] Listing prompts for server: ${server.name}`) try { const { prompts } = await client.listPrompts() return prompts.map((prompt: any) => ({ @@ -449,8 +482,11 @@ class McpService { serverId: server.id, serverName: server.name })) - } catch (error) { - Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error) + } catch (error: any) { + // -32601 is the code for the method not found + if (error?.code !== -32601) { + Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message) + } return [] } } @@ -508,8 +544,8 @@ class McpService { * List resources available on an MCP server (implementation) */ private async listResourcesImpl(server: MCPServer): Promise { - Logger.info(`[MCP] Listing resources for server: ${server.name}`) const client = await this.initClient(server) + Logger.info(`[MCP] Listing resources for server: ${server.name}`) try { const result = await client.listResources() const resources = result.resources || [] @@ -519,8 +555,11 @@ class McpService { serverName: server.name })) return serverResources - } catch (error) { - Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error) + } catch (error: any) { + // -32601 is the code for the method not found + if (error?.code !== -32601) { + Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message) + } return [] } } @@ -563,7 +602,7 @@ class McpService { contents: contents } } catch (error: Error | any) { - Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error) + Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) } } @@ -602,5 +641,13 @@ class McpService { }) } -const mcpService = new McpService() -export default mcpService +let mcpInstance: ReturnType | null = null + +export const getMcpInstance = () => { + if (!mcpInstance) { + mcpInstance = McpService.getInstance() + } + return mcpInstance +} + +export default McpService.getInstance diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index aff511d748..f033cc82bf 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -75,7 +75,8 @@ export class WindowService { sandbox: false, webSecurity: false, webviewTag: true, - allowRunningInsecureContent: true + allowRunningInsecureContent: true, + backgroundThrottling: false } }) @@ -323,6 +324,11 @@ export class WindowService { event.preventDefault() + if (mainWindow.isFullScreen()) { + mainWindow.setFullScreen(false) + return + } + mainWindow.hide() //for mac users, should hide dock icon if close to tray @@ -440,7 +446,8 @@ export class WindowService { preload: join(__dirname, '../preload/index.js'), sandbox: false, webSecurity: false, - webviewTag: true + webviewTag: true, + backgroundThrottling: false } }) diff --git a/src/main/services/mcp/shell-env.ts b/src/main/services/mcp/shell-env.ts index 54cc21280f..9901417024 100644 --- a/src/main/services/mcp/shell-env.ts +++ b/src/main/services/mcp/shell-env.ts @@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise> { commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command } - Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) + Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) const child = spawn(shellPath, commandArgs, { cwd: homeDirectory, // Run the command in the user's home directory diff --git a/src/main/utils/mcp.ts b/src/main/utils/mcp.ts new file mode 100644 index 0000000000..23d19806d9 --- /dev/null +++ b/src/main/utils/mcp.ts @@ -0,0 +1,34 @@ +export function buildFunctionCallToolName(serverName: string, toolName: string) { + const sanitizedServer = serverName.trim().replace(/-/g, '_') + const sanitizedTool = toolName.trim().replace(/-/g, '_') + + // Combine server name and tool name + let name = sanitizedTool + if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) { + name = `${sanitizedServer.slice(0, 7) || ''}-${sanitizedTool || ''}` + } + + // Replace invalid characters with underscores or dashes + // Keep a-z, A-Z, 0-9, underscores and dashes + name = name.replace(/[^a-zA-Z0-9_-]/g, '_') + + // Ensure name starts with a letter or underscore (for valid JavaScript identifier) + if (!/^[a-zA-Z]/.test(name)) { + name = `tool-${name}` + } + + // Remove consecutive underscores/dashes (optional improvement) + name = name.replace(/[_-]{2,}/g, '_') + + // Truncate to 63 characters maximum + if (name.length > 63) { + name = name.slice(0, 63) + } + + // Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges + if (name.endsWith('_') || name.endsWith('-')) { + name = name.slice(0, -1) + } + + return name +} diff --git a/src/renderer/__tests__/setup.ts b/src/renderer/__tests__/setup.ts index f847a40826..70b9cd70b0 100644 --- a/src/renderer/__tests__/setup.ts +++ b/src/renderer/__tests__/setup.ts @@ -18,3 +18,32 @@ vi.mock('electron-log/renderer', () => { } } }) + +vi.stubGlobal('window', { + electron: { + ipcRenderer: { + on: vi.fn(), // Mocking ipcRenderer.on + send: vi.fn() // Mocking ipcRenderer.send + } + }, + api: { + file: { + read: vi.fn().mockResolvedValue('[]'), // Mock file.read to return an empty array (you can customize this) + writeWithId: vi.fn().mockResolvedValue(undefined) // Mock file.writeWithId to do nothing + } + } +}) + +vi.mock('axios', () => ({ + default: { + get: vi.fn().mockResolvedValue({ data: {} }), // Mocking axios GET request + post: vi.fn().mockResolvedValue({ data: {} }) // Mocking axios POST request + // You can add other axios methods like put, delete etc. as needed + } +})) + +vi.stubGlobal('window', { + ...global.window, // Copy other global properties + addEventListener: vi.fn(), // Mock addEventListener + removeEventListener: vi.fn() // You can also mock removeEventListener if needed +}) diff --git a/src/renderer/index.html b/src/renderer/index.html index eebceeac66..c8832dc573 100644 --- a/src/renderer/index.html +++ b/src/renderer/index.html @@ -5,7 +5,7 @@ + content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' 'unsafe-inline' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" /> Cherry Studio