refactor: streamline MCP service handling and improve IPC registration

* Refactored MCPService to implement a singleton pattern for better instance management.
* Updated IPC registration to utilize the new getMcpInstance method for handling MCP-related requests.
* Removed redundant IPC handlers from the main index file and centralized them in the ipc module.
* Added background throttling option in WindowService configuration to enhance performance.
* Introduced delays in MCPToolsButton to optimize resource and prompt fetching after initial load.
This commit is contained in:
kangfenmao 2025-05-18 18:47:01 +08:00 committed by 亢奋猫
parent 4b2417ce37
commit 8bfbbd497c
7 changed files with 320 additions and 252 deletions

View File

@ -2,12 +2,11 @@ import '@main/config'
import { electronApp, optimizer } from '@electron-toolkit/utils' import { electronApp, optimizer } from '@electron-toolkit/utils'
import { replaceDevtoolsFont } from '@main/utils/windowUtil' import { replaceDevtoolsFont } from '@main/utils/windowUtil'
import { IpcChannel } from '@shared/IpcChannel' import { app } from 'electron'
import { app, BrowserWindow, ipcMain } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer' import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
import Logger from 'electron-log' import Logger from 'electron-log'
import { isDev, isMac, isWin } from './constant' import { isDev } from './constant'
import { registerIpc } from './ipc' import { registerIpc } from './ipc'
import { configManager } from './services/ConfigManager' import { configManager } from './services/ConfigManager'
import mcpService from './services/MCPService' import mcpService from './services/MCPService'
@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) {
.then((name) => console.log(`Added Extension: ${name}`)) .then((name) => console.log(`Added Extension: ${name}`))
.catch((err) => console.log('An error occurred: ', err)) .catch((err) => console.log('An error occurred: ', err))
} }
ipcMain.handle(IpcChannel.System_GetDeviceType, () => {
return isMac ? 'mac' : isWin ? 'windows' : 'linux'
})
ipcMain.handle(IpcChannel.System_GetHostname, () => {
return require('os').hostname()
})
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
}) })
registerProtocolClient(app) registerProtocolClient(app)
@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) {
app.on('will-quit', async () => { app.on('will-quit', async () => {
// event.preventDefault() // event.preventDefault()
try { try {
await mcpService.cleanup() await mcpService().cleanup()
} catch (error) { } catch (error) {
Logger.error('Error cleaning up MCP service:', error) Logger.error('Error cleaning up MCP service:', error)
} }

View File

@ -19,7 +19,7 @@ import FileService from './services/FileService'
import FileStorage from './services/FileStorage' import FileStorage from './services/FileStorage'
import { GeminiService } from './services/GeminiService' import { GeminiService } from './services/GeminiService'
import KnowledgeService from './services/KnowledgeService' import KnowledgeService from './services/KnowledgeService'
import mcpService from './services/MCPService' import { getMcpInstance } from './services/MCPService'
import * as NutstoreService from './services/NutstoreService' import * as NutstoreService from './services/NutstoreService'
import ObsidianVaultService from './services/ObsidianVaultService' import ObsidianVaultService from './services/ObsidianVaultService'
import { ProxyConfig, proxyManager } from './services/ProxyManager' import { ProxyConfig, proxyManager } from './services/ProxyManager'
@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text)) ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text))
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text)) ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
// system
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
// backup // backup
ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup) ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup)
ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore) ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore)
@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
) )
// Register MCP handlers // Register MCP handlers
ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer) ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server))
ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer) ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server))
ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer) ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server))
ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools) ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server))
ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool) ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params))
ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts) ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server))
ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt) ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params))
ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources) ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server))
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource) ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params))
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo) ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo())
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name)) ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name)) ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))

View File

