mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-01 01:30:51 +08:00
refactor: streamline MCP service handling and improve IPC registration
* Refactored MCPService to implement a singleton pattern for better instance management. * Updated IPC registration to utilize the new getMcpInstance method for handling MCP-related requests. * Removed redundant IPC handlers from the main index file and centralized them in the ipc module. * Added background throttling option in WindowService configuration to enhance performance. * Introduced delays in MCPToolsButton to optimize resource and prompt fetching after initial load.
This commit is contained in:
parent
4b2417ce37
commit
8bfbbd497c
@ -2,12 +2,11 @@ import '@main/config'
|
|||||||
|
|
||||||
import { electronApp, optimizer } from '@electron-toolkit/utils'
|
import { electronApp, optimizer } from '@electron-toolkit/utils'
|
||||||
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
|
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { app } from 'electron'
|
||||||
import { app, BrowserWindow, ipcMain } from 'electron'
|
|
||||||
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
||||||
import Logger from 'electron-log'
|
import Logger from 'electron-log'
|
||||||
|
|
||||||
import { isDev, isMac, isWin } from './constant'
|
import { isDev } from './constant'
|
||||||
import { registerIpc } from './ipc'
|
import { registerIpc } from './ipc'
|
||||||
import { configManager } from './services/ConfigManager'
|
import { configManager } from './services/ConfigManager'
|
||||||
import mcpService from './services/MCPService'
|
import mcpService from './services/MCPService'
|
||||||
@ -85,18 +84,6 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
.then((name) => console.log(`Added Extension: ${name}`))
|
.then((name) => console.log(`Added Extension: ${name}`))
|
||||||
.catch((err) => console.log('An error occurred: ', err))
|
.catch((err) => console.log('An error occurred: ', err))
|
||||||
}
|
}
|
||||||
ipcMain.handle(IpcChannel.System_GetDeviceType, () => {
|
|
||||||
return isMac ? 'mac' : isWin ? 'windows' : 'linux'
|
|
||||||
})
|
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.System_GetHostname, () => {
|
|
||||||
return require('os').hostname()
|
|
||||||
})
|
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
|
||||||
const win = BrowserWindow.fromWebContents(e.sender)
|
|
||||||
win && win.webContents.toggleDevTools()
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
registerProtocolClient(app)
|
registerProtocolClient(app)
|
||||||
@ -128,7 +115,7 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
app.on('will-quit', async () => {
|
app.on('will-quit', async () => {
|
||||||
// event.preventDefault()
|
// event.preventDefault()
|
||||||
try {
|
try {
|
||||||
await mcpService.cleanup()
|
await mcpService().cleanup()
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
Logger.error('Error cleaning up MCP service:', error)
|
Logger.error('Error cleaning up MCP service:', error)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import FileService from './services/FileService'
|
|||||||
import FileStorage from './services/FileStorage'
|
import FileStorage from './services/FileStorage'
|
||||||
import { GeminiService } from './services/GeminiService'
|
import { GeminiService } from './services/GeminiService'
|
||||||
import KnowledgeService from './services/KnowledgeService'
|
import KnowledgeService from './services/KnowledgeService'
|
||||||
import mcpService from './services/MCPService'
|
import { getMcpInstance } from './services/MCPService'
|
||||||
import * as NutstoreService from './services/NutstoreService'
|
import * as NutstoreService from './services/NutstoreService'
|
||||||
import ObsidianVaultService from './services/ObsidianVaultService'
|
import ObsidianVaultService from './services/ObsidianVaultService'
|
||||||
import { ProxyConfig, proxyManager } from './services/ProxyManager'
|
import { ProxyConfig, proxyManager } from './services/ProxyManager'
|
||||||
@ -204,6 +204,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text))
|
ipcMain.handle(IpcChannel.Zip_Compress, (_, text: string) => compress(text))
|
||||||
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
|
ipcMain.handle(IpcChannel.Zip_Decompress, (_, text: Buffer) => decompress(text))
|
||||||
|
|
||||||
|
// system
|
||||||
|
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
||||||
|
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
||||||
|
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
||||||
|
const win = BrowserWindow.fromWebContents(e.sender)
|
||||||
|
win && win.webContents.toggleDevTools()
|
||||||
|
})
|
||||||
|
|
||||||
// backup
|
// backup
|
||||||
ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup)
|
ipcMain.handle(IpcChannel.Backup_Backup, backupManager.backup)
|
||||||
ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore)
|
ipcMain.handle(IpcChannel.Backup_Restore, backupManager.restore)
|
||||||
@ -301,16 +309,16 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Register MCP handlers
|
// Register MCP handlers
|
||||||
ipcMain.handle(IpcChannel.Mcp_RemoveServer, mcpService.removeServer)
|
ipcMain.handle(IpcChannel.Mcp_RemoveServer, (event, server) => getMcpInstance().removeServer(event, server))
|
||||||
ipcMain.handle(IpcChannel.Mcp_RestartServer, mcpService.restartServer)
|
ipcMain.handle(IpcChannel.Mcp_RestartServer, (event, server) => getMcpInstance().restartServer(event, server))
|
||||||
ipcMain.handle(IpcChannel.Mcp_StopServer, mcpService.stopServer)
|
ipcMain.handle(IpcChannel.Mcp_StopServer, (event, server) => getMcpInstance().stopServer(event, server))
|
||||||
ipcMain.handle(IpcChannel.Mcp_ListTools, mcpService.listTools)
|
ipcMain.handle(IpcChannel.Mcp_ListTools, (event, server) => getMcpInstance().listTools(event, server))
|
||||||
ipcMain.handle(IpcChannel.Mcp_CallTool, mcpService.callTool)
|
ipcMain.handle(IpcChannel.Mcp_CallTool, (event, params) => getMcpInstance().callTool(event, params))
|
||||||
ipcMain.handle(IpcChannel.Mcp_ListPrompts, mcpService.listPrompts)
|
ipcMain.handle(IpcChannel.Mcp_ListPrompts, (event, server) => getMcpInstance().listPrompts(event, server))
|
||||||
ipcMain.handle(IpcChannel.Mcp_GetPrompt, mcpService.getPrompt)
|
ipcMain.handle(IpcChannel.Mcp_GetPrompt, (event, params) => getMcpInstance().getPrompt(event, params))
|
||||||
ipcMain.handle(IpcChannel.Mcp_ListResources, mcpService.listResources)
|
ipcMain.handle(IpcChannel.Mcp_ListResources, (event, server) => getMcpInstance().listResources(event, server))
|
||||||
ipcMain.handle(IpcChannel.Mcp_GetResource, mcpService.getResource)
|
ipcMain.handle(IpcChannel.Mcp_GetResource, (event, params) => getMcpInstance().getResource(event, params))
|
||||||
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
|
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, () => getMcpInstance().getInstallInfo())
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
|
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
|
||||||
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))
|
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))
|
||||||
|
|||||||
@ -68,20 +68,18 @@ function withCache<T extends unknown[], R>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
class McpService {
|
class McpService {
|
||||||
|
private static instance: McpService | null = null
|
||||||
private clients: Map<string, Client> = new Map()
|
private clients: Map<string, Client> = new Map()
|
||||||
|
private pendingClients: Map<string, Promise<Client>> = new Map()
|
||||||
|
|
||||||
private getServerKey(server: MCPServer): string {
|
public static getInstance(): McpService {
|
||||||
return JSON.stringify({
|
if (!McpService.instance) {
|
||||||
baseUrl: server.baseUrl,
|
McpService.instance = new McpService()
|
||||||
command: server.command,
|
}
|
||||||
args: server.args,
|
return McpService.instance
|
||||||
registryUrl: server.registryUrl,
|
|
||||||
env: server.env,
|
|
||||||
id: server.id
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constructor() {
|
private constructor() {
|
||||||
this.initClient = this.initClient.bind(this)
|
this.initClient = this.initClient.bind(this)
|
||||||
this.listTools = this.listTools.bind(this)
|
this.listTools = this.listTools.bind(this)
|
||||||
this.callTool = this.callTool.bind(this)
|
this.callTool = this.callTool.bind(this)
|
||||||
@ -96,9 +94,26 @@ class McpService {
|
|||||||
this.cleanup = this.cleanup.bind(this)
|
this.cleanup = this.cleanup.bind(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getServerKey(server: MCPServer): string {
|
||||||
|
return JSON.stringify({
|
||||||
|
baseUrl: server.baseUrl,
|
||||||
|
command: server.command,
|
||||||
|
args: server.args,
|
||||||
|
registryUrl: server.registryUrl,
|
||||||
|
env: server.env,
|
||||||
|
id: server.id
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
async initClient(server: MCPServer): Promise<Client> {
|
async initClient(server: MCPServer): Promise<Client> {
|
||||||
const serverKey = this.getServerKey(server)
|
const serverKey = this.getServerKey(server)
|
||||||
|
|
||||||
|
// If there's a pending initialization, wait for it
|
||||||
|
const pendingClient = this.pendingClients.get(serverKey)
|
||||||
|
if (pendingClient) {
|
||||||
|
return pendingClient
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we already have a client for this server configuration
|
// Check if we already have a client for this server configuration
|
||||||
const existingClient = this.clients.get(serverKey)
|
const existingClient = this.clients.get(serverKey)
|
||||||
if (existingClient) {
|
if (existingClient) {
|
||||||
@ -113,209 +128,226 @@ class McpService {
|
|||||||
} else {
|
} else {
|
||||||
return existingClient
|
return existingClient
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error: any) {
|
||||||
Logger.error(`[MCP] Error pinging server ${server.name}:`, error)
|
Logger.error(`[MCP] Error pinging server ${server.name}:`, error?.message)
|
||||||
this.clients.delete(serverKey)
|
this.clients.delete(serverKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Create new client instance for each connection
|
|
||||||
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
|
|
||||||
|
|
||||||
const args = [...(server.args || [])]
|
// Create a promise for the initialization process
|
||||||
|
const initPromise = (async () => {
|
||||||
|
try {
|
||||||
|
// Create new client instance for each connection
|
||||||
|
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
|
||||||
|
|
||||||
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
const args = [...(server.args || [])]
|
||||||
const authProvider = new McpOAuthClientProvider({
|
|
||||||
serverUrlHash: crypto
|
|
||||||
.createHash('md5')
|
|
||||||
.update(server.baseUrl || '')
|
|
||||||
.digest('hex')
|
|
||||||
})
|
|
||||||
|
|
||||||
const initTransport = async (): Promise<
|
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||||
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
const authProvider = new McpOAuthClientProvider({
|
||||||
> => {
|
serverUrlHash: crypto
|
||||||
// Create appropriate transport based on configuration
|
.createHash('md5')
|
||||||
if (server.type === 'inMemory') {
|
.update(server.baseUrl || '')
|
||||||
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
|
.digest('hex')
|
||||||
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
|
})
|
||||||
// start the in-memory server with the given name and environment variables
|
|
||||||
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
|
|
||||||
try {
|
|
||||||
await inMemoryServer.connect(serverTransport)
|
|
||||||
Logger.info(`[MCP] In-memory server started: ${server.name}`)
|
|
||||||
} catch (error: Error | any) {
|
|
||||||
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
|
|
||||||
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
|
||||||
}
|
|
||||||
// set the client transport to the client
|
|
||||||
return clientTransport
|
|
||||||
} else if (server.baseUrl) {
|
|
||||||
if (server.type === 'streamableHttp') {
|
|
||||||
const options: StreamableHTTPClientTransportOptions = {
|
|
||||||
requestInit: {
|
|
||||||
headers: server.headers || {}
|
|
||||||
},
|
|
||||||
authProvider
|
|
||||||
}
|
|
||||||
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
|
||||||
} else if (server.type === 'sse') {
|
|
||||||
const options: SSEClientTransportOptions = {
|
|
||||||
eventSourceInit: {
|
|
||||||
fetch: async (url, init) => {
|
|
||||||
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
|
|
||||||
|
|
||||||
// Get tokens from authProvider to make sure using the latest tokens
|
const initTransport = async (): Promise<
|
||||||
if (authProvider && typeof authProvider.tokens === 'function') {
|
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||||
try {
|
> => {
|
||||||
const tokens = await authProvider.tokens()
|
// Create appropriate transport based on configuration
|
||||||
if (tokens && tokens.access_token) {
|
if (server.type === 'inMemory') {
|
||||||
headers['Authorization'] = `Bearer ${tokens.access_token}`
|
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
|
||||||
|
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
|
||||||
|
// start the in-memory server with the given name and environment variables
|
||||||
|
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
|
||||||
|
try {
|
||||||
|
await inMemoryServer.connect(serverTransport)
|
||||||
|
Logger.info(`[MCP] In-memory server started: ${server.name}`)
|
||||||
|
} catch (error: Error | any) {
|
||||||
|
Logger.error(`[MCP] Error starting in-memory server: ${error}`)
|
||||||
|
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
||||||
|
}
|
||||||
|
// set the client transport to the client
|
||||||
|
return clientTransport
|
||||||
|
} else if (server.baseUrl) {
|
||||||
|
if (server.type === 'streamableHttp') {
|
||||||
|
const options: StreamableHTTPClientTransportOptions = {
|
||||||
|
requestInit: {
|
||||||
|
headers: server.headers || {}
|
||||||
|
},
|
||||||
|
authProvider
|
||||||
|
}
|
||||||
|
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
||||||
|
} else if (server.type === 'sse') {
|
||||||
|
const options: SSEClientTransportOptions = {
|
||||||
|
eventSourceInit: {
|
||||||
|
fetch: async (url, init) => {
|
||||||
|
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
|
||||||
|
|
||||||
|
// Get tokens from authProvider to make sure using the latest tokens
|
||||||
|
if (authProvider && typeof authProvider.tokens === 'function') {
|
||||||
|
try {
|
||||||
|
const tokens = await authProvider.tokens()
|
||||||
|
if (tokens && tokens.access_token) {
|
||||||
|
headers['Authorization'] = `Bearer ${tokens.access_token}`
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error('Failed to fetch tokens:', error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
|
||||||
Logger.error('Failed to fetch tokens:', error)
|
return fetch(url, { ...init, headers })
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
requestInit: {
|
||||||
|
headers: server.headers || {}
|
||||||
|
},
|
||||||
|
authProvider
|
||||||
|
}
|
||||||
|
return new SSEClientTransport(new URL(server.baseUrl!), options)
|
||||||
|
} else {
|
||||||
|
throw new Error('Invalid server type')
|
||||||
|
}
|
||||||
|
} else if (server.command) {
|
||||||
|
let cmd = server.command
|
||||||
|
|
||||||
|
if (server.command === 'npx') {
|
||||||
|
cmd = await getBinaryPath('bun')
|
||||||
|
Logger.info(`[MCP] Using command: ${cmd}`)
|
||||||
|
|
||||||
|
// add -x to args if args exist
|
||||||
|
if (args && args.length > 0) {
|
||||||
|
if (!args.includes('-y')) {
|
||||||
|
args.unshift('-y')
|
||||||
|
}
|
||||||
|
if (!args.includes('x')) {
|
||||||
|
args.unshift('x')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (server.registryUrl) {
|
||||||
|
server.env = {
|
||||||
|
...server.env,
|
||||||
|
NPM_CONFIG_REGISTRY: server.registryUrl
|
||||||
}
|
}
|
||||||
|
|
||||||
return fetch(url, { ...init, headers })
|
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
|
||||||
|
if (server.name.includes('mcp-auto-install')) {
|
||||||
|
const binPath = await getBinaryPath()
|
||||||
|
makeSureDirExists(binPath)
|
||||||
|
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (server.command === 'uvx' || server.command === 'uv') {
|
||||||
|
cmd = await getBinaryPath(server.command)
|
||||||
|
if (server.registryUrl) {
|
||||||
|
server.env = {
|
||||||
|
...server.env,
|
||||||
|
UV_DEFAULT_INDEX: server.registryUrl,
|
||||||
|
PIP_INDEX_URL: server.registryUrl
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
requestInit: {
|
|
||||||
headers: server.headers || {}
|
|
||||||
},
|
|
||||||
authProvider
|
|
||||||
}
|
|
||||||
return new SSEClientTransport(new URL(server.baseUrl!), options)
|
|
||||||
} else {
|
|
||||||
throw new Error('Invalid server type')
|
|
||||||
}
|
|
||||||
} else if (server.command) {
|
|
||||||
let cmd = server.command
|
|
||||||
|
|
||||||
if (server.command === 'npx') {
|
|
||||||
cmd = await getBinaryPath('bun')
|
|
||||||
Logger.info(`[MCP] Using command: ${cmd}`)
|
|
||||||
|
|
||||||
// add -x to args if args exist
|
|
||||||
if (args && args.length > 0) {
|
|
||||||
if (!args.includes('-y')) {
|
|
||||||
args.unshift('-y')
|
|
||||||
}
|
|
||||||
if (!args.includes('x')) {
|
|
||||||
args.unshift('x')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (server.registryUrl) {
|
|
||||||
server.env = {
|
|
||||||
...server.env,
|
|
||||||
NPM_CONFIG_REGISTRY: server.registryUrl
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
|
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
|
||||||
if (server.name.includes('mcp-auto-install')) {
|
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||||
const binPath = await getBinaryPath()
|
const loginShellEnv = await this.getLoginShellEnv()
|
||||||
makeSureDirExists(binPath)
|
const stdioTransport = new StdioClientTransport({
|
||||||
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
|
command: cmd,
|
||||||
}
|
args,
|
||||||
}
|
env: {
|
||||||
} else if (server.command === 'uvx' || server.command === 'uv') {
|
...loginShellEnv,
|
||||||
cmd = await getBinaryPath(server.command)
|
...server.env
|
||||||
if (server.registryUrl) {
|
},
|
||||||
server.env = {
|
stderr: 'pipe'
|
||||||
...server.env,
|
})
|
||||||
UV_DEFAULT_INDEX: server.registryUrl,
|
stdioTransport.stderr?.on('data', (data) =>
|
||||||
PIP_INDEX_URL: server.registryUrl
|
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
|
||||||
}
|
)
|
||||||
|
return stdioTransport
|
||||||
|
} else {
|
||||||
|
throw new Error('Either baseUrl or command must be provided')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
|
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
|
||||||
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
|
||||||
const loginShellEnv = await this.getLoginShellEnv()
|
// Create an event emitter for the OAuth callback
|
||||||
const stdioTransport = new StdioClientTransport({
|
const events = new EventEmitter()
|
||||||
command: cmd,
|
|
||||||
args,
|
|
||||||
env: {
|
|
||||||
...loginShellEnv,
|
|
||||||
...server.env
|
|
||||||
},
|
|
||||||
stderr: 'pipe'
|
|
||||||
})
|
|
||||||
stdioTransport.stderr?.on('data', (data) =>
|
|
||||||
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
|
|
||||||
)
|
|
||||||
return stdioTransport
|
|
||||||
} else {
|
|
||||||
throw new Error('Either baseUrl or command must be provided')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
|
// Create a callback server
|
||||||
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
|
const callbackServer = new CallBackServer({
|
||||||
// Create an event emitter for the OAuth callback
|
port: authProvider.config.callbackPort,
|
||||||
const events = new EventEmitter()
|
path: authProvider.config.callbackPath || '/oauth/callback',
|
||||||
|
events
|
||||||
|
})
|
||||||
|
|
||||||
// Create a callback server
|
// Set a timeout to close the callback server
|
||||||
const callbackServer = new CallBackServer({
|
const timeoutId = setTimeout(() => {
|
||||||
port: authProvider.config.callbackPort,
|
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
|
||||||
path: authProvider.config.callbackPath || '/oauth/callback',
|
callbackServer.close()
|
||||||
events
|
}, 300000) // 5 minutes timeout
|
||||||
})
|
|
||||||
|
|
||||||
// Set a timeout to close the callback server
|
try {
|
||||||
const timeoutId = setTimeout(() => {
|
// Wait for the authorization code
|
||||||
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
|
const authCode = await callbackServer.waitForAuthCode()
|
||||||
callbackServer.close()
|
Logger.info(`[MCP] Received auth code: ${authCode}`)
|
||||||
}, 300000) // 5 minutes timeout
|
|
||||||
|
|
||||||
try {
|
// Complete the OAuth flow
|
||||||
// Wait for the authorization code
|
await transport.finishAuth(authCode)
|
||||||
const authCode = await callbackServer.waitForAuthCode()
|
|
||||||
Logger.info(`[MCP] Received auth code: ${authCode}`)
|
|
||||||
|
|
||||||
// Complete the OAuth flow
|
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
|
||||||
await transport.finishAuth(authCode)
|
|
||||||
|
|
||||||
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
|
const newTransport = await initTransport()
|
||||||
|
// Try to connect again
|
||||||
|
await client.connect(newTransport)
|
||||||
|
|
||||||
const newTransport = await initTransport()
|
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
|
||||||
// Try to connect again
|
} catch (oauthError) {
|
||||||
await client.connect(newTransport)
|
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
|
||||||
|
throw new Error(
|
||||||
|
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
|
||||||
|
)
|
||||||
|
} finally {
|
||||||
|
// Clear the timeout and close the callback server
|
||||||
|
clearTimeout(timeoutId)
|
||||||
|
callbackServer.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
|
try {
|
||||||
} catch (oauthError) {
|
const transport = await initTransport()
|
||||||
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
|
try {
|
||||||
throw new Error(
|
await client.connect(transport)
|
||||||
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
|
} catch (error: Error | any) {
|
||||||
)
|
if (
|
||||||
|
error instanceof Error &&
|
||||||
|
(error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))
|
||||||
|
) {
|
||||||
|
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
|
||||||
|
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
|
||||||
|
} else {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the new client in the cache
|
||||||
|
this.clients.set(serverKey, client)
|
||||||
|
|
||||||
|
Logger.info(`[MCP] Activated server: ${server.name}`)
|
||||||
|
return client
|
||||||
|
} catch (error: any) {
|
||||||
|
Logger.error(`[MCP] Error activating server ${server.name}:`, error?.message)
|
||||||
|
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
// Clear the timeout and close the callback server
|
// Clean up the pending promise when done
|
||||||
clearTimeout(timeoutId)
|
this.pendingClients.delete(serverKey)
|
||||||
callbackServer.close()
|
|
||||||
}
|
}
|
||||||
}
|
})()
|
||||||
|
|
||||||
try {
|
// Store the pending promise
|
||||||
const transport = await initTransport()
|
this.pendingClients.set(serverKey, initPromise)
|
||||||
try {
|
|
||||||
await client.connect(transport)
|
|
||||||
} catch (error: Error | any) {
|
|
||||||
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
|
|
||||||
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
|
|
||||||
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
|
|
||||||
} else {
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the new client in the cache
|
return initPromise
|
||||||
this.clients.set(serverKey, client)
|
|
||||||
|
|
||||||
Logger.info(`[MCP] Activated server: ${server.name}`)
|
|
||||||
return client
|
|
||||||
} catch (error: any) {
|
|
||||||
Logger.error(`[MCP] Error activating server ${server.name}:`, error)
|
|
||||||
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async closeClient(serverKey: string) {
|
async closeClient(serverKey: string) {
|
||||||
@ -357,8 +389,8 @@ class McpService {
|
|||||||
for (const [key] of this.clients) {
|
for (const [key] of this.clients) {
|
||||||
try {
|
try {
|
||||||
await this.closeClient(key)
|
await this.closeClient(key)
|
||||||
} catch (error) {
|
} catch (error: any) {
|
||||||
Logger.error(`[MCP] Failed to close client: ${error}`)
|
Logger.error(`[MCP] Failed to close client: ${error?.message}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -379,8 +411,8 @@ class McpService {
|
|||||||
serverTools.push(serverTool)
|
serverTools.push(serverTool)
|
||||||
})
|
})
|
||||||
return serverTools
|
return serverTools
|
||||||
} catch (error) {
|
} catch (error: any) {
|
||||||
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error)
|
Logger.error(`[MCP] Failed to list tools for server: ${server.name}`, error?.message)
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -439,8 +471,8 @@ class McpService {
|
|||||||
* List prompts available on an MCP server
|
* List prompts available on an MCP server
|
||||||
*/
|
*/
|
||||||
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
|
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
|
||||||
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
|
|
||||||
const client = await this.initClient(server)
|
const client = await this.initClient(server)
|
||||||
|
Logger.info(`[MCP] Listing prompts for server: ${server.name}`)
|
||||||
try {
|
try {
|
||||||
const { prompts } = await client.listPrompts()
|
const { prompts } = await client.listPrompts()
|
||||||
return prompts.map((prompt: any) => ({
|
return prompts.map((prompt: any) => ({
|
||||||
@ -449,8 +481,11 @@ class McpService {
|
|||||||
serverId: server.id,
|
serverId: server.id,
|
||||||
serverName: server.name
|
serverName: server.name
|
||||||
}))
|
}))
|
||||||
} catch (error) {
|
} catch (error: any) {
|
||||||
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error)
|
// -32601 is the code for the method not found
|
||||||
|
if (error?.code !== -32601) {
|
||||||
|
Logger.error(`[MCP] Failed to list prompts for server: ${server.name}`, error?.message)
|
||||||
|
}
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -508,8 +543,8 @@ class McpService {
|
|||||||
* List resources available on an MCP server (implementation)
|
* List resources available on an MCP server (implementation)
|
||||||
*/
|
*/
|
||||||
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
|
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
|
||||||
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
|
|
||||||
const client = await this.initClient(server)
|
const client = await this.initClient(server)
|
||||||
|
Logger.info(`[MCP] Listing resources for server: ${server.name}`)
|
||||||
try {
|
try {
|
||||||
const result = await client.listResources()
|
const result = await client.listResources()
|
||||||
const resources = result.resources || []
|
const resources = result.resources || []
|
||||||
@ -519,8 +554,11 @@ class McpService {
|
|||||||
serverName: server.name
|
serverName: server.name
|
||||||
}))
|
}))
|
||||||
return serverResources
|
return serverResources
|
||||||
} catch (error) {
|
} catch (error: any) {
|
||||||
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error)
|
// -32601 is the code for the method not found
|
||||||
|
if (error?.code !== -32601) {
|
||||||
|
Logger.error(`[MCP] Failed to list resources for server: ${server.name}`, error?.message)
|
||||||
|
}
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -563,7 +601,7 @@ class McpService {
|
|||||||
contents: contents
|
contents: contents
|
||||||
}
|
}
|
||||||
} catch (error: Error | any) {
|
} catch (error: Error | any) {
|
||||||
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error)
|
Logger.error(`[MCP] Failed to get resource ${uri} from server: ${server.name}`, error.message)
|
||||||
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
|
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -602,5 +640,13 @@ class McpService {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const mcpService = new McpService()
|
let mcpInstance: ReturnType<typeof McpService.getInstance> | null = null
|
||||||
export default mcpService
|
|
||||||
|
export const getMcpInstance = () => {
|
||||||
|
if (!mcpInstance) {
|
||||||
|
mcpInstance = McpService.getInstance()
|
||||||
|
}
|
||||||
|
return mcpInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
export default McpService.getInstance
|
||||||
|
|||||||
@ -446,7 +446,8 @@ export class WindowService {
|
|||||||
preload: join(__dirname, '../preload/index.js'),
|
preload: join(__dirname, '../preload/index.js'),
|
||||||
sandbox: false,
|
sandbox: false,
|
||||||
webSecurity: false,
|
webSecurity: false,
|
||||||
webviewTag: true
|
webviewTag: true,
|
||||||
|
backgroundThrottling: false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,7 @@ function getLoginShellEnvironment(): Promise<Record<string, string>> {
|
|||||||
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
|
commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command
|
||||||
}
|
}
|
||||||
|
|
||||||
Logger.log(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
|
Logger.log(`[ShellEnv] Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`)
|
||||||
|
|
||||||
const child = spawn(shellPath, commandArgs, {
|
const child = spawn(shellPath, commandArgs, {
|
||||||
cwd: homeDirectory, // Run the command in the user's home directory
|
cwd: homeDirectory, // Run the command in the user's home directory
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit'
|
||||||
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
||||||
import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp'
|
import { addMCPServer, deleteMCPServer, setMCPServers, updateMCPServer } from '@renderer/store/mcp'
|
||||||
import { MCPServer } from '@renderer/types'
|
import { MCPServer } from '@renderer/types'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { useMemo } from 'react'
|
|
||||||
|
|
||||||
const ipcRenderer = window.electron.ipcRenderer
|
const ipcRenderer = window.electron.ipcRenderer
|
||||||
|
|
||||||
@ -14,9 +14,14 @@ ipcRenderer.on(IpcChannel.Mcp_AddServer, (_event, server: MCPServer) => {
|
|||||||
store.dispatch(addMCPServer(server))
|
store.dispatch(addMCPServer(server))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const selectMcpServers = (state) => state.mcp.servers
|
||||||
|
const selectActiveMcpServers = createSelector([selectMcpServers], (servers) =>
|
||||||
|
servers.filter((server) => server.isActive)
|
||||||
|
)
|
||||||
|
|
||||||
export const useMCPServers = () => {
|
export const useMCPServers = () => {
|
||||||
const mcpServers = useAppSelector((state) => state.mcp.servers)
|
const mcpServers = useAppSelector(selectMcpServers)
|
||||||
const activedMcpServers = useMemo(() => mcpServers.filter((server) => server.isActive), [mcpServers])
|
const activedMcpServers = useAppSelector(selectActiveMcpServers)
|
||||||
const dispatch = useAppDispatch()
|
const dispatch = useAppDispatch()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
|
|||||||
import { useMCPServers } from '@renderer/hooks/useMCPServers'
|
import { useMCPServers } from '@renderer/hooks/useMCPServers'
|
||||||
import { EventEmitter } from '@renderer/services/EventService'
|
import { EventEmitter } from '@renderer/services/EventService'
|
||||||
import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
|
import { Assistant, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
|
||||||
|
import { delay, runAsyncFunction } from '@renderer/utils'
|
||||||
import { Form, Input, Tooltip } from 'antd'
|
import { Form, Input, Tooltip } from 'antd'
|
||||||
import { Plus, SquareTerminal } from 'lucide-react'
|
import { Plus, SquareTerminal } from 'lucide-react'
|
||||||
import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
|
import { FC, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'
|
||||||
@ -109,6 +110,11 @@ const extractPromptContent = (response: any): string | null => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add static variable before component definition
|
||||||
|
let isFirstResourcesListCall = true
|
||||||
|
let isFirstPromptListCall = true
|
||||||
|
const initMcpDelay = 3
|
||||||
|
|
||||||
const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => {
|
const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, ToolbarButton, ...props }) => {
|
||||||
const { activedMcpServers } = useMCPServers()
|
const { activedMcpServers } = useMCPServers()
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
@ -308,6 +314,11 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
|
|||||||
const promptList = useMemo(async () => {
|
const promptList = useMemo(async () => {
|
||||||
const prompts: MCPPrompt[] = []
|
const prompts: MCPPrompt[] = []
|
||||||
|
|
||||||
|
if (isFirstPromptListCall) {
|
||||||
|
await delay(initMcpDelay)
|
||||||
|
isFirstPromptListCall = false
|
||||||
|
}
|
||||||
|
|
||||||
for (const server of activedMcpServers) {
|
for (const server of activedMcpServers) {
|
||||||
const serverPrompts = await window.api.mcp.listPrompts(server)
|
const serverPrompts = await window.api.mcp.listPrompts(server)
|
||||||
prompts.push(...serverPrompts)
|
prompts.push(...serverPrompts)
|
||||||
@ -319,7 +330,8 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
|
|||||||
icon: <SquareTerminal />,
|
icon: <SquareTerminal />,
|
||||||
action: () => handlePromptSelect(prompt as MCPPromptWithArgs)
|
action: () => handlePromptSelect(prompt as MCPPromptWithArgs)
|
||||||
}))
|
}))
|
||||||
}, [handlePromptSelect, activedMcpServers])
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [activedMcpServers])
|
||||||
|
|
||||||
const openPromptList = useCallback(async () => {
|
const openPromptList = useCallback(async () => {
|
||||||
const prompts = await promptList
|
const prompts = await promptList
|
||||||
@ -380,33 +392,42 @@ const MCPToolsButton: FC<Props> = ({ ref, setInputValue, resizeTextArea, Toolbar
|
|||||||
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
|
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let isMounted = true
|
runAsyncFunction(async () => {
|
||||||
|
let isMounted = true
|
||||||
|
|
||||||
const fetchResources = async () => {
|
const fetchResources = async () => {
|
||||||
const resources: MCPResource[] = []
|
const resources: MCPResource[] = []
|
||||||
for (const server of activedMcpServers) {
|
|
||||||
const serverResources = await window.api.mcp.listResources(server)
|
for (const server of activedMcpServers) {
|
||||||
resources.push(...serverResources)
|
const serverResources = await window.api.mcp.listResources(server)
|
||||||
|
resources.push(...serverResources)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isMounted) {
|
||||||
|
setResourcesList(
|
||||||
|
resources.map((resource) => ({
|
||||||
|
label: resource.name,
|
||||||
|
description: resource.description,
|
||||||
|
icon: <SquareTerminal />,
|
||||||
|
action: () => handleResourceSelect(resource)
|
||||||
|
}))
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isMounted) {
|
// Avoid mcp following the software startup, affecting the startup speed
|
||||||
setResourcesList(
|
if (isFirstResourcesListCall) {
|
||||||
resources.map((resource) => ({
|
await delay(initMcpDelay)
|
||||||
label: resource.name,
|
isFirstResourcesListCall = false
|
||||||
description: resource.description,
|
fetchResources()
|
||||||
icon: <SquareTerminal />,
|
|
||||||
action: () => handleResourceSelect(resource)
|
|
||||||
}))
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fetchResources()
|
return () => {
|
||||||
|
isMounted = false
|
||||||
return () => {
|
}
|
||||||
isMounted = false
|
})
|
||||||
}
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [activedMcpServers, handleResourceSelect])
|
}, [activedMcpServers])
|
||||||
|
|
||||||
const openResourcesList = useCallback(async () => {
|
const openResourcesList = useCallback(async () => {
|
||||||
const resources = resourcesList
|
const resources = resourcesList
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user