From 8bfbbd497cff3ffc693e29bc9f66f4586908178f Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Sun, 18 May 2025 18:47:01 +0800 Subject: [PATCH] refactor: streamline MCP service handling and improve IPC registration * Refactored MCPService to implement a singleton pattern for better instance management. * Updated IPC registration to utilize the new getMcpInstance method for handling MCP-related requests. * Removed redundant IPC handlers from the main index file and centralized them in the ipc module. * Added background throttling option in WindowService configuration to enhance performance. * Introduced delays in MCPToolsButton to optimize resource and prompt fetching after initial load. --- src/main/index.ts | 19 +- src/main/ipc.ts | 30 +- src/main/services/MCPService.ts | 440 ++++++++++-------- src/main/services/WindowService.ts | 3 +- src/main/services/mcp/shell-env.ts | 2 +- src/renderer/src/hooks/useMCPServers.ts | 11 +- .../pages/home/Inputbar/MCPToolsButton.tsx | 67 ++- 7 files changed, 320 insertions(+), 252 deletions(-) diff --git a/src/main/index.ts b/src/main/index.ts index f85803ed84..44d516a5ca 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -2,12 +2,11 @@ import '@main/config' import { electronApp, optimizer } from '@electron-toolkit/utils' import { replaceDevtoolsFont } from '@main/utils/windowUtil' -import { IpcChannel } from '@shared/IpcChannel' -import { app, BrowserWindow, ipcMain } from 'electron' +import { app } from 'electron' import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer' import Logger from 'electron-log' -import { isDev, isMac, isWin } from './constant' +import { isDev } from './constant' import { registerIpc } from './ipc' import { configManager } from './services/ConfigManager' import mcpService from './services/MCPService' @@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) { .then((name) => console.log(`Added Extension: ${name}`)) .catch((err) => console.log('An error occurred: ', err)) } - ipcMain.handle(IpcChannel.System_GetDeviceType, () => { - return isMac ? 'mac' : isWin ? 'windows' : 'linux' - }) - - ipcMain.handle(IpcChannel.System_GetHostname, () => { - return require('os').hostname() - }) - - ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => { - const win = BrowserWindow.fromWebContents(e.sender) - win && win.webContents.toggleDevTools() - }) }) registerProtocolClient(app) @@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) { app.on('will-quit', async () => { // event.preventDefault() try { - await mcpService.cleanup() + await mcpService().cleanup() } catch (error) { Logger.error('Error cleaning up MCP service:', error) } diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 665e8114b7..439ed63e3d 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -19,7 +19,7 @@ import FileService from './services/FileService' import FileStorage from './services/FileStorage' import { GeminiService } from './services/GeminiService' import KnowledgeService from './services/KnowledgeService' -import mcpService from './services/MCPService' +import { getMcpInstance } from './services/MCPService' import * as NutstoreService from './services/NutstoreService' import ObsidianVaultService from './services/ObsidianVaultService' import { ProxyConfig, proxyManager } from './services/ProxyManager' @@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text)) ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) + // system + ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux')) + ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname()) + ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => { + const win = BrowserWindow.fromWebContents(e.sender) + win && win.webContents.toggleDevTools() + }) + // backup ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup) ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore) @@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ) // Register MCP handlers - ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer) - ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer) - ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer) - ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools) - ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool) - ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts) - ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt) - ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources) - ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource) - ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo) + ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server)) + ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server)) + ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params)) + ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server)) + ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params)) + ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server)) + ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params)) + ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo()) ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name)) ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name)) diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 237e709deb..5ea91343f5 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -68,20 +68,18 @@ function withCache( } class McpService { + private static instance: McpService | null = null private clients: Map = new Map() + private pendingClients: Map> = new Map() - private getServerKey(server: MCPServer): string { - return JSON.stringify({ - baseUrl: server.baseUrl, - command: server.command, - args: server.args, - registryUrl: server.registryUrl, - env: server.env, - id: server.id - }) + public static getInstance(): McpService { + if (!McpService.instance) { + McpService.instance = new McpService() + } + return McpService.instance } - constructor() { + private constructor() { this.initClient = this.initClient.bind(this) this.listTools = this.listTools.bind(this) this.callTool = this.callTool.bind(this) @@ -96,9 +94,26 @@ class McpService { this.cleanup = this.cleanup.bind(this) } + private getServerKey(server: MCPServer): string { + return JSON.stringify({ + baseUrl: server.baseUrl, + command: server.command, + args: server.args, + registryUrl: server.registryUrl, + env: server.env, + id: server.id + }) + } + async initClient(server: MCPServer): Promise { const serverKey = this.getServerKey(server) + // If there's a pending initialization, wait for it + const pendingClient = this.pendingClients.get(serverKey) + if (pendingClient) { + return pendingClient + } + // Check if we already have a client for this server configuration const existingClient = this.clients.get(serverKey) if (existingClient) { @@ -113,209 +128,226 @@ class McpService { } else { return existingClient } - } catch (error) { - Logger.error(`[MCP] Error pinging server ${server.name}:`, error) + } catch (error: any) { + Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message) this.clients.delete(serverKey) } } - // Create new client instance for each connection - const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} }) - const args = [...(server.args || [])] + // Create a promise for the initialization process + const initPromise = (async () => { + try { + // Create new client instance for each connection + const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} }) - // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport - const authProvider = new McpOAuthClientProvider({ - serverUrlHash: crypto - .createHash('md5') - .update(server.baseUrl || '') - .digest('hex') - }) + const args = [...(server.args || [])] - const initTransport = async (): Promise< - StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport - > => { - // Create appropriate transport based on configuration - if (server.type === 'inMemory') { - Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() - // start the in-memory server with the given name and environment variables - const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) - try { - await inMemoryServer.connect(serverTransport) - Logger.info(`[MCP] In-memory server started: ${server.name}`) - } catch (error: Error | any) { - Logger.error(`[MCP] Error starting in-memory server: ${error}`) - throw new Error(`Failed to start in-memory server: ${error.message}`) - } - // set the client transport to the client - return clientTransport - } else if (server.baseUrl) { - if (server.type === 'streamableHttp') { - const options: StreamableHTTPClientTransportOptions = { - requestInit: { - headers: server.headers || {} - }, - authProvider - } - return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) - } else if (server.type === 'sse') { - const options: SSEClientTransportOptions = { - eventSourceInit: { - fetch: async (url, init) => { - const headers = { ...(server.headers || {}), ...(init?.headers || {}) } + // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + const authProvider = new McpOAuthClientProvider({ + serverUrlHash: crypto + .createHash('md5') + .update(server.baseUrl || '') + .digest('hex') + }) - // Get tokens from authProvider to make sure using the latest tokens - if (authProvider && typeof authProvider.tokens === 'function') { - try { - const tokens = await authProvider.tokens() - if (tokens && tokens.access_token) { - headers['Authorization'] = `Bearer ${tokens.access_token}` + const initTransport = async (): Promise< + StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + > => { + // Create appropriate transport based on configuration + if (server.type === 'inMemory') { + Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() + // start the in-memory server with the given name and environment variables + const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {}) + try { + await inMemoryServer.connect(serverTransport) + Logger.info(`[MCP] In-memory server started: ${server.name}`) + } catch (error: Error | any) { + Logger.error(`[MCP] Error starting in-memory server: ${error}`) + throw new Error(`Failed to start in-memory server: ${error.message}`) + } + // set the client transport to the client + return clientTransport + } else if (server.baseUrl) { + if (server.type === 'streamableHttp') { + const options: StreamableHTTPClientTransportOptions = { + requestInit: { + headers: server.headers || {} + }, + authProvider + } + return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) + } else if (server.type === 'sse') { + const options: SSEClientTransportOptions = { + eventSourceInit: { + fetch: async (url, init) => { + const headers = { ...(server.headers || {}), ...(init?.headers || {}) } + + // Get tokens from authProvider to make sure using the latest tokens + if (authProvider && typeof authProvider.tokens === 'function') { + try { + const tokens = await authProvider.tokens() + if (tokens && tokens.access_token) { + headers['Authorization'] = `Bearer ${tokens.access_token}` + } + } catch (error) { + Logger.error('Failed to fetch tokens:', error) + } } - } catch (error) { - Logger.error('Failed to fetch tokens:', error) + + return fetch(url, { ...init, headers }) } + }, + requestInit: { + headers: server.headers || {} + }, + authProvider + } + return new SSEClientTransport(new URL(server.baseUrl!), options) + } else { + throw new Error('Invalid server type') + } + } else if (server.command) { + let cmd = server.command + + if (server.command === 'npx') { + cmd = await getBinaryPath('bun') + Logger.info(`[MCP] Using command: ${cmd}`) + + // add -x to args if args exist + if (args && args.length > 0) { + if (!args.includes('-y')) { + args.unshift('-y') + } + if (!args.includes('x')) { + args.unshift('x') + } + } + if (server.registryUrl) { + server.env = { + ...server.env, + NPM_CONFIG_REGISTRY: server.registryUrl } - return fetch(url, { ...init, headers }) + // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory + if (server.name.includes('mcp-auto-install')) { + const binPath = await getBinaryPath() + makeSureDirExists(binPath) + server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') + } + } + } else if (server.command === 'uvx' || server.command === 'uv') { + cmd = await getBinaryPath(server.command) + if (server.registryUrl) { + server.env = { + ...server.env, + UV_DEFAULT_INDEX: server.registryUrl, + PIP_INDEX_URL: server.registryUrl + } } - }, - requestInit: { - headers: server.headers || {} - }, - authProvider - } - return new SSEClientTransport(new URL(server.baseUrl!), options) - } else { - throw new Error('Invalid server type') - } - } else if (server.command) { - let cmd = server.command - - if (server.command === 'npx') { - cmd = await getBinaryPath('bun') - Logger.info(`[MCP] Using command: ${cmd}`) - - // add -x to args if args exist - if (args && args.length > 0) { - if (!args.includes('-y')) { - args.unshift('-y') - } - if (!args.includes('x')) { - args.unshift('x') - } - } - if (server.registryUrl) { - server.env = { - ...server.env, - NPM_CONFIG_REGISTRY: server.registryUrl } - // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory - if (server.name.includes('mcp-auto-install')) { - const binPath = await getBinaryPath() - makeSureDirExists(binPath) - server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') - } - } - } else if (server.command === 'uvx' || server.command === 'uv') { - cmd = await getBinaryPath(server.command) - if (server.registryUrl) { - server.env = { - ...server.env, - UV_DEFAULT_INDEX: server.registryUrl, - PIP_INDEX_URL: server.registryUrl - } + Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) + // Logger.info(`[MCP] Environment variables for server:`, server.env) + const loginShellEnv = await this.getLoginShellEnv() + const stdioTransport = new StdioClientTransport({ + command: cmd, + args, + env: { + ...loginShellEnv, + ...server.env + }, + stderr: 'pipe' + }) + stdioTransport.stderr?.on('data', (data) => + Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) + ) + return stdioTransport + } else { + throw new Error('Either baseUrl or command must be provided') } } - Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) - // Logger.info(`[MCP] Environment variables for server:`, server.env) - const loginShellEnv = await this.getLoginShellEnv() - const stdioTransport = new StdioClientTransport({ - command: cmd, - args, - env: { - ...loginShellEnv, - ...server.env - }, - stderr: 'pipe' - }) - stdioTransport.stderr?.on('data', (data) => - Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) - ) - return stdioTransport - } else { - throw new Error('Either baseUrl or command must be provided') - } - } + const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { + Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) + // Create an event emitter for the OAuth callback + const events = new EventEmitter() - const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { - Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) - // Create an event emitter for the OAuth callback - const events = new EventEmitter() + // Create a callback server + const callbackServer = new CallBackServer({ + port: authProvider.config.callbackPort, + path: authProvider.config.callbackPath || '/oauth/callback', + events + }) - // Create a callback server - const callbackServer = new CallBackServer({ - port: authProvider.config.callbackPort, - path: authProvider.config.callbackPath || '/oauth/callback', - events - }) + // Set a timeout to close the callback server + const timeoutId = setTimeout(() => { + Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) + callbackServer.close() + }, 300000) // 5 minutes timeout - // Set a timeout to close the callback server - const timeoutId = setTimeout(() => { - Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) - callbackServer.close() - }, 300000) // 5 minutes timeout + try { + // Wait for the authorization code + const authCode = await callbackServer.waitForAuthCode() + Logger.info(`[MCP] Received auth code: ${authCode}`) - try { - // Wait for the authorization code - const authCode = await callbackServer.waitForAuthCode() - Logger.info(`[MCP] Received auth code: ${authCode}`) + // Complete the OAuth flow + await transport.finishAuth(authCode) - // Complete the OAuth flow - await transport.finishAuth(authCode) + Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) - Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) + const newTransport = await initTransport() + // Try to connect again + await client.connect(newTransport) - const newTransport = await initTransport() - // Try to connect again - await client.connect(newTransport) + Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) + } catch (oauthError) { + Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) + throw new Error( + `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` + ) + } finally { + // Clear the timeout and close the callback server + clearTimeout(timeoutId) + callbackServer.close() + } + } - Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) - } catch (oauthError) { - Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) - throw new Error( - `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` - ) + try { + const transport = await initTransport() + try { + await client.connect(transport) + } catch (error: Error | any) { + if ( + error instanceof Error && + (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized')) + ) { + Logger.info(`[MCP] Authentication required for server: ${server.name}`) + await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) + } else { + throw error + } + } + + // Store the new client in the cache + this.clients.set(serverKey, client) + + Logger.info(`[MCP] Activated server: ${server.name}`) + return client + } catch (error: any) { + Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message) + throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) + } } finally { - // Clear the timeout and close the callback server - clearTimeout(timeoutId) - callbackServer.close() + // Clean up the pending promise when done + this.pendingClients.delete(serverKey) } - } + })() - try { - const transport = await initTransport() - try { - await client.connect(transport) - } catch (error: Error | any) { - if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) { - Logger.info(`[MCP] Authentication required for server: ${server.name}`) - await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) - } else { - throw error - } - } + // Store the pending promise + this.pendingClients.set(serverKey, initPromise) - // Store the new client in the cache - this.clients.set(serverKey, client) - - Logger.info(`[MCP] Activated server: ${server.name}`) - return client - } catch (error: any) { - Logger.error(`[MCP] Error activating server ${server.name}:`, error) - throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) - } + return initPromise } async closeClient(serverKey: string) { @@ -357,8 +389,8 @@ class McpService { for (const [key] of this.clients) { try { await this.closeClient(key) - } catch (error) { - Logger.error(`[MCP] Failed to close client: ${error}`) + } catch (error: any) { + Logger.error(`[MCP] Failed to close client: ${error?.message}`) } } } @@ -379,8 +411,8 @@ class McpService { serverTools.push(serverTool) }) return serverTools - } catch (error) { - Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error) + } catch (error: any) { + Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message) return [] } } @@ -439,8 +471,8 @@ class McpService { * List prompts available on an MCP server */ private async listPromptsImpl(server: MCPServer): Promise { - Logger.info(`[MCP] Listing prompts for server: ${server.name}`) const client = await this.initClient(server) + Logger.info(`[MCP] Listing prompts for server: ${server.name}`) try { const { prompts } = await client.listPrompts() return prompts.map((prompt: any) => ({ @@ -449,8 +481,11 @@ class McpService { serverId: server.id, serverName: server.name })) - } catch (error) { - Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error) + } catch (error: any) { + // -32601 is the code for the method not found + if (error?.code !== -32601) { + Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message) + } return [] } } @@ -508,8 +543,8 @@ class McpService { * List resources available on an MCP server (implementation) */ private async listResourcesImpl(server: MCPServer): Promise { - Logger.info(`[MCP] Listing resources for server: ${server.name}`) const client = await this.initClient(server) + Logger.info(`[MCP] Listing resources for server: ${server.name}`) try { const result = await client.listResources() const resources = result.resources || [] @@ -519,8 +554,11 @@ class McpService { serverName: server.name })) return serverResources - } catch (error) { - Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error) + } catch (error: any) { + // -32601 is the code for the method not found + if (error?.code !== -32601) { + Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message) + } return [] } } @@ -563,7 +601,7 @@ class McpService { contents: contents } } catch (error: Error | any) { - Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error) + Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) } } @@ -602,5 +640,13 @@ class McpService { }) } -const mcpService = new McpService() -export default mcpService +let mcpInstance: ReturnType | null = null + +export const getMcpInstance = () => { + if (!mcpInstance) { + mcpInstance = McpService.getInstance() + } + return mcpInstance +} + +export default McpService.getInstance diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index eb4e2f104a..f033cc82bf 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -446,7 +446,8 @@ export class WindowService { preload: join(__dirname, '../preload/index.js'), sandbox: false, webSecurity: false, - webviewTag: true + webviewTag: true, + backgroundThrottling: false } }) diff --git a/src/main/services/mcp/shell-env.ts b/src/main/services/mcp/shell-env.ts index 54cc21280f..9901417024 100644 --- a/src/main/services/mcp/shell-env.ts +++ b/src/main/services/mcp/shell-env.ts @@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise> { commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command } - Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) + Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) const child = spawn(shellPath, commandArgs, { cwd: homeDirectory, // Run the command in the user's home directory diff --git a/src/renderer/src/hooks/useMCPServers.ts b/src/renderer/src/hooks/useMCPServers.ts index 90d6e6fec4..49fde29f60 100644 --- a/src/renderer/src/hooks/useMCPServers.ts +++ b/src/renderer/src/hooks/useMCPServers.ts @@ -1,8 +1,8 @@ +import { createSelector } from '@reduxjs/toolkit' import store, { useAppDispatch, useAppSelector } from '@renderer/store' import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp' import { MCPServer } from '@renderer/types' import { IpcChannel } from '@shared/IpcChannel' -import { useMemo } from 'react' const ipcRenderer = window.electron.ipcRenderer @@ -14,9 +14,14 @@ ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => { store.dispatch(addMCPServer(server)) }) +const selectMcpServers = (state) => state.mcp.servers +const selectActiveMcpServers = createSelector([selectMcpServers], (servers) => + servers.filter((server) => server.isActive) +) + export const useMCPServers = () => { - const mcpServers = useAppSelector((state) => state.mcp.servers) - const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers]) + const mcpServers = useAppSelector(selectMcpServers) + const activedMcpServers = useAppSelector(selectActiveMcpServers) const dispatch = useAppDispatch() return { diff --git a/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx b/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx index 60eff1cd23..4aa461c9e9 100644 --- a/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/MCPToolsButton.tsx @@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant' import { useMCPServers } from '@renderer/hooks/useMCPServers' import { EventEmitter } from '@renderer/services/EventService' import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types' +import { delay, runAsyncFunction } from '@renderer/utils' import { Form, Input, Tooltip } from 'antd' import { Plus, SquareTerminal } from 'lucide-react' import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react' @@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => { return null } +// Add static variable before component definition +let isFirstResourcesListCall = true +let isFirstPromptListCall = true +const initMcpDelay = 3 + const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => { const { activedMcpServers } = useMCPServers() const { t } = useTranslation() @@ -308,6 +314,11 @@ const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, Toolbar const promptList = useMemo(async () => { const prompts: MCPPrompt[] = [] + if (isFirstPromptListCall) { + await delay(initMcpDelay) + isFirstPromptListCall = false + } + for (const server of activedMcpServers) { const serverPrompts = await window.api.mcp.listPrompts(server) prompts.push(...serverPrompts) @@ -319,7 +330,8 @@ const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, Toolbar icon: , action: () => handlePromptSelect(prompt as MCPPromptWithArgs) })) - }, [handlePromptSelect, activedMcpServers]) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [activedMcpServers]) const openPromptList = useCallback(async () => { const prompts = await promptList @@ -380,33 +392,42 @@ const MCPToolsButton: FC = ({ ref, setInputValue, resizeTextArea, Toolbar const [resourcesList, setResourcesList] = useState([]) useEffect(() => { - let isMounted = true + runAsyncFunction(async () => { + let isMounted = true - const fetchResources = async () => { - const resources: MCPResource[] = [] - for (const server of activedMcpServers) { - const serverResources = await window.api.mcp.listResources(server) - resources.push(...serverResources) + const fetchResources = async () => { + const resources: MCPResource[] = [] + + for (const server of activedMcpServers) { + const serverResources = await window.api.mcp.listResources(server) + resources.push(...serverResources) + } + + if (isMounted) { + setResourcesList( + resources.map((resource) => ({ + label: resource.name, + description: resource.description, + icon: , + action: () => handleResourceSelect(resource) + })) + ) + } } - if (isMounted) { - setResourcesList( - resources.map((resource) => ({ - label: resource.name, - description: resource.description, - icon: , - action: () => handleResourceSelect(resource) - })) - ) + // Avoid mcp following the software startup, affecting the startup speed + if (isFirstResourcesListCall) { + await delay(initMcpDelay) + isFirstResourcesListCall = false + fetchResources() } - } - fetchResources() - - return () => { - isMounted = false - } - }, [activedMcpServers, handleResourceSelect]) + return () => { + isMounted = false + } + }) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [activedMcpServers]) const openResourcesList = useCallback(async () => { const resources = resourcesList