@ -68,20 +68,18 @@ function withCache<T extends unknown[], R>(
} }
class McpService { class McpService {
private static instance: McpService | null = null
private clients: Map<string, Client> = new Map() private clients: Map<string, Client> = new Map()
private pendingClients: Map<string, Promise<Client>> = new Map()
private getServerKey(server: MCPServer): string { public static getInstance(): McpService {
return JSON.stringify({ if (!McpService.instance) {
baseUrl: server.baseUrl, McpService.instance = new McpService()
command: server.command, }
args: server.args, return McpService.instance
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
} }
constructor() { private constructor() {
this.initClient = this.initClient.bind(this) this.initClient = this.initClient.bind(this)
this.listTools = this.listTools.bind(this) this.listTools = this.listTools.bind(this)
this.callTool = this.callTool.bind(this) this.callTool = this.callTool.bind(this)
@ -96,9 +94,26 @@ class McpService {
this.cleanup = this.cleanup.bind(this) this.cleanup = this.cleanup.bind(this)
} }
private getServerKey(server: MCPServer): string {
return JSON.stringify({
baseUrl: server.baseUrl,
command: server.command,
args: server.args,
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
}
async initClient(server: MCPServer): Promise<Client> { async initClient(server: MCPServer): Promise<Client> {
const serverKey = this.getServerKey(server) const serverKey = this.getServerKey(server)
// If there's a pending initialization, wait for it
const pendingClient = this.pendingClients.get(serverKey)
if (pendingClient) {
return pendingClient
}
// Check if we already have a client for this server configuration // Check if we already have a client for this server configuration
const existingClient = this.clients.get(serverKey) const existingClient = this.clients.get(serverKey)
if (existingClient) { if (existingClient) {
@ -113,209 +128,226 @@ class McpService {
} else { } else {
return existingClient return existingClient
} }
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Error pinging server ${server.name}:`, error) Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message)
this.clients.delete(serverKey) this.clients.delete(serverKey)
} }
} }
// Create new client instance for each connection
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
const args = [...(server.args || [])] // Create a promise for the initialization process
const initPromise = (async () => {
try {
// Create new client instance for each connection
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport const args = [...(server.args || [])]
const authProvider = new McpOAuthClientProvider({
serverUrlHash: crypto
.createHash('md5')
.update(server.baseUrl || '')
.digest('hex')
})
const initTransport = async (): Promise< // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport const authProvider = new McpOAuthClientProvider({
> => { serverUrlHash: crypto
// Create appropriate transport based on configuration .createHash('md5')
if (server.type === 'inMemory') { .update(server.baseUrl || '')
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) .digest('hex')
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair() })
// start the in-memory server with the given name and environment variables
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
try {
await inMemoryServer.connect(serverTransport)
Logger.info(`[MCP] In-memory server started: ${server.name}`)
} catch (error: Error | any) {
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
throw new Error(`Failed to start in-memory server: ${error.message}`)
}
// set the client transport to the client
return clientTransport
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
} else if (server.type === 'sse') {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// Get tokens from authProvider to make sure using the latest tokens const initTransport = async (): Promise<
if (authProvider && typeof authProvider.tokens === 'function') { StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
try { > => {
const tokens = await authProvider.tokens() // Create appropriate transport based on configuration
if (tokens && tokens.access_token) { if (server.type === 'inMemory') {
headers['Authorization'] = `Bearer ${tokens.access_token}` Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
// start the in-memory server with the given name and environment variables
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
try {
await inMemoryServer.connect(serverTransport)
Logger.info(`[MCP] In-memory server started: ${server.name}`)
} catch (error: Error | any) {
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
throw new Error(`Failed to start in-memory server: ${error.message}`)
}
// set the client transport to the client
return clientTransport
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
} else if (server.type === 'sse') {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// Get tokens from authProvider to make sure using the latest tokens
if (authProvider && typeof authProvider.tokens === 'function') {
try {
const tokens = await authProvider.tokens()
if (tokens && tokens.access_token) {
headers['Authorization'] = `Bearer ${tokens.access_token}`
}
} catch (error) {
Logger.error('Failed to fetch tokens:', error)
}
} }
} catch (error) {
Logger.error('Failed to fetch tokens:', error) return fetch(url, { ...init, headers })
} }
},
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new SSEClientTransport(new URL(server.baseUrl!), options)
} else {
throw new Error('Invalid server type')
}
} else if (server.command) {
let cmd = server.command
if (server.command === 'npx') {
cmd = await getBinaryPath('bun')
Logger.info(`[MCP] Using command: ${cmd}`)
// add -x to args if args exist
if (args && args.length > 0) {
if (!args.includes('-y')) {
args.unshift('-y')
}
if (!args.includes('x')) {
args.unshift('x')
}
}
if (server.registryUrl) {
server.env = {
...server.env,
NPM_CONFIG_REGISTRY: server.registryUrl
} }
return fetch(url, { ...init, headers }) // if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
if (server.name.includes('mcp-auto-install')) {
const binPath = await getBinaryPath()
makeSureDirExists(binPath)
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
}
}
} else if (server.command === 'uvx' || server.command === 'uv') {
cmd = await getBinaryPath(server.command)
if (server.registryUrl) {
server.env = {
...server.env,
UV_DEFAULT_INDEX: server.registryUrl,
PIP_INDEX_URL: server.registryUrl
}
} }
},
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new SSEClientTransport(new URL(server.baseUrl!), options)
} else {
throw new Error('Invalid server type')
}
} else if (server.command) {
let cmd = server.command
if (server.command === 'npx') {
cmd = await getBinaryPath('bun')
Logger.info(`[MCP] Using command: ${cmd}`)
// add -x to args if args exist
if (args && args.length > 0) {
if (!args.includes('-y')) {
args.unshift('-y')
}
if (!args.includes('x')) {
args.unshift('x')
}
}
if (server.registryUrl) {
server.env = {
...server.env,
NPM_CONFIG_REGISTRY: server.registryUrl
} }
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
if (server.name.includes('mcp-auto-install')) { // Logger.info(`[MCP] Environment variables for server:`, server.env)
const binPath = await getBinaryPath() const loginShellEnv = await this.getLoginShellEnv()
makeSureDirExists(binPath) const stdioTransport = new StdioClientTransport({
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json') command: cmd,
} args,
} env: {
} else if (server.command === 'uvx' || server.command === 'uv') { ...loginShellEnv,
cmd = await getBinaryPath(server.command) ...server.env
if (server.registryUrl) { },
server.env = { stderr: 'pipe'
...server.env, })
UV_DEFAULT_INDEX: server.registryUrl, stdioTransport.stderr?.on('data', (data) =>
PIP_INDEX_URL: server.registryUrl Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
} )
return stdioTransport
} else {
throw new Error('Either baseUrl or command must be provided')
} }
} }
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
// Logger.info(`[MCP] Environment variables for server:`, server.env) Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
const loginShellEnv = await this.getLoginShellEnv() // Create an event emitter for the OAuth callback
const stdioTransport = new StdioClientTransport({ const events = new EventEmitter()
command: cmd,
args,
env: {
...loginShellEnv,
...server.env
},
stderr: 'pipe'
})
stdioTransport.stderr?.on('data', (data) =>
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
)
return stdioTransport
} else {
throw new Error('Either baseUrl or command must be provided')
}
}
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { // Create a callback server
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) const callbackServer = new CallBackServer({
// Create an event emitter for the OAuth callback port: authProvider.config.callbackPort,
const events = new EventEmitter() path: authProvider.config.callbackPath || '/oauth/callback',
events
})
// Create a callback server // Set a timeout to close the callback server
const callbackServer = new CallBackServer({ const timeoutId = setTimeout(() => {
port: authProvider.config.callbackPort, Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
path: authProvider.config.callbackPath || '/oauth/callback', callbackServer.close()
events }, 300000) // 5 minutes timeout
})
// Set a timeout to close the callback server try {
const timeoutId = setTimeout(() => { // Wait for the authorization code
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) const authCode = await callbackServer.waitForAuthCode()
callbackServer.close() Logger.info(`[MCP] Received auth code: ${authCode}`)
}, 300000) // 5 minutes timeout
try { // Complete the OAuth flow
// Wait for the authorization code await transport.finishAuth(authCode)
const authCode = await callbackServer.waitForAuthCode()
Logger.info(`[MCP] Received auth code: ${authCode}`)
// Complete the OAuth flow Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
await transport.finishAuth(authCode)
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) const newTransport = await initTransport()
// Try to connect again
await client.connect(newTransport)
const newTransport = await initTransport() Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
// Try to connect again } catch (oauthError) {
await client.connect(newTransport) Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
throw new Error(
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
)
} finally {
// Clear the timeout and close the callback server
clearTimeout(timeoutId)
callbackServer.close()
}
}
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) try {
} catch (oauthError) { const transport = await initTransport()
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) try {
throw new Error( await client.connect(transport)
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` } catch (error: Error | any) {
) if (
error instanceof Error &&
(error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))
) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else {
throw error
}
}
// Store the new client in the cache
this.clients.set(serverKey, client)
Logger.info(`[MCP] Activated server: ${server.name}`)
return client
} catch (error: any) {
Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
}
} finally { } finally {
// Clear the timeout and close the callback server // Clean up the pending promise when done
clearTimeout(timeoutId) this.pendingClients.delete(serverKey)
callbackServer.close()
} }
} })()
try { // Store the pending promise
const transport = await initTransport() this.pendingClients.set(serverKey, initPromise)
try {
await client.connect(transport)
} catch (error: Error | any) {
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else {
throw error
}
}
// Store the new client in the cache return initPromise
this.clients.set(serverKey, client)
Logger.info(`[MCP] Activated server: ${server.name}`)
return client
} catch (error: any) {
Logger.error(`[MCP] Error activating server ${server.name}:`, error)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
}
} }
async closeClient(serverKey: string) { async closeClient(serverKey: string) {
@ -357,8 +389,8 @@ class McpService {
for (const [key] of this.clients) { for (const [key] of this.clients) {
try { try {
await this.closeClient(key) await this.closeClient(key)
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to close client: ${error}`) Logger.error(`[MCP] Failed to close client: ${error?.message}`)
} }
} }
} }
@ -379,8 +411,8 @@ class McpService {
serverTools.push(serverTool) serverTools.push(serverTool)
}) })
return serverTools return serverTools
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error) Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message)
return [] return []
} }
} }
@ -439,8 +471,8 @@ class McpService {
* List prompts available on an MCP server * List prompts available on an MCP server
*/ */
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> { private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
const client = await this.initClient(server) const client = await this.initClient(server)
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
try { try {
const { prompts } = await client.listPrompts() const { prompts } = await client.listPrompts()
return prompts.map((prompt: any) => ({ return prompts.map((prompt: any) => ({
@ -449,8 +481,11 @@ class McpService {
serverId: server.id, serverId: server.id,
serverName: server.name serverName: server.name
})) }))
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error) // -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message)
}
return [] return []
} }
} }
@ -508,8 +543,8 @@ class McpService {
* List resources available on an MCP server (implementation) * List resources available on an MCP server (implementation)
*/ */
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> { private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
const client = await this.initClient(server) const client = await this.initClient(server)
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
try { try {
const result = await client.listResources() const result = await client.listResources()
const resources = result.resources || [] const resources = result.resources || []
@ -519,8 +554,11 @@ class McpService {
serverName: server.name serverName: server.name
})) }))
return serverResources return serverResources
} catch (error) { } catch (error: any) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error) // -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message)
}
return [] return []
} }
} }
@ -563,7 +601,7 @@ class McpService {
contents: contents contents: contents
} }
} catch (error: Error | any) { } catch (error: Error | any) {
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error) Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message)
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`) throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
} }
} }
@ -602,5 +640,13 @@ class McpService {
}) })
} }
const mcpService = new McpService() let mcpInstance: ReturnType<typeof McpService.getInstance> | null = null
export default mcpService
export const getMcpInstance = () => {
if (!mcpInstance) {
mcpInstance = McpService.getInstance()
}
return mcpInstance
}
export default McpService.getInstance

