refactor: streamline MCP service handling and improve IPC registration

* Refactored MCPService to implement a singleton pattern for better instance management.
* Updated IPC registration to utilize the new getMcpInstance method for handling MCP-related requests.
* Removed redundant IPC handlers from the main index file and centralized them in the ipc module.
* Added background throttling option in WindowService configuration to enhance performance.
* Introduced delays in MCPToolsButton to optimize resource and prompt fetching after initial load.
This commit is contained in:
kangfenmao 2025-05-18 18:47:01 +08:00 committed by 亢奋猫
parent 4b2417ce37
commit 8bfbbd497c
7 changed files with 320 additions and 252 deletions

View File

@ -2,12 +2,11 @@ import '@main/config'
import { electronApp, optimizer } from '@electron-toolkit/utils' import { electronApp, optimizer } from '@electron-toolkit/utils'
import { replaceDevtoolsFont } from '@main/utils/windowUtil' import { replaceDevtoolsFont } from '@main/utils/windowUtil'
import { IpcChannel } from '@shared/IpcChannel' import { app } from 'electron'
import { app, BrowserWindow, ipcMain } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer' import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
import Logger from 'electron-log' import Logger from 'electron-log'
import { isDev, isMac, isWin } from './constant' import { isDev } from './constant'
import { registerIpc } from './ipc' import { registerIpc } from './ipc'
import { configManager } from './services/ConfigManager' import { configManager } from './services/ConfigManager'
import mcpService from './services/MCPService' import mcpService from './services/MCPService'
@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) {
.then((name) => console.log(`Added Extension: ${name}`)) .then((name) => console.log(`Added Extension: ${name}`))
.catch((err) => console.log('An error occurred: ', err)) .catch((err) => console.log('An error occurred: ', err))
} }
ipcMain.handle(IpcChannel.System_GetDeviceType, () => {
return isMac ? 'mac' : isWin ? 'windows' : 'linux'
})
ipcMain.handle(IpcChannel.System_GetHostname, () => {
return require('os').hostname()
})
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
}) })
registerProtocolClient(app) registerProtocolClient(app)
@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) {
app.on('will-quit', async () => { app.on('will-quit', async () => {
// event.preventDefault() // event.preventDefault()
try { try {
await mcpService.cleanup() await mcpService().cleanup()
} catch (error) { } catch (error) {
Logger.error('Error cleaning up MCP service:', error) Logger.error('Error cleaning up MCP service:', error)
} }

View File

@ -19,7 +19,7 @@ import FileService from './services/FileService'
import FileStorage from './services/FileStorage' import FileStorage from './services/FileStorage'
import { GeminiService } from './services/GeminiService' import { GeminiService } from './services/GeminiService'
import KnowledgeService from './services/KnowledgeService' import KnowledgeService from './services/KnowledgeService'
import mcpService from './services/MCPService' import { getMcpInstance } from './services/MCPService'
import * as NutstoreService from './services/NutstoreService' import * as NutstoreService from './services/NutstoreService'
import ObsidianVaultService from './services/ObsidianVaultService' import ObsidianVaultService from './services/ObsidianVaultService'
import { ProxyConfig, proxyManager } from './services/ProxyManager' import { ProxyConfig, proxyManager } from './services/ProxyManager'
@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text)) ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text))
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
// system
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
// backup // backup
ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup) ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup)
ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore) ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore)
@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
) )
// Register MCP handlers // Register MCP handlers
ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer) ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server))
ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer) ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server))
ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer) ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server))
ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools) ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server))
ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool) ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params))
ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts) ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server))
ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt) ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params))
ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources) ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server))
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource) ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params))
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo) ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo())
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name)) ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name)) ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))

View File

