mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-03 02:59:07 +08:00
* build: add @biomejs/biome as a dependency * chore: add biome extension to vscode recommendations * chore: migrate from prettier to biome for code formatting Update VSCode settings to use Biome as the default formatter for multiple languages Add Biome to code actions on save and reorder search exclude patterns * build: add biome.json configuration file for code formatting * build: migrate from prettier to biome for formatting Update package.json scripts and biome.json configuration to use biome instead of prettier for code formatting. Adjust biome formatter includes/excludes patterns for better file matching. * refactor(eslint): remove unused prettier config and imports * chore: update biome.json configuration - Enable linter and set custom rules - Change jsxQuoteStyle to single quotes - Add json parser configuration - Set formatWithErrors to true * chore: migrate biome config from json to jsonc format The new jsonc format allows for comments in the configuration file, making it more maintainable and easier to document configuration choices. * style(biome): update ignore patterns and jsx quote style Update file ignore patterns from `/*` to `/**` for consistency Change jsxQuoteStyle from single to double quotes for alignment with project standards * refactor: simplify error type annotations from Error | any to any The change standardizes error handling by using 'any' type instead of union types with Error | any, making the code more consistent and reducing unnecessary type complexity. * chore: exclude tailwind.css from biome formatting * style: standardize quote usage and fix JSX formatting - Replace single quotes with double quotes in CSS imports and selectors - Fix JSX element closing bracket alignment and formatting - Standardize JSON formatting in package.json files * Revert "style: standardize quote usage and fix JSX formatting" This reverts commit0947f8505d. * fix: remove json import assertion for biome compatibility The import assertion syntax is not supported by biome, so it was replaced with a standard import statement. * style: change quote styles in biome.jsonc to use single quotes for JSX and double quotes for JS * style: change quote style from double to single in biome config * style: change JSX quote style from single to double * chore: update biome.jsonc to use single quotes for CSS formatting * chore: update biome config and format commands - Exclude tailwind.css from linter includes - Add biome lint to format commands * style: format JSX closing brackets for better readability * style: set bracketSameLine to true in biome config The change aligns with common JSX formatting preferences where brackets on the same line improve readability for many developers * Revert "style: format JSX closing brackets for better readability" This reverts commitd442c934ee. * style: format code and clean up whitespace - Remove unnecessary whitespace in CSS and TS files - Format package.json files to consistent style - Reorder tsconfig.json properties alphabetically - Improve code formatting in React components * style(biome): update biome.jsonc config with clearer comment Add explanation for keeping bracketSameLine as true to minimize changes in current PR while noting false would be better for future * chore: remove prettier dependency and format package.json files - Remove prettier from dependencies as it's no longer needed - Reformat package.json files for better readability * chore: replace prettier with biome for code formatting Remove all prettier-related configuration, dependencies, and references Update formatting scripts and documentation to use biome instead Adjust electron-builder config to exclude biome.jsonc * build: replace prettier with biome for formatting Use biome as the default formatter instead of prettier for better performance and modern tooling support * ci(i18n): replace prettier with biome for i18n formatting Update the auto-i18n workflow to use Biome instead of Prettier for formatting translated files. This change simplifies the dependencies by removing multiple Prettier plugins and using a single tool for formatting. * fix(i18n): Auto update translations for PR #10170 * style: format package.json files by consolidating array formatting Consolidate multi-line array formatting into single-line format for better readability and consistency across package.json files * Revert "fix(i18n): Auto update translations for PR #10170" This reverts commita7edd32efd. * ci(workflows): specify biome config path in auto-i18n workflow * chore: update biome.jsonc to use lexicographic sort order for json keys * ci(workflows): update biome format command to use --config-path flag * chore: exclude package.json from biome formatting * ci: update biome.jsonc linter configuration Update linter includes to target specific files and modify useSortedClasses rule * chore: reorder search exclude patterns in vscode settings * style(OGCard): reorder tailwind classes for consistent styling * fix(biome): update tailwind classes sorting to safe and warn level * docs(dev): update ide setup instructions in dev docs Replace Prettier with Biome as the recommended formatter and clarify editor options * build(extension-table-plus): replace prettier with biome for formatting - Add biome.json configuration file - Update package.json to use biome instead of prettier - Remove prettier from dependencies - Update lint script to use biome format * chore: replace biome.json with biome.jsonc for extended configuration Update biome configuration file to use JSONC format for comments and more detailed settings * chore: remove unused biome.jsonc configuration file --------- Co-authored-by: GitHub Action <action@github.com>
936 lines
34 KiB
TypeScript
936 lines
34 KiB
TypeScript
import crypto from 'node:crypto'
|
|
import os from 'node:os'
|
|
import path from 'node:path'
|
|
|
|
import { loggerService } from '@logger'
|
|
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
|
|
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
|
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
|
import { getBinaryName, getBinaryPath } from '@main/utils/process'
|
|
import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core'
|
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
|
import { SSEClientTransport, SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js'
|
|
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'
|
|
import {
|
|
StreamableHTTPClientTransport,
|
|
type StreamableHTTPClientTransportOptions
|
|
} from '@modelcontextprotocol/sdk/client/streamableHttp'
|
|
import { InMemoryTransport } from '@modelcontextprotocol/sdk/inMemory'
|
|
import { McpError, type Tool as SDKTool } from '@modelcontextprotocol/sdk/types'
|
|
// Import notification schemas from MCP SDK
|
|
import {
|
|
CancelledNotificationSchema,
|
|
type GetPromptResult,
|
|
LoggingMessageNotificationSchema,
|
|
PromptListChangedNotificationSchema,
|
|
ResourceListChangedNotificationSchema,
|
|
ResourceUpdatedNotificationSchema,
|
|
ToolListChangedNotificationSchema
|
|
} from '@modelcontextprotocol/sdk/types.js'
|
|
import { nanoid } from '@reduxjs/toolkit'
|
|
import { MCPProgressEvent } from '@shared/config/types'
|
|
import { IpcChannel } from '@shared/IpcChannel'
|
|
import { defaultAppHeaders } from '@shared/utils'
|
|
import {
|
|
BuiltinMCPServerNames,
|
|
type GetResourceResponse,
|
|
isBuiltinMCPServer,
|
|
type MCPCallToolResponse,
|
|
type MCPPrompt,
|
|
type MCPResource,
|
|
type MCPServer,
|
|
type MCPTool
|
|
} from '@types'
|
|
import { app, net } from 'electron'
|
|
import { EventEmitter } from 'events'
|
|
import { memoize } from 'lodash'
|
|
import { v4 as uuidv4 } from 'uuid'
|
|
|
|
import { CacheService } from './CacheService'
|
|
import DxtService from './DxtService'
|
|
import { CallBackServer } from './mcp/oauth/callback'
|
|
import { McpOAuthClientProvider } from './mcp/oauth/provider'
|
|
import getLoginShellEnvironment from './mcp/shell-env'
|
|
import { windowService } from './WindowService'
|
|
|
|
// Generic type for caching wrapped functions
|
|
type CachedFunction<T extends unknown[], R> = (...args: T) => Promise<R>
|
|
|
|
type CallToolArgs = { server: MCPServer; name: string; args: any; callId?: string }
|
|
|
|
const logger = loggerService.withContext('MCPService')
|
|
|
|
// Redact potentially sensitive fields in objects (headers, tokens, api keys)
|
|
function redactSensitive(input: any): any {
|
|
const SENSITIVE_KEYS = ['authorization', 'Authorization', 'apiKey', 'api_key', 'apikey', 'token', 'access_token']
|
|
const MAX_STRING = 300
|
|
|
|
const redact = (val: any): any => {
|
|
if (val == null) return val
|
|
if (typeof val === 'string') {
|
|
return val.length > MAX_STRING ? `${val.slice(0, MAX_STRING)}…<${val.length - MAX_STRING} more>` : val
|
|
}
|
|
if (Array.isArray(val)) return val.map((v) => redact(v))
|
|
if (typeof val === 'object') {
|
|
const out: Record<string, any> = {}
|
|
for (const [k, v] of Object.entries(val)) {
|
|
if (SENSITIVE_KEYS.includes(k)) {
|
|
out[k] = '<redacted>'
|
|
} else {
|
|
out[k] = redact(v)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
return val
|
|
}
|
|
|
|
return redact(input)
|
|
}
|
|
|
|
// Create a context-aware logger for a server
|
|
function getServerLogger(server: MCPServer, extra?: Record<string, any>) {
|
|
const base = {
|
|
serverName: server?.name,
|
|
serverId: server?.id,
|
|
baseUrl: server?.baseUrl,
|
|
type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory')
|
|
}
|
|
return loggerService.withContext('MCPService', { ...base, ...extra })
|
|
}
|
|
|
|
/**
|
|
* Higher-order function to add caching capability to any async function
|
|
* @param fn The original function to be wrapped with caching
|
|
* @param getCacheKey Function to generate a cache key from the function arguments
|
|
* @param ttl Time to live for the cache entry in milliseconds
|
|
* @param logPrefix Prefix for log messages
|
|
* @returns The wrapped function with caching capability
|
|
*/
|
|
function withCache<T extends unknown[], R>(
|
|
fn: (...args: T) => Promise<R>,
|
|
getCacheKey: (...args: T) => string,
|
|
ttl: number,
|
|
logPrefix: string
|
|
): CachedFunction<T, R> {
|
|
return async (...args: T): Promise<R> => {
|
|
const cacheKey = getCacheKey(...args)
|
|
|
|
if (CacheService.has(cacheKey)) {
|
|
logger.debug(`${logPrefix} loaded from cache`, { cacheKey })
|
|
const cachedData = CacheService.get<R>(cacheKey)
|
|
if (cachedData) {
|
|
return cachedData
|
|
}
|
|
}
|
|
|
|
const start = Date.now()
|
|
const result = await fn(...args)
|
|
CacheService.set(cacheKey, result, ttl)
|
|
logger.debug(`${logPrefix} cached`, { cacheKey, ttlMs: ttl, durationMs: Date.now() - start })
|
|
return result
|
|
}
|
|
}
|
|
|
|
class McpService {
|
|
private clients: Map<string, Client> = new Map()
|
|
private pendingClients: Map<string, Promise<Client>> = new Map()
|
|
private dxtService = new DxtService()
|
|
private activeToolCalls: Map<string, AbortController> = new Map()
|
|
|
|
constructor() {
|
|
this.initClient = this.initClient.bind(this)
|
|
this.listTools = this.listTools.bind(this)
|
|
this.callTool = this.callTool.bind(this)
|
|
this.listPrompts = this.listPrompts.bind(this)
|
|
this.getPrompt = this.getPrompt.bind(this)
|
|
this.listResources = this.listResources.bind(this)
|
|
this.getResource = this.getResource.bind(this)
|
|
this.closeClient = this.closeClient.bind(this)
|
|
this.removeServer = this.removeServer.bind(this)
|
|
this.restartServer = this.restartServer.bind(this)
|
|
this.stopServer = this.stopServer.bind(this)
|
|
this.abortTool = this.abortTool.bind(this)
|
|
this.cleanup = this.cleanup.bind(this)
|
|
this.checkMcpConnectivity = this.checkMcpConnectivity.bind(this)
|
|
this.getServerVersion = this.getServerVersion.bind(this)
|
|
}
|
|
|
|
private getServerKey(server: MCPServer): string {
|
|
return JSON.stringify({
|
|
baseUrl: server.baseUrl,
|
|
command: server.command,
|
|
args: Array.isArray(server.args) ? server.args : [],
|
|
registryUrl: server.registryUrl,
|
|
env: server.env,
|
|
id: server.id
|
|
})
|
|
}
|
|
|
|
async initClient(server: MCPServer): Promise<Client> {
|
|
const serverKey = this.getServerKey(server)
|
|
|
|
// If there's a pending initialization, wait for it
|
|
const pendingClient = this.pendingClients.get(serverKey)
|
|
if (pendingClient) {
|
|
getServerLogger(server).silly(`Waiting for pending client initialization`)
|
|
return pendingClient
|
|
}
|
|
|
|
// Check if we already have a client for this server configuration
|
|
const existingClient = this.clients.get(serverKey)
|
|
if (existingClient) {
|
|
try {
|
|
// Check if the existing client is still connected
|
|
const pingResult = await existingClient.ping({
|
|
// add short timeout to prevent hanging
|
|
timeout: 1000
|
|
})
|
|
getServerLogger(server).debug(`Ping result`, { ok: !!pingResult })
|
|
// If the ping fails, remove the client from the cache
|
|
// and create a new one
|
|
if (!pingResult) {
|
|
this.clients.delete(serverKey)
|
|
} else {
|
|
return existingClient
|
|
}
|
|
} catch (error: any) {
|
|
getServerLogger(server).error(`Error pinging server ${server.name}`, error as Error)
|
|
this.clients.delete(serverKey)
|
|
}
|
|
}
|
|
|
|
const prepareHeaders = () => {
|
|
return {
|
|
...defaultAppHeaders(),
|
|
...server.headers
|
|
}
|
|
}
|
|
|
|
// Create a promise for the initialization process
|
|
const initPromise = (async () => {
|
|
try {
|
|
// Create new client instance for each connection
|
|
const client = new Client({ name: 'Cherry Studio', version: app.getVersion() }, { capabilities: {} })
|
|
|
|
let args = [...(server.args || [])]
|
|
|
|
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
|
const authProvider = new McpOAuthClientProvider({
|
|
serverUrlHash: crypto
|
|
.createHash('md5')
|
|
.update(server.baseUrl || '')
|
|
.digest('hex')
|
|
})
|
|
|
|
const initTransport = async (): Promise<
|
|
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
|
> => {
|
|
// Create appropriate transport based on configuration
|
|
if (isBuiltinMCPServer(server) && server.name !== BuiltinMCPServerNames.mcpAutoInstall) {
|
|
getServerLogger(server).debug(`Using in-memory transport`)
|
|
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair()
|
|
// start the in-memory server with the given name and environment variables
|
|
const inMemoryServer = createInMemoryMCPServer(server.name, args, server.env || {})
|
|
try {
|
|
await inMemoryServer.connect(serverTransport)
|
|
getServerLogger(server).debug(`In-memory server started`)
|
|
} catch (error: any) {
|
|
getServerLogger(server).error(`Error starting in-memory server`, error as Error)
|
|
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
|
}
|
|
// set the client transport to the client
|
|
return clientTransport
|
|
} else if (server.baseUrl) {
|
|
if (server.type === 'streamableHttp') {
|
|
const options: StreamableHTTPClientTransportOptions = {
|
|
fetch: async (url, init) => {
|
|
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
|
|
},
|
|
requestInit: {
|
|
headers: prepareHeaders()
|
|
},
|
|
authProvider
|
|
}
|
|
// redact headers before logging
|
|
getServerLogger(server).debug(`StreamableHTTPClientTransport options`, {
|
|
options: redactSensitive(options)
|
|
})
|
|
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
|
} else if (server.type === 'sse') {
|
|
const options: SSEClientTransportOptions = {
|
|
eventSourceInit: {
|
|
fetch: async (url, init) => {
|
|
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
|
|
}
|
|
},
|
|
requestInit: {
|
|
headers: prepareHeaders()
|
|
},
|
|
authProvider
|
|
}
|
|
return new SSEClientTransport(new URL(server.baseUrl!), options)
|
|
} else {
|
|
throw new Error('Invalid server type')
|
|
}
|
|
} else if (server.command) {
|
|
let cmd = server.command
|
|
|
|
// For DXT servers, use resolved configuration with platform overrides and variable substitution
|
|
if (server.dxtPath) {
|
|
const resolvedConfig = this.dxtService.getResolvedMcpConfig(server.dxtPath)
|
|
if (resolvedConfig) {
|
|
cmd = resolvedConfig.command
|
|
args = resolvedConfig.args
|
|
// Merge resolved environment variables with existing ones
|
|
server.env = {
|
|
...server.env,
|
|
...resolvedConfig.env
|
|
}
|
|
getServerLogger(server).debug(`Using resolved DXT config`, {
|
|
command: cmd,
|
|
args
|
|
})
|
|
} else {
|
|
getServerLogger(server).warn(`Failed to resolve DXT config, falling back to manifest values`)
|
|
}
|
|
}
|
|
|
|
if (server.command === 'npx') {
|
|
cmd = await getBinaryPath('bun')
|
|
getServerLogger(server).debug(`Using command`, { command: cmd })
|
|
|
|
// add -x to args if args exist
|
|
if (args && args.length > 0) {
|
|
if (!args.includes('-y')) {
|
|
args.unshift('-y')
|
|
}
|
|
if (!args.includes('x')) {
|
|
args.unshift('x')
|
|
}
|
|
}
|
|
if (server.registryUrl) {
|
|
server.env = {
|
|
...server.env,
|
|
NPM_CONFIG_REGISTRY: server.registryUrl
|
|
}
|
|
|
|
// if the server name is mcp-auto-install, use the mcp-registry.json file in the bin directory
|
|
if (server.name.includes('mcp-auto-install')) {
|
|
const binPath = await getBinaryPath()
|
|
makeSureDirExists(binPath)
|
|
server.env.MCP_REGISTRY_PATH = path.join(binPath, '..', 'config', 'mcp-registry.json')
|
|
}
|
|
}
|
|
} else if (server.command === 'uvx' || server.command === 'uv') {
|
|
cmd = await getBinaryPath(server.command)
|
|
if (server.registryUrl) {
|
|
server.env = {
|
|
...server.env,
|
|
UV_DEFAULT_INDEX: server.registryUrl,
|
|
PIP_INDEX_URL: server.registryUrl
|
|
}
|
|
}
|
|
}
|
|
|
|
getServerLogger(server).debug(`Starting server`, { command: cmd, args })
|
|
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
|
const loginShellEnv = await this.getLoginShellEnv()
|
|
|
|
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
|
if (cmd.includes('bun')) {
|
|
removeEnvProxy(loginShellEnv)
|
|
}
|
|
|
|
const transportOptions: any = {
|
|
command: cmd,
|
|
args,
|
|
env: {
|
|
...loginShellEnv,
|
|
...server.env
|
|
},
|
|
stderr: 'pipe'
|
|
}
|
|
|
|
// For DXT servers, set the working directory to the extracted path
|
|
if (server.dxtPath) {
|
|
transportOptions.cwd = server.dxtPath
|
|
getServerLogger(server).debug(`Setting working directory for DXT server`, {
|
|
cwd: server.dxtPath
|
|
})
|
|
}
|
|
|
|
const stdioTransport = new StdioClientTransport(transportOptions)
|
|
stdioTransport.stderr?.on('data', (data) =>
|
|
getServerLogger(server).debug(`Stdio stderr`, { data: data.toString() })
|
|
)
|
|
return stdioTransport
|
|
} else {
|
|
throw new Error('Either baseUrl or command must be provided')
|
|
}
|
|
}
|
|
|
|
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
|
|
getServerLogger(server).debug(`Starting OAuth flow`)
|
|
// Create an event emitter for the OAuth callback
|
|
const events = new EventEmitter()
|
|
|
|
// Create a callback server
|
|
const callbackServer = new CallBackServer({
|
|
port: authProvider.config.callbackPort,
|
|
path: authProvider.config.callbackPath || '/oauth/callback',
|
|
events
|
|
})
|
|
|
|
// Set a timeout to close the callback server
|
|
const timeoutId = setTimeout(() => {
|
|
getServerLogger(server).warn(`OAuth flow timed out`)
|
|
callbackServer.close()
|
|
}, 300000) // 5 minutes timeout
|
|
|
|
try {
|
|
// Wait for the authorization code
|
|
const authCode = await callbackServer.waitForAuthCode()
|
|
getServerLogger(server).debug(`Received auth code`)
|
|
|
|
// Complete the OAuth flow
|
|
await transport.finishAuth(authCode)
|
|
|
|
getServerLogger(server).debug(`OAuth flow completed`)
|
|
|
|
const newTransport = await initTransport()
|
|
// Try to connect again
|
|
await client.connect(newTransport)
|
|
|
|
getServerLogger(server).debug(`Successfully authenticated`)
|
|
} catch (oauthError) {
|
|
getServerLogger(server).error(`OAuth authentication failed`, oauthError as Error)
|
|
throw new Error(
|
|
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
|
|
)
|
|
} finally {
|
|
// Clear the timeout and close the callback server
|
|
clearTimeout(timeoutId)
|
|
callbackServer.close()
|
|
}
|
|
}
|
|
|
|
try {
|
|
const transport = await initTransport()
|
|
try {
|
|
await client.connect(transport)
|
|
} catch (error: any) {
|
|
if (
|
|
error instanceof Error &&
|
|
(error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))
|
|
) {
|
|
logger.debug(`Authentication required for server: ${server.name}`)
|
|
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
|
|
} else {
|
|
throw error
|
|
}
|
|
}
|
|
|
|
// Store the new client in the cache
|
|
this.clients.set(serverKey, client)
|
|
|
|
// Set up notification handlers
|
|
this.setupNotificationHandlers(client, server)
|
|
|
|
// Clear existing cache to ensure fresh data
|
|
this.clearServerCache(serverKey)
|
|
|
|
logger.debug(`Activated server: ${server.name}`)
|
|
return client
|
|
} catch (error) {
|
|
getServerLogger(server).error(`Error activating server ${server.name}`, error as Error)
|
|
throw error
|
|
}
|
|
} finally {
|
|
// Clean up the pending promise when done
|
|
this.pendingClients.delete(serverKey)
|
|
}
|
|
})()
|
|
|
|
// Store the pending promise
|
|
this.pendingClients.set(serverKey, initPromise)
|
|
|
|
return initPromise
|
|
}
|
|
|
|
/**
|
|
* Set up notification handlers for MCP client
|
|
*/
|
|
private setupNotificationHandlers(client: Client, server: MCPServer) {
|
|
const serverKey = this.getServerKey(server)
|
|
|
|
try {
|
|
// Set up tools list changed notification handler
|
|
client.setNotificationHandler(ToolListChangedNotificationSchema, async () => {
|
|
logger.debug(`Tools list changed for server: ${server.name}`)
|
|
// Clear tools cache
|
|
CacheService.remove(`mcp:list_tool:${serverKey}`)
|
|
})
|
|
|
|
// Set up resources list changed notification handler
|
|
client.setNotificationHandler(ResourceListChangedNotificationSchema, async () => {
|
|
logger.debug(`Resources list changed for server: ${server.name}`)
|
|
// Clear resources cache
|
|
CacheService.remove(`mcp:list_resources:${serverKey}`)
|
|
})
|
|
|
|
// Set up prompts list changed notification handler
|
|
client.setNotificationHandler(PromptListChangedNotificationSchema, async () => {
|
|
logger.debug(`Prompts list changed for server: ${server.name}`)
|
|
// Clear prompts cache
|
|
CacheService.remove(`mcp:list_prompts:${serverKey}`)
|
|
})
|
|
|
|
// Set up resource updated notification handler
|
|
client.setNotificationHandler(ResourceUpdatedNotificationSchema, async () => {
|
|
logger.debug(`Resource updated for server: ${server.name}`)
|
|
// Clear resource-specific caches
|
|
this.clearResourceCaches(serverKey)
|
|
})
|
|
|
|
// Set up cancelled notification handler
|
|
client.setNotificationHandler(CancelledNotificationSchema, async (notification) => {
|
|
logger.debug(`Operation cancelled for server: ${server.name}`, notification.params)
|
|
})
|
|
|
|
// Set up logging message notification handler
|
|
client.setNotificationHandler(LoggingMessageNotificationSchema, async (notification) => {
|
|
logger.debug(`Message from server ${server.name}:`, notification.params)
|
|
})
|
|
|
|
getServerLogger(server).debug(`Set up notification handlers`)
|
|
} catch (error) {
|
|
getServerLogger(server).error(`Failed to set up notification handlers`, error as Error)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Clear resource-specific caches for a server
|
|
*/
|
|
private clearResourceCaches(serverKey: string) {
|
|
CacheService.remove(`mcp:list_resources:${serverKey}`)
|
|
}
|
|
|
|
/**
|
|
* Clear all caches for a specific server
|
|
*/
|
|
private clearServerCache(serverKey: string) {
|
|
CacheService.remove(`mcp:list_tool:${serverKey}`)
|
|
CacheService.remove(`mcp:list_prompts:${serverKey}`)
|
|
CacheService.remove(`mcp:list_resources:${serverKey}`)
|
|
logger.debug(`Cleared all caches for server`, { serverKey })
|
|
}
|
|
|
|
async closeClient(serverKey: string) {
|
|
const client = this.clients.get(serverKey)
|
|
if (client) {
|
|
// Remove the client from the cache
|
|
await client.close()
|
|
logger.debug(`Closed server`, { serverKey })
|
|
this.clients.delete(serverKey)
|
|
// Clear all caches for this server
|
|
this.clearServerCache(serverKey)
|
|
} else {
|
|
logger.warn(`No client found for server`, { serverKey })
|
|
}
|
|
}
|
|
|
|
async stopServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) {
|
|
const serverKey = this.getServerKey(server)
|
|
getServerLogger(server).debug(`Stopping server`)
|
|
await this.closeClient(serverKey)
|
|
}
|
|
|
|
async removeServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) {
|
|
const serverKey = this.getServerKey(server)
|
|
const existingClient = this.clients.get(serverKey)
|
|
if (existingClient) {
|
|
await this.closeClient(serverKey)
|
|
}
|
|
|
|
// If this is a DXT server, cleanup its directory
|
|
if (server.dxtPath) {
|
|
try {
|
|
const cleaned = this.dxtService.cleanupDxtServer(server.name)
|
|
if (cleaned) {
|
|
getServerLogger(server).debug(`Cleaned up DXT server directory`)
|
|
}
|
|
} catch (error) {
|
|
getServerLogger(server).error(`Failed to cleanup DXT server`, error as Error)
|
|
}
|
|
}
|
|
}
|
|
|
|
async restartServer(_: Electron.IpcMainInvokeEvent, server: MCPServer) {
|
|
getServerLogger(server).debug(`Restarting server`)
|
|
const serverKey = this.getServerKey(server)
|
|
await this.closeClient(serverKey)
|
|
// Clear cache before restarting to ensure fresh data
|
|
this.clearServerCache(serverKey)
|
|
await this.initClient(server)
|
|
}
|
|
|
|
async cleanup() {
|
|
for (const [key] of this.clients) {
|
|
try {
|
|
await this.closeClient(key)
|
|
} catch (error: any) {
|
|
logger.error(`Failed to close client`, error as Error)
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Check connectivity for an MCP server
|
|
*/
|
|
public async checkMcpConnectivity(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise<boolean> {
|
|
getServerLogger(server).debug(`Checking connectivity`)
|
|
try {
|
|
getServerLogger(server).debug(`About to call initClient`, { hasInitClient: !!this.initClient })
|
|
|
|
if (!this.initClient) {
|
|
throw new Error('initClient method is not available')
|
|
}
|
|
|
|
const client = await this.initClient(server)
|
|
// Attempt to list tools as a way to check connectivity
|
|
await client.listTools()
|
|
getServerLogger(server).debug(`Connectivity check successful`)
|
|
return true
|
|
} catch (error) {
|
|
getServerLogger(server).error(`Connectivity check failed`, error as Error)
|
|
// Close the client if connectivity check fails to ensure a clean state for the next attempt
|
|
const serverKey = this.getServerKey(server)
|
|
await this.closeClient(serverKey)
|
|
return false
|
|
}
|
|
}
|
|
|
|
private async listToolsImpl(server: MCPServer): Promise<MCPTool[]> {
|
|
const client = await this.initClient(server)
|
|
try {
|
|
const { tools } = await client.listTools()
|
|
const serverTools: MCPTool[] = []
|
|
tools.map((tool: SDKTool) => {
|
|
const serverTool: MCPTool = {
|
|
...tool,
|
|
id: buildFunctionCallToolName(server.name, tool.name),
|
|
serverId: server.id,
|
|
serverName: server.name,
|
|
type: 'mcp'
|
|
}
|
|
serverTools.push(serverTool)
|
|
getServerLogger(server).debug(`Listing tools`, { tool: serverTool })
|
|
})
|
|
return serverTools
|
|
} catch (error: unknown) {
|
|
getServerLogger(server).error(`Failed to list tools`, error as Error)
|
|
throw error
|
|
}
|
|
}
|
|
|
|
async listTools(_: Electron.IpcMainInvokeEvent, server: MCPServer) {
|
|
const listFunc = (server: MCPServer) => {
|
|
const cachedListTools = withCache<[MCPServer], MCPTool[]>(
|
|
this.listToolsImpl.bind(this),
|
|
(server) => {
|
|
const serverKey = this.getServerKey(server)
|
|
return `mcp:list_tool:${serverKey}`
|
|
},
|
|
5 * 60 * 1000, // 5 minutes TTL
|
|
`[MCP] Tools from ${server.name}`
|
|
)
|
|
|
|
const result = cachedListTools(server)
|
|
return result
|
|
}
|
|
|
|
return withSpanFunc(`${server.name}.ListTool`, 'MCP', listFunc, [server])
|
|
}
|
|
|
|
/**
|
|
* Call a tool on an MCP server
|
|
*/
|
|
public async callTool(
|
|
_: Electron.IpcMainInvokeEvent,
|
|
{ server, name, args, callId }: CallToolArgs
|
|
): Promise<MCPCallToolResponse> {
|
|
const toolCallId = callId || uuidv4()
|
|
const abortController = new AbortController()
|
|
this.activeToolCalls.set(toolCallId, abortController)
|
|
|
|
const callToolFunc = async ({ server, name, args }: CallToolArgs) => {
|
|
try {
|
|
getServerLogger(server, { tool: name, callId: toolCallId }).debug(`Calling tool`, {
|
|
args: redactSensitive(args)
|
|
})
|
|
if (typeof args === 'string') {
|
|
try {
|
|
args = JSON.parse(args)
|
|
} catch (e) {
|
|
getServerLogger(server, { tool: name, callId: toolCallId }).error('args parse error', e as Error, {
|
|
args
|
|
})
|
|
}
|
|
if (args === '') {
|
|
args = {}
|
|
}
|
|
}
|
|
const client = await this.initClient(server)
|
|
const result = await client.callTool({ name, arguments: args }, undefined, {
|
|
onprogress: (process) => {
|
|
getServerLogger(server, { tool: name, callId: toolCallId }).debug(`Progress`, {
|
|
ratio: process.progress / (process.total || 1)
|
|
})
|
|
const mainWindow = windowService.getMainWindow()
|
|
if (mainWindow) {
|
|
mainWindow.webContents.send(IpcChannel.Mcp_Progress, {
|
|
callId: toolCallId,
|
|
progress: process.progress / (process.total || 1)
|
|
} as MCPProgressEvent)
|
|
}
|
|
},
|
|
timeout: server.timeout ? server.timeout * 1000 : 60000, // Default timeout of 1 minute,
|
|
// 需要服务端支持: https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#timeouts
|
|
// Need server side support: https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#timeouts
|
|
resetTimeoutOnProgress: server.longRunning,
|
|
maxTotalTimeout: server.longRunning ? 10 * 60 * 1000 : undefined,
|
|
signal: this.activeToolCalls.get(toolCallId)?.signal
|
|
})
|
|
return result as MCPCallToolResponse
|
|
} catch (error) {
|
|
getServerLogger(server, { tool: name, callId: toolCallId }).error(`Error calling tool`, error as Error)
|
|
throw error
|
|
} finally {
|
|
this.activeToolCalls.delete(toolCallId)
|
|
}
|
|
}
|
|
|
|
return await withSpanFunc(`${server.name}.${name}`, `MCP`, callToolFunc, [{ server, name, args }])
|
|
}
|
|
|
|
public async getInstallInfo() {
|
|
const dir = path.join(os.homedir(), '.cherrystudio', 'bin')
|
|
const uvName = await getBinaryName('uv')
|
|
const bunName = await getBinaryName('bun')
|
|
const uvPath = path.join(dir, uvName)
|
|
const bunPath = path.join(dir, bunName)
|
|
return { dir, uvPath, bunPath }
|
|
}
|
|
|
|
/**
|
|
* List prompts available on an MCP server
|
|
*/
|
|
private async listPromptsImpl(server: MCPServer): Promise<MCPPrompt[]> {
|
|
const client = await this.initClient(server)
|
|
getServerLogger(server).debug(`Listing prompts`)
|
|
try {
|
|
const { prompts } = await client.listPrompts()
|
|
return prompts.map((prompt: any) => ({
|
|
...prompt,
|
|
id: `p${nanoid()}`,
|
|
serverId: server.id,
|
|
serverName: server.name
|
|
}))
|
|
} catch (error: unknown) {
|
|
// -32601 is the code for the method not found
|
|
if (error instanceof McpError && error.code !== -32601) {
|
|
getServerLogger(server).error(`Failed to list prompts`, error as Error)
|
|
}
|
|
return []
|
|
}
|
|
}
|
|
|
|
/**
|
|
* List prompts available on an MCP server with caching
|
|
*/
|
|
public async listPrompts(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise<MCPPrompt[]> {
|
|
const cachedListPrompts = withCache<[MCPServer], MCPPrompt[]>(
|
|
this.listPromptsImpl.bind(this),
|
|
(server) => {
|
|
const serverKey = this.getServerKey(server)
|
|
return `mcp:list_prompts:${serverKey}`
|
|
},
|
|
60 * 60 * 1000, // 60 minutes TTL
|
|
`[MCP] Prompts from ${server.name}`
|
|
)
|
|
return cachedListPrompts(server)
|
|
}
|
|
|
|
/**
|
|
* Get a specific prompt from an MCP server (implementation)
|
|
*/
|
|
private async getPromptImpl(server: MCPServer, name: string, args?: Record<string, any>): Promise<GetPromptResult> {
|
|
logger.debug(`Getting prompt ${name} from server: ${server.name}`)
|
|
const client = await this.initClient(server)
|
|
return await client.getPrompt({ name, arguments: args })
|
|
}
|
|
|
|
/**
|
|
* Get a specific prompt from an MCP server with caching
|
|
*/
|
|
@TraceMethod({ spanName: 'getPrompt', tag: 'mcp' })
|
|
public async getPrompt(
|
|
_: Electron.IpcMainInvokeEvent,
|
|
{ server, name, args }: { server: MCPServer; name: string; args?: Record<string, any> }
|
|
): Promise<GetPromptResult> {
|
|
const cachedGetPrompt = withCache<[MCPServer, string, Record<string, any> | undefined], GetPromptResult>(
|
|
this.getPromptImpl.bind(this),
|
|
(server, name, args) => {
|
|
const serverKey = this.getServerKey(server)
|
|
const argsKey = args ? JSON.stringify(args) : 'no-args'
|
|
return `mcp:get_prompt:${serverKey}:${name}:${argsKey}`
|
|
},
|
|
30 * 60 * 1000, // 30 minutes TTL
|
|
`[MCP] Prompt ${name} from ${server.name}`
|
|
)
|
|
return await cachedGetPrompt(server, name, args)
|
|
}
|
|
|
|
/**
|
|
* List resources available on an MCP server (implementation)
|
|
*/
|
|
private async listResourcesImpl(server: MCPServer): Promise<MCPResource[]> {
|
|
const client = await this.initClient(server)
|
|
logger.debug(`Listing resources for server: ${server.name}`)
|
|
try {
|
|
const result = await client.listResources()
|
|
const resources = result.resources || []
|
|
return (Array.isArray(resources) ? resources : []).map((resource: any) => ({
|
|
...resource,
|
|
serverId: server.id,
|
|
serverName: server.name
|
|
}))
|
|
} catch (error: any) {
|
|
// -32601 is the code for the method not found
|
|
if (error?.code !== -32601) {
|
|
getServerLogger(server).error(`Failed to list resources`, error as Error)
|
|
}
|
|
return []
|
|
}
|
|
}
|
|
|
|
/**
|
|
* List resources available on an MCP server with caching
|
|
*/
|
|
public async listResources(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise<MCPResource[]> {
|
|
const cachedListResources = withCache<[MCPServer], MCPResource[]>(
|
|
this.listResourcesImpl.bind(this),
|
|
(server) => {
|
|
const serverKey = this.getServerKey(server)
|
|
return `mcp:list_resources:${serverKey}`
|
|
},
|
|
60 * 60 * 1000, // 60 minutes TTL
|
|
`[MCP] Resources from ${server.name}`
|
|
)
|
|
return cachedListResources(server)
|
|
}
|
|
|
|
/**
|
|
* Get a specific resource from an MCP server (implementation)
|
|
*/
|
|
private async getResourceImpl(server: MCPServer, uri: string): Promise<GetResourceResponse> {
|
|
getServerLogger(server, { uri }).debug(`Getting resource`)
|
|
const client = await this.initClient(server)
|
|
try {
|
|
const result = await client.readResource({ uri: uri })
|
|
const contents: MCPResource[] = []
|
|
if (result.contents && result.contents.length > 0) {
|
|
result.contents.forEach((content: any) => {
|
|
contents.push({
|
|
...content,
|
|
serverId: server.id,
|
|
serverName: server.name
|
|
})
|
|
})
|
|
}
|
|
return {
|
|
contents: contents
|
|
}
|
|
} catch (error: any) {
|
|
getServerLogger(server, { uri }).error(`Failed to get resource`, error as Error)
|
|
throw new Error(`Failed to get resource ${uri} from server: ${server.name}: ${error.message}`)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get a specific resource from an MCP server with caching
|
|
*/
|
|
@TraceMethod({ spanName: 'getResource', tag: 'mcp' })
|
|
public async getResource(
|
|
_: Electron.IpcMainInvokeEvent,
|
|
{ server, uri }: { server: MCPServer; uri: string }
|
|
): Promise<GetResourceResponse> {
|
|
const cachedGetResource = withCache<[MCPServer, string], GetResourceResponse>(
|
|
this.getResourceImpl.bind(this),
|
|
(server, uri) => {
|
|
const serverKey = this.getServerKey(server)
|
|
return `mcp:get_resource:${serverKey}:${uri}`
|
|
},
|
|
30 * 60 * 1000, // 30 minutes TTL
|
|
`[MCP] Resource ${uri} from ${server.name}`
|
|
)
|
|
return await cachedGetResource(server, uri)
|
|
}
|
|
|
|
private getLoginShellEnv = memoize(async (): Promise<Record<string, string>> => {
|
|
try {
|
|
const loginEnv = await getLoginShellEnvironment()
|
|
const pathSeparator = process.platform === 'win32' ? ';' : ':'
|
|
const cherryBinPath = path.join(os.homedir(), '.cherrystudio', 'bin')
|
|
loginEnv.PATH = `${loginEnv.PATH}${pathSeparator}${cherryBinPath}`
|
|
logger.debug('Successfully fetched login shell environment variables:')
|
|
return loginEnv
|
|
} catch (error) {
|
|
logger.error('Failed to fetch login shell environment variables:', error as Error)
|
|
return {}
|
|
}
|
|
})
|
|
|
|
// 实现 abortTool 方法
|
|
public async abortTool(_: Electron.IpcMainInvokeEvent, callId: string) {
|
|
const activeToolCall = this.activeToolCalls.get(callId)
|
|
if (activeToolCall) {
|
|
activeToolCall.abort()
|
|
this.activeToolCalls.delete(callId)
|
|
logger.debug(`Aborted tool call`, { callId })
|
|
return true
|
|
} else {
|
|
logger.warn(`No active tool call found for callId`, { callId })
|
|
return false
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get the server version information
|
|
*/
|
|
public async getServerVersion(_: Electron.IpcMainInvokeEvent, server: MCPServer): Promise<string | null> {
|
|
try {
|
|
getServerLogger(server).debug(`Getting server version`)
|
|
const client = await this.initClient(server)
|
|
|
|
// Try to get server information which may include version
|
|
const serverInfo = client.getServerVersion()
|
|
getServerLogger(server).debug(`Server info`, redactSensitive(serverInfo))
|
|
|
|
if (serverInfo && serverInfo.version) {
|
|
getServerLogger(server).debug(`Server version`, { version: serverInfo.version })
|
|
return serverInfo.version
|
|
}
|
|
|
|
getServerLogger(server).warn(`No version information available`)
|
|
return null
|
|
} catch (error: any) {
|
|
getServerLogger(server).error(`Failed to get server version`, error as Error)
|
|
return null
|
|
}
|
|
}
|
|
}
|
|
|
|
export default new McpService()
|