View File

@ -446,7 +446,8 @@ export class WindowService {
preload: join(__dirname, '../preload/index.js'), preload: join(__dirname, '../preload/index.js'),
sandbox: false, sandbox: false,
webSecurity: false, webSecurity: false,
webviewTag: true webviewTag: true,
backgroundThrottling: false
} }
}) })

View File

@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise<Record<string, string>> {
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
} }
Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
const child = spawn(shellPath, commandArgs, { const child = spawn(shellPath, commandArgs, {
cwd: homeDirectory, // Run the command in the user's home directory cwd: homeDirectory, // Run the command in the user's home directory

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'
import store, { useAppDispatch, useAppSelector } from '@renderer/store' import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp' import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp'
import { MCPServer } from '@renderer/types' import { MCPServer } from '@renderer/types'
import { IpcChannel } from '@shared/IpcChannel' import { IpcChannel } from '@shared/IpcChannel'
import { useMemo } from 'react'
const ipcRenderer = window.electron.ipcRenderer const ipcRenderer = window.electron.ipcRenderer
@ -14,9 +14,14 @@ ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => {
store.dispatch(addMCPServer(server)) store.dispatch(addMCPServer(server))
}) })
const selectMcpServers = (state) => state.mcp.servers
const selectActiveMcpServers = createSelector([selectMcpServers], (servers) =>
servers.filter((server) => server.isActive)
)
export const useMCPServers = () => { export const useMCPServers = () => {
const mcpServers = useAppSelector((state) => state.mcp.servers) const mcpServers = useAppSelector(selectMcpServers)
const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers]) const activedMcpServers = useAppSelector(selectActiveMcpServers)
const dispatch = useAppDispatch() const dispatch = useAppDispatch()
return { return {

View File

@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
import { useMCPServers } from '@renderer/hooks/useMCPServers' import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { EventEmitter } from '@renderer/services/EventService' import { EventEmitter } from '@renderer/services/EventService'
import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types' import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import { delay, runAsyncFunction } from '@renderer/utils'
import { Form, Input, Tooltip } from 'antd' import { Form, Input, Tooltip } from 'antd'
import { Plus, SquareTerminal } from 'lucide-react' import { Plus, SquareTerminal } from 'lucide-react'
import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react' import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => {
return null return null
} }
// Add static variable before component definition
let isFirstResourcesListCall = true
let isFirstPromptListCall = true
const initMcpDelay = 3
const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => { const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => {
const { activedMcpServers } = useMCPServers() const { activedMcpServers } = useMCPServers()
const { t } = useTranslation() const { t } = useTranslation()
@ -308,6 +314,11 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const promptList = useMemo(async () => { const promptList = useMemo(async () => {
const prompts: MCPPrompt[] = [] const prompts: MCPPrompt[] = []
if (isFirstPromptListCall) {
await delay(initMcpDelay)
isFirstPromptListCall = false
}
for (const server of activedMcpServers) { for (const server of activedMcpServers) {
const serverPrompts = await window.api.mcp.listPrompts(server) const serverPrompts = await window.api.mcp.listPrompts(server)
prompts.push(...serverPrompts) prompts.push(...serverPrompts)
@ -319,7 +330,8 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
icon: <SquareTerminal />, icon: <SquareTerminal />,
action: () => handlePromptSelect(prompt as MCPPromptWithArgs) action: () => handlePromptSelect(prompt as MCPPromptWithArgs)
})) }))
}, [handlePromptSelect, activedMcpServers]) // eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers])
const openPromptList = useCallback(async () => { const openPromptList = useCallback(async () => {
const prompts = await promptList const prompts = await promptList
@ -380,33 +392,42 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([]) const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
useEffect(() => { useEffect(() => {
let isMounted = true runAsyncFunction(async () => {
let isMounted = true
const fetchResources = async () => { const fetchResources = async () => {
const resources: MCPResource[] = [] const resources: MCPResource[] = []
for (const server of activedMcpServers) {
const serverResources = await window.api.mcp.listResources(server) for (const server of activedMcpServers) {
resources.push(...serverResources) const serverResources = await window.api.mcp.listResources(server)
resources.push(...serverResources)
}
if (isMounted) {
setResourcesList(
resources.map((resource) => ({
label: resource.name,
description: resource.description,
icon: <SquareTerminal />,
action: () => handleResourceSelect(resource)
}))
)
}
} }
if (isMounted) { // Avoid mcp following the software startup, affecting the startup speed
setResourcesList( if (isFirstResourcesListCall) {
resources.map((resource) => ({ await delay(initMcpDelay)
label: resource.name, isFirstResourcesListCall = false
description: resource.description, fetchResources()
icon: <SquareTerminal />,
action: () => handleResourceSelect(resource)
}))
)
} }
}
fetchResources() return () => {
isMounted = false
return () => { }
isMounted = false })
} // eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers, handleResourceSelect]) }, [activedMcpServers])
const openResourcesList = useCallback(async () => { const openResourcesList = useCallback(async () => {
const resources = resourcesList const resources = resourcesList