refactor: streamline MCP service handling and improve IPC registration

* Refactored MCPService to implement a singleton pattern for better instance management.
* Updated IPC registration to utilize the new getMcpInstance method for handling MCP-related requests.
* Removed redundant IPC handlers from the main index file and centralized them in the ipc module.
* Added background throttling option in WindowService configuration to enhance performance.
* Introduced delays in MCPToolsButton to optimize resource and prompt fetching after initial load.
This commit is contained in:
kangfenmao 2025-05-18 18:47:01 +08:00 committed by 亢奋猫
parent 4b2417ce37
commit 8bfbbd497c
7 changed files with 320 additions and 252 deletions

View File

@ -2,12 +2,11 @@ import '@main/config'
import { electronApp, optimizer } from '@electron-toolkit/utils'
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
import { IpcChannel } from '@shared/IpcChannel'
import { app, BrowserWindow, ipcMain } from 'electron'
import { app } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
import Logger from 'electron-log'
import { isDev, isMac, isWin } from './constant'
import { isDev } from './constant'
import { registerIpc } from './ipc'
import { configManager } from './services/ConfigManager'
import mcpService from './services/MCPService'
@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) {
.then((name) => console.log(`Added Extension: ${name}`))
.catch((err) => console.log('An error occurred: ', err))
}
ipcMain.handle(IpcChannel.System_GetDeviceType, () => {
return isMac ? 'mac' : isWin ? 'windows' : 'linux'
})
ipcMain.handle(IpcChannel.System_GetHostname, () => {
return require('os').hostname()
})
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
})
registerProtocolClient(app)
@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) {
app.on('will-quit', async () => {
// event.preventDefault()
try {
await mcpService.cleanup()
await mcpService().cleanup()
} catch (error) {
Logger.error('Error cleaning up MCP service:', error)
}

View File

@ -19,7 +19,7 @@ import FileService from './services/FileService'
import FileStorage from './services/FileStorage'
import { GeminiService } from './services/GeminiService'
import KnowledgeService from './services/KnowledgeService'
import mcpService from './services/MCPService'
import { getMcpInstance } from './services/MCPService'
import * as NutstoreService from './services/NutstoreService'
import ObsidianVaultService from './services/ObsidianVaultService'
import { ProxyConfig, proxyManager } from './services/ProxyManager'
@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text))
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
// system
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
})
// backup
ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup)
ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore)
@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
)
// Register MCP handlers
ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer)
ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer)
ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer)
ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools)
ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool)
ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts)
ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt)
ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources)
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource)
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server))
ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server))
ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server))
ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server))
ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params))
ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server))
ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params))
ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server))
ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params))
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo())
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))

View File