@ -68,20 +68,18 @@ function withCache<T extends unknown[], R>(
} }
class McpService { class McpService {
private static instance: McpService | null = null
private clients: Map<string, Client> = new Map() private clients: Map<string, Client> = new Map()
private pendingClients: Map<string, Promise<Client>> = new Map()
private getServerKey(server: MCPServer): string { public static getInstance(): McpService {
return JSON.stringify({ if (!McpService.instance) {
baseUrl: server.baseUrl, McpService.instance = new McpService()
command: server.command, }
args: server.args, return McpService.instance
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
} }
constructor() { private constructor() {
this.initClient = this.initClient.bind(this) this.initClient = this.initClient.bind(this)
this.listTools = this.listTools.bind(this) this.listTools = this.listTools.bind(this)
this.callTool = this.callTool.bind(this) this.callTool = this.callTool.bind(this)
@ -96,9 +94,26 @@ class McpService {
this.cleanup = this.cleanup.bind(this) this.cleanup = this.cleanup.bind(this)
} }
private getServerKey(server: MCPServer): string {
return JSON.stringify({
baseUrl: server.baseUrl,
command: server.command,
args: server.args,
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
}
async initClient(server: MCPServer): Promise<Client> { async initClient(server: MCPServer): Promise<Client> {
const serverKey = this.getServerKey(server) const serverKey = this.getServerKey(server)
// If there's a pending initialization, wait for it
const pendingClient = this.pendingClients.get(serverKey)
if (pendingClient) {
return pendingClient
}
// Check if we already have a client for this server configuration // Check if we already have a client for this server configuration
const existingClient = this.clients.get(serverKey) const existingClient = this.clients.get(serverKey)
if (existingClient) { if (existingClient) {
@ -113,11 +128,15 @@ class McpService {
} else { } else {
return existingClient return existingClient
} }
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Error pinging server ${server.name}:`, error) Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message)
this.clients.delete(serverKey) this.clients.delete(serverKey)
} }
} }
// Create a promise for the initialization process
const initPromise = (async () => {
try {
// Create new client instance for each connection // Create new client instance for each connection
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} }) const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
@ -299,7 +318,10 @@ class McpService {
try { try {
await client.connect(transport) await client.connect(transport)
} catch (error: Error | any) { } catch (error: Error | any) {
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) { if (
error instanceof Error &&
(error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))
) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`) Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else { } else {
@ -313,9 +335,19 @@ class McpService {
Logger.info(`[MCP] Activated server: ${server.name}`) Logger.info(`[MCP] Activated server: ${server.name}`)
return client return client
} catch (error: any) { } catch (error: any) {
Logger.error(`[MCP] Error activating server ${server.name}:`, error) Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
} }
} finally {
// Clean up the pending promise when done
this.pendingClients.delete(serverKey)
}
})()
// Store the pending promise
this.pendingClients.set(serverKey, initPromise)
return initPromise
} }
async closeClient(serverKey: string) { async closeClient(serverKey: string) {
@ -357,8 +389,8 @@ class McpService {
for (const [key] of this.clients) { for (const [key] of this.clients) {
try { try {
await this.closeClient(key) await this.closeClient(key)
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to close client: ${error}`) Logger.error(`[MCP] Failed to close client: ${error?.message}`)
} }
} }
} }
@ -379,8 +411,8 @@ class McpService {
serverTools.push(serverTool) serverTools.push(serverTool)
}) })
return serverTools return serverTools
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error) Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message)
return [] return []
} }
} }
@ -439,8 +471,8 @@ class McpService {
* List prompts available on an MCP server * List prompts available on an MCP server
*/ */
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> { private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
const client = await this.initClient(server) const client = await this.initClient(server)
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
try { try {
const { prompts } = await client.listPrompts() const { prompts } = await client.listPrompts()
return prompts.map((prompt: any) => ({ return prompts.map((prompt: any) => ({
@ -449,8 +481,11 @@ class McpService {
serverId: server.id, serverId: server.id,
serverName: server.name serverName: server.name
})) }))
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error) // -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message)
}
return [] return []
} }
} }
@ -508,8 +543,8 @@ class McpService {
* List resources available on an MCP server (implementation) * List resources available on an MCP server (implementation)
*/ */
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> { private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
const client = await this.initClient(server) const client = await this.initClient(server)
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
try { try {
const result = await client.listResources() const result = await client.listResources()
const resources = result.resources || [] const resources = result.resources || []
@ -519,8 +554,11 @@ class McpService {
serverName: server.name serverName: server.name
})) }))
return serverResources return serverResources
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error) // -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message)
}
return [] return []
} }
} }
@ -563,7 +601,7 @@ class McpService {
contents: contents contents: contents
} }
} catch (error: Error | any) { } catch (error: Error | any) {
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error) Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message)
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
} }
} }
@ -602,5 +640,13 @@ class McpService {
}) })
} }
const mcpService = new McpService() let mcpInstance: ReturnType<typeof McpService.getInstance> | null = null
export default mcpService
export const getMcpInstance = () => {
if (!mcpInstance) {
mcpInstance = McpService.getInstance()
}
return mcpInstance
}
export default McpService.getInstance

View File

@ -446,7 +446,8 @@ export class WindowService {
preload: join(__dirname, '../preload/index.js'), preload: join(__dirname, '../preload/index.js'),
sandbox: false, sandbox: false,
webSecurity: false, webSecurity: false,
webviewTag: true webviewTag: true,
backgroundThrottling: false
} }
}) })

View File

@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise<Record<string, string>> {
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
} }
Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
const child = spawn(shellPath, commandArgs, { const child = spawn(shellPath, commandArgs, {
cwd: homeDirectory, // Run the command in the user's home directory cwd: homeDirectory, // Run the command in the user's home directory

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'
import store, { useAppDispatch, useAppSelector } from '@renderer/store' import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp' import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp'
import { MCPServer } from '@renderer/types' import { MCPServer } from '@renderer/types'
import { IpcChannel } from '@shared/IpcChannel' import { IpcChannel } from '@shared/IpcChannel'
import { useMemo } from 'react'
const ipcRenderer = window.electron.ipcRenderer const ipcRenderer = window.electron.ipcRenderer
@ -14,9 +14,14 @@ ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => {
store.dispatch(addMCPServer(server)) store.dispatch(addMCPServer(server))
}) })
const selectMcpServers = (state) => state.mcp.servers
const selectActiveMcpServers = createSelector([selectMcpServers], (servers) =>
servers.filter((server) => server.isActive)
)
export const useMCPServers = () => { export const useMCPServers = () => {
const mcpServers = useAppSelector((state) => state.mcp.servers) const mcpServers = useAppSelector(selectMcpServers)
const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers]) const activedMcpServers = useAppSelector(selectActiveMcpServers)
const dispatch = useAppDispatch() const dispatch = useAppDispatch()
return { return {

View File

@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
import { useMCPServers } from '@renderer/hooks/useMCPServers' import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { EventEmitter } from '@renderer/services/EventService' import { EventEmitter } from '@renderer/services/EventService'
import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types' import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import { delay, runAsyncFunction } from '@renderer/utils'
import { Form, Input, Tooltip } from 'antd' import { Form, Input, Tooltip } from 'antd'
import { Plus, SquareTerminal } from 'lucide-react' import { Plus, SquareTerminal } from 'lucide-react'
import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react' import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => {
return null return null
} }
// Add static variable before component definition
let isFirstResourcesListCall = true
let isFirstPromptListCall = true
const initMcpDelay = 3
const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => { const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => {
const { activedMcpServers } = useMCPServers() const { activedMcpServers } = useMCPServers()
const { t } = useTranslation() const { t } = useTranslation()
@ -308,6 +314,11 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const promptList = useMemo(async () => { const promptList = useMemo(async () => {
const prompts: MCPPrompt[] = [] const prompts: MCPPrompt[] = []
if (isFirstPromptListCall) {
await delay(initMcpDelay)
isFirstPromptListCall = false
}
for (const server of activedMcpServers) { for (const server of activedMcpServers) {
const serverPrompts = await window.api.mcp.listPrompts(server) const serverPrompts = await window.api.mcp.listPrompts(server)
prompts.push(...serverPrompts) prompts.push(...serverPrompts)
@ -319,7 +330,8 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
icon: <SquareTerminal />, icon: <SquareTerminal />,
action: () => handlePromptSelect(prompt as MCPPromptWithArgs) action: () => handlePromptSelect(prompt as MCPPromptWithArgs)
})) }))
}, [handlePromptSelect, activedMcpServers]) // eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers])
const openPromptList = useCallback(async () => { const openPromptList = useCallback(async () => {
const prompts = await promptList const prompts = await promptList
@ -380,10 +392,12 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([]) const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
useEffect(() => { useEffect(() => {
runAsyncFunction(async () => {
let isMounted = true let isMounted = true
const fetchResources = async () => { const fetchResources = async () => {
const resources: MCPResource[] = [] const resources: MCPResource[] = []
for (const server of activedMcpServers) { for (const server of activedMcpServers) {
const serverResources = await window.api.mcp.listResources(server) const serverResources = await window.api.mcp.listResources(server)
resources.push(...serverResources) resources.push(...serverResources)
@ -401,12 +415,19 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
} }
} }
// Avoid mcp following the software startup, affecting the startup speed
if (isFirstResourcesListCall) {
await delay(initMcpDelay)
isFirstResourcesListCall = false
fetchResources() fetchResources()
}
return () => { return () => {
isMounted = false isMounted = false
} }
}, [activedMcpServers, handleResourceSelect]) })
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers])
const openResourcesList = useCallback(async () => { const openResourcesList = useCallback(async () => {
const resources = resourcesList const resources = resourcesList