@ -68,20 +68,18 @@ function withCache<T extends unknown[], R>(
}
class McpService {
private static instance: McpService | null = null
private clients: Map<string, Client> = new Map()
private pendingClients: Map<string, Promise<Client>> = new Map()
private getServerKey(server: MCPServer): string {
return JSON.stringify({
baseUrl: server.baseUrl,
command: server.command,
args: server.args,
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
public static getInstance(): McpService {
if (!McpService.instance) {
McpService.instance = new McpService()
}
return McpService.instance
}
constructor() {
private constructor() {
this.initClient = this.initClient.bind(this)
this.listTools = this.listTools.bind(this)
this.callTool = this.callTool.bind(this)
@ -96,9 +94,26 @@ class McpService {
this.cleanup = this.cleanup.bind(this)
}
private getServerKey(server: MCPServer): string {
return JSON.stringify({
baseUrl: server.baseUrl,
command: server.command,
args: server.args,
registryUrl: server.registryUrl,
env: server.env,
id: server.id
})
}
async initClient(server: MCPServer): Promise<Client> {
const serverKey = this.getServerKey(server)
// If there's a pending initialization, wait for it
const pendingClient = this.pendingClients.get(serverKey)
if (pendingClient) {
return pendingClient
}
// Check if we already have a client for this server configuration
const existingClient = this.clients.get(serverKey)
if (existingClient) {
@ -113,209 +128,226 @@ class McpService {
} else {
return existingClient
}
} catch (error) {
Logger.error(`[MCP] Error pinging server ${server.name}:`, error)
} catch (error: any) {
Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message)
this.clients.delete(serverKey)
}
}
// Create new client instance for each connection
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
const args = [...(server.args || [])]
// Create a promise for the initialization process
const initPromise = (async () => {
try {
// Create new client instance for each connection
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
const authProvider = new McpOAuthClientProvider({
serverUrlHash: crypto
.createHash('md5')
.update(server.baseUrl || '')
.digest('hex')
})
const args = [...(server.args || [])]
const initTransport = async (): Promise<
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
> => {
// Create appropriate transport based on configuration
if (server.type === 'inMemory') {
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
// start the in-memory server with the given name and environment variables
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
try {
await inMemoryServer.connect(serverTransport)
Logger.info(`[MCP] In-memory server started: ${server.name}`)
} catch (error: Error | any) {
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
throw new Error(`Failed to start in-memory server: ${error.message}`)
}
// set the client transport to the client
return clientTransport
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
} else if (server.type === 'sse') {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
const authProvider = new McpOAuthClientProvider({
serverUrlHash: crypto
.createHash('md5')
.update(server.baseUrl || '')
.digest('hex')
})
// Get tokens from authProvider to make sure using the latest tokens
if (authProvider && typeof authProvider.tokens === 'function') {
try {
const tokens = await authProvider.tokens()
if (tokens && tokens.access_token) {
headers['Authorization'] = `Bearer ${tokens.access_token}`
const initTransport = async (): Promise<
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
> => {
// Create appropriate transport based on configuration
if (server.type === 'inMemory') {
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
// start the in-memory server with the given name and environment variables
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
try {
await inMemoryServer.connect(serverTransport)
Logger.info(`[MCP] In-memory server started: ${server.name}`)
} catch (error: Error | any) {
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
throw new Error(`Failed to start in-memory server: ${error.message}`)
}
// set the client transport to the client
return clientTransport
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
} else if (server.type === 'sse') {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// Get tokens from authProvider to make sure using the latest tokens
if (authProvider && typeof authProvider.tokens === 'function') {
try {
const tokens = await authProvider.tokens()
if (tokens && tokens.access_token) {
headers['Authorization'] = `Bearer ${tokens.access_token}`
}
} catch (error) {
Logger.error('Failed to fetch tokens:', error)
}
}
} catch (error) {
Logger.error('Failed to fetch tokens:', error)
return fetch(url, { ...init, headers })
}
},
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new SSEClientTransport(new URL(server.baseUrl!), options)
} else {
throw new Error('Invalid server type')
}
} else if (server.command) {
let cmd = server.command
if (server.command === 'npx') {
cmd = await getBinaryPath('bun')
Logger.info(`[MCP] Using command: ${cmd}`)
// add -x to args if args exist
if (args && args.length > 0) {
if (!args.includes('-y')) {
args.unshift('-y')
}
if (!args.includes('x')) {
args.unshift('x')
}
}
if (server.registryUrl) {
server.env = {
...server.env,
NPM_CONFIG_REGISTRY: server.registryUrl
}
return fetch(url, { ...init, headers })
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
if (server.name.includes('mcp-auto-install')) {
const binPath = await getBinaryPath()
makeSureDirExists(binPath)
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
}
}
} else if (server.command === 'uvx' || server.command === 'uv') {
cmd = await getBinaryPath(server.command)
if (server.registryUrl) {
server.env = {
...server.env,
UV_DEFAULT_INDEX: server.registryUrl,
PIP_INDEX_URL: server.registryUrl
}
}
},
requestInit: {
headers: server.headers || {}
},
authProvider
}
return new SSEClientTransport(new URL(server.baseUrl!), options)
} else {
throw new Error('Invalid server type')
}
} else if (server.command) {
let cmd = server.command
if (server.command === 'npx') {
cmd = await getBinaryPath('bun')
Logger.info(`[MCP] Using command: ${cmd}`)
// add -x to args if args exist
if (args && args.length > 0) {
if (!args.includes('-y')) {
args.unshift('-y')
}
if (!args.includes('x')) {
args.unshift('x')
}
}
if (server.registryUrl) {
server.env = {
...server.env,
NPM_CONFIG_REGISTRY: server.registryUrl
}
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
if (server.name.includes('mcp-auto-install')) {
const binPath = await getBinaryPath()
makeSureDirExists(binPath)
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
}
}
} else if (server.command === 'uvx' || server.command === 'uv') {
cmd = await getBinaryPath(server.command)
if (server.registryUrl) {
server.env = {
...server.env,
UV_DEFAULT_INDEX: server.registryUrl,
PIP_INDEX_URL: server.registryUrl
}
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
// Logger.info(`[MCP] Environment variables for server:`, server.env)
const loginShellEnv = await this.getLoginShellEnv()
const stdioTransport = new StdioClientTransport({
command: cmd,
args,
env: {
...loginShellEnv,
...server.env
},
stderr: 'pipe'
})
stdioTransport.stderr?.on('data', (data) =>
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
)
return stdioTransport
} else {
throw new Error('Either baseUrl or command must be provided')
}
}
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
// Logger.info(`[MCP] Environment variables for server:`, server.env)
const loginShellEnv = await this.getLoginShellEnv()
const stdioTransport = new StdioClientTransport({
command: cmd,
args,
env: {
...loginShellEnv,
...server.env
},
stderr: 'pipe'
})
stdioTransport.stderr?.on('data', (data) =>
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
)
return stdioTransport
} else {
throw new Error('Either baseUrl or command must be provided')
}
}
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
// Create an event emitter for the OAuth callback
const events = new EventEmitter()
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
// Create an event emitter for the OAuth callback
const events = new EventEmitter()
// Create a callback server
const callbackServer = new CallBackServer({
port: authProvider.config.callbackPort,
path: authProvider.config.callbackPath || '/oauth/callback',
events
})
// Create a callback server
const callbackServer = new CallBackServer({
port: authProvider.config.callbackPort,
path: authProvider.config.callbackPath || '/oauth/callback',
events
})
// Set a timeout to close the callback server
const timeoutId = setTimeout(() => {
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
callbackServer.close()
}, 300000) // 5 minutes timeout
// Set a timeout to close the callback server
const timeoutId = setTimeout(() => {
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
callbackServer.close()
}, 300000) // 5 minutes timeout
try {
// Wait for the authorization code
const authCode = await callbackServer.waitForAuthCode()
Logger.info(`[MCP] Received auth code: ${authCode}`)
try {
// Wait for the authorization code
const authCode = await callbackServer.waitForAuthCode()
Logger.info(`[MCP] Received auth code: ${authCode}`)
// Complete the OAuth flow
await transport.finishAuth(authCode)
// Complete the OAuth flow
await transport.finishAuth(authCode)
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
const newTransport = await initTransport()
// Try to connect again
await client.connect(newTransport)
const newTransport = await initTransport()
// Try to connect again
await client.connect(newTransport)
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
} catch (oauthError) {
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
throw new Error(
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
)
} finally {
// Clear the timeout and close the callback server
clearTimeout(timeoutId)
callbackServer.close()
}
}
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
} catch (oauthError) {
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
throw new Error(
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
)
try {
const transport = await initTransport()
try {
await client.connect(transport)
} catch (error: Error | any) {
if (
error instanceof Error &&
(error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))
) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else {
throw error
}
}
// Store the new client in the cache
this.clients.set(serverKey, client)
Logger.info(`[MCP] Activated server: ${server.name}`)
return client
} catch (error: any) {
Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
}
} finally {
// Clear the timeout and close the callback server
clearTimeout(timeoutId)
callbackServer.close()
// Clean up the pending promise when done
this.pendingClients.delete(serverKey)
}
}
})()
try {
const transport = await initTransport()
try {
await client.connect(transport)
} catch (error: Error | any) {
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else {
throw error
}
}
// Store the pending promise
this.pendingClients.set(serverKey, initPromise)
// Store the new client in the cache
this.clients.set(serverKey, client)
Logger.info(`[MCP] Activated server: ${server.name}`)
return client
} catch (error: any) {
Logger.error(`[MCP] Error activating server ${server.name}:`, error)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
}
return initPromise
}
async closeClient(serverKey: string) {
@ -357,8 +389,8 @@ class McpService {
for (const [key] of this.clients) {
try {
await this.closeClient(key)
} catch (error) {
Logger.error(`[MCP] Failed to close client: ${error}`)
} catch (error: any) {
Logger.error(`[MCP] Failed to close client: ${error?.message}`)
}
}
}
@ -379,8 +411,8 @@ class McpService {
serverTools.push(serverTool)
})
return serverTools
} catch (error) {
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error)
} catch (error: any) {
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message)
return []
}
}
@ -439,8 +471,8 @@ class McpService {
* List prompts available on an MCP server
*/
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
const client = await this.initClient(server)
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
try {
const { prompts } = await client.listPrompts()
return prompts.map((prompt: any) => ({
@ -449,8 +481,11 @@ class McpService {
serverId: server.id,
serverName: server.name
}))
} catch (error) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error)
} catch (error: any) {
// -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message)
}
return []
}
}
@ -508,8 +543,8 @@ class McpService {
* List resources available on an MCP server (implementation)
*/
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
const client = await this.initClient(server)
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
try {
const result = await client.listResources()
const resources = result.resources || []
@ -519,8 +554,11 @@ class McpService {
serverName: server.name
}))
return serverResources
} catch (error) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error)
} catch (error: any) {
// -32601 is the code for the method not found
if (error?.code !== -32601) {
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message)
}
return []
}
}
@ -563,7 +601,7 @@ class McpService {
contents: contents
}
} catch (error: Error | any) {
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error)
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message)
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
}
}
@ -602,5 +640,13 @@ class McpService {
})
}
const mcpService = new McpService()
export default mcpService
let mcpInstance: ReturnType<typeof McpService.getInstance> | null = null
export const getMcpInstance = () => {
if (!mcpInstance) {
mcpInstance = McpService.getInstance()
}
return mcpInstance
}
export default McpService.getInstance

View File

@ -446,7 +446,8 @@ export class WindowService {
preload: join(__dirname, '../preload/index.js'),
sandbox: false,
webSecurity: false,
webviewTag: true
webviewTag: true,
backgroundThrottling: false
}
})

View File

@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise<Record<string, string>> {
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
}
Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
const child = spawn(shellPath, commandArgs, {
cwd: homeDirectory, // Run the command in the user's home directory

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp'
import { MCPServer } from '@renderer/types'
import { IpcChannel } from '@shared/IpcChannel'
import { useMemo } from 'react'
const ipcRenderer = window.electron.ipcRenderer
@ -14,9 +14,14 @@ ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => {
store.dispatch(addMCPServer(server))
})
const selectMcpServers = (state) => state.mcp.servers
const selectActiveMcpServers = createSelector([selectMcpServers], (servers) =>
servers.filter((server) => server.isActive)
)
export const useMCPServers = () => {
const mcpServers = useAppSelector((state) => state.mcp.servers)
const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers])
const mcpServers = useAppSelector(selectMcpServers)
const activedMcpServers = useAppSelector(selectActiveMcpServers)
const dispatch = useAppDispatch()
return {

View File

@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { EventEmitter } from '@renderer/services/EventService'
import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import { delay, runAsyncFunction } from '@renderer/utils'
import { Form, Input, Tooltip } from 'antd'
import { Plus, SquareTerminal } from 'lucide-react'
import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => {
return null
}
// Add static variable before component definition
let isFirstResourcesListCall = true
let isFirstPromptListCall = true
const initMcpDelay = 3
const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => {
const { activedMcpServers } = useMCPServers()
const { t } = useTranslation()
@ -308,6 +314,11 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const promptList = useMemo(async () => {
const prompts: MCPPrompt[] = []
if (isFirstPromptListCall) {
await delay(initMcpDelay)
isFirstPromptListCall = false
}
for (const server of activedMcpServers) {
const serverPrompts = await window.api.mcp.listPrompts(server)
prompts.push(...serverPrompts)
@ -319,7 +330,8 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
icon: <SquareTerminal />,
action: () => handlePromptSelect(prompt as MCPPromptWithArgs)
}))
}, [handlePromptSelect, activedMcpServers])
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers])
const openPromptList = useCallback(async () => {
const prompts = await promptList
@ -380,33 +392,42 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
useEffect(() => {
let isMounted = true
runAsyncFunction(async () => {
let isMounted = true
const fetchResources = async () => {
const resources: MCPResource[] = []
for (const server of activedMcpServers) {
const serverResources = await window.api.mcp.listResources(server)
resources.push(...serverResources)
const fetchResources = async () => {
const resources: MCPResource[] = []
for (const server of activedMcpServers) {
const serverResources = await window.api.mcp.listResources(server)
resources.push(...serverResources)
}
if (isMounted) {
setResourcesList(
resources.map((resource) => ({
label: resource.name,
description: resource.description,
icon: <SquareTerminal />,
action: () => handleResourceSelect(resource)
}))
)
}
}
if (isMounted) {
setResourcesList(
resources.map((resource) => ({
label: resource.name,
description: resource.description,
icon: <SquareTerminal />,
action: () => handleResourceSelect(resource)
}))
)
// Avoid mcp following the software startup, affecting the startup speed
if (isFirstResourcesListCall) {
await delay(initMcpDelay)
isFirstResourcesListCall = false
fetchResources()
}
}
fetchResources()
return () => {
isMounted = false
}
}, [activedMcpServers, handleResourceSelect])
return () => {
isMounted = false
}
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [activedMcpServers])
const openResourcesList = useCallback(async () => {
const resources = resourcesList