diff --git a/src/main/mcpServers/hub/generator.ts b/src/main/mcpServers/hub/generator.ts new file mode 100644 index 0000000000..479ce17d04 --- /dev/null +++ b/src/main/mcpServers/hub/generator.ts @@ -0,0 +1,160 @@ +import type { MCPServer, MCPTool } from '@types' + +import type { GeneratedTool } from './types' + +function toCamelCase(str: string): string { + return str + .replace(/[^a-zA-Z0-9]+(.)/g, (_, char) => char.toUpperCase()) + .replace(/^[A-Z]/, (char) => char.toLowerCase()) + .replace(/[^a-zA-Z0-9]/g, '') +} + +function makeUniqueFunctionName(baseName: string, existingNames: Set): string { + let name = baseName + let counter = 1 + while (existingNames.has(name)) { + name = `${baseName}${counter}` + counter++ + } + existingNames.add(name) + return name +} + +function jsonSchemaToSignature(schema: Record | undefined): string { + if (!schema || typeof schema !== 'object') { + return '{}' + } + + const properties = schema.properties as Record> | undefined + if (!properties) { + return '{}' + } + + const required = (schema.required as string[]) || [] + const parts: string[] = [] + + for (const [key, prop] of Object.entries(properties)) { + const isRequired = required.includes(key) + const typeStr = schemaTypeToTS(prop) + parts.push(`${key}${isRequired ? '' : '?'}: ${typeStr}`) + } + + return `{ ${parts.join(', ')} }` +} + +function schemaTypeToTS(prop: Record): string { + const type = prop.type as string | string[] | undefined + const enumValues = prop.enum as unknown[] | undefined + + if (enumValues && Array.isArray(enumValues)) { + return enumValues.map((v) => (typeof v === 'string' ? `"${v}"` : String(v))).join(' | ') + } + + if (Array.isArray(type)) { + return type.map((t) => primitiveTypeToTS(t)).join(' | ') + } + + if (type === 'array') { + const items = prop.items as Record | undefined + if (items) { + return `Array<${schemaTypeToTS(items)}>` + } + return 'Array' + } + + if (type === 'object') { + const properties = prop.properties as Record> | undefined + if (properties) { + return jsonSchemaToSignature(prop) + } + return 'object' + } + + return primitiveTypeToTS(type) +} + +function primitiveTypeToTS(type: string | undefined): string { + switch (type) { + case 'string': + return 'string' + case 'number': + case 'integer': + return 'number' + case 'boolean': + return 'boolean' + case 'null': + return 'null' + default: + return 'unknown' + } +} + +function generateJSDoc(tool: MCPTool, signature: string, returns: string): string { + const lines: string[] = ['/**'] + + if (tool.description) { + const descLines = tool.description.split('\n') + for (const line of descLines) { + lines.push(` * ${line}`) + } + } + + lines.push(` *`) + lines.push(` * @param {${signature}} params`) + lines.push(` * @returns {Promise<${returns}>}`) + lines.push(` */`) + + return lines.join('\n') +} + +export function generateToolFunction( + tool: MCPTool, + server: MCPServer, + existingNames: Set, + callToolFn: (toolId: string, params: unknown) => Promise +): GeneratedTool { + const toolId = `${server.id}__${tool.name}` + const baseName = toCamelCase(tool.name) + const functionName = makeUniqueFunctionName(baseName, existingNames) + + const inputSchema = tool.inputSchema as Record | undefined + const outputSchema = tool.outputSchema as Record | undefined + + const signature = jsonSchemaToSignature(inputSchema) + const returns = outputSchema ? jsonSchemaToSignature(outputSchema) : 'unknown' + + const jsDoc = generateJSDoc(tool, signature, returns) + + const jsCode = `${jsDoc} +async function ${functionName}(params) { + return await __callTool("${toolId}", params); +}` + + const fn = async (params: unknown): Promise => { + return await callToolFn(toolId, params) + } + + return { + serverId: server.id, + serverName: server.name, + toolName: tool.name, + toolId, + functionName, + jsCode, + fn, + signature, + returns, + description: tool.description + } +} + +export function generateToolsCode(tools: GeneratedTool[]): string { + if (tools.length === 0) { + return '// No tools available' + } + + const header = `// Found ${tools.length} tool(s):\n` + const code = tools.map((t) => t.jsCode).join('\n\n') + + return header + '\n' + code +} diff --git a/src/main/mcpServers/hub/runtime.ts b/src/main/mcpServers/hub/runtime.ts new file mode 100644 index 0000000000..3b34a49dab --- /dev/null +++ b/src/main/mcpServers/hub/runtime.ts @@ -0,0 +1,105 @@ +import { loggerService } from '@logger' + +import { callMcpTool } from './mcp-bridge' +import type { ConsoleMethods, ExecOutput, ExecutionContext, GeneratedTool } from './types' + +const logger = loggerService.withContext('MCPServer:Hub:Runtime') + +const MAX_LOGS = 1000 +const EXECUTION_TIMEOUT = 60000 + +export class Runtime { + async execute(code: string, tools: GeneratedTool[]): Promise { + const logs: string[] = [] + const capturedConsole = this.createCapturedConsole(logs) + + try { + const context = this.buildContext(tools, capturedConsole) + const result = await this.runCode(code, context) + + return { + result, + logs: logs.length > 0 ? logs : undefined + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + logger.error('Execution error:', error as Error) + + return { + result: undefined, + logs: logs.length > 0 ? logs : undefined, + error: errorMessage + } + } + } + + private buildContext(tools: GeneratedTool[], capturedConsole: ConsoleMethods): ExecutionContext { + const context: ExecutionContext = { + __callTool: callMcpTool, + parallel: (...promises: Promise[]) => Promise.all(promises), + settle: (...promises: Promise[]) => Promise.allSettled(promises), + console: capturedConsole + } + + for (const tool of tools) { + context[tool.functionName] = tool.fn + } + + return context + } + + private async runCode(code: string, context: ExecutionContext): Promise { + const contextKeys = Object.keys(context) + const contextValues = contextKeys.map((k) => context[k]) + + const wrappedCode = ` + return (async () => { + ${code} + })() + ` + + const fn = new Function(...contextKeys, wrappedCode) + + const timeoutPromise = new Promise((_, reject) => { + setTimeout(() => { + reject(new Error(`Execution timed out after ${EXECUTION_TIMEOUT}ms`)) + }, EXECUTION_TIMEOUT) + }) + + const executionPromise = fn(...contextValues) + + return Promise.race([executionPromise, timeoutPromise]) + } + + private createCapturedConsole(logs: string[]): ConsoleMethods { + const addLog = (level: string, ...args: unknown[]) => { + if (logs.length >= MAX_LOGS) { + return + } + const message = args.map((arg) => this.stringify(arg)).join(' ') + logs.push(`[${level}] ${message}`) + } + + return { + log: (...args: unknown[]) => addLog('log', ...args), + warn: (...args: unknown[]) => addLog('warn', ...args), + error: (...args: unknown[]) => addLog('error', ...args), + info: (...args: unknown[]) => addLog('info', ...args), + debug: (...args: unknown[]) => addLog('debug', ...args) + } + } + + private stringify(value: unknown): string { + if (value === undefined) return 'undefined' + if (value === null) return 'null' + if (typeof value === 'string') return value + if (typeof value === 'number' || typeof value === 'boolean') return String(value) + if (value instanceof Error) return value.message + + try { + return JSON.stringify(value, null, 2) + } catch { + return String(value) + } + } +} diff --git a/src/main/mcpServers/hub/search.ts b/src/main/mcpServers/hub/search.ts new file mode 100644 index 0000000000..93c7eb70ac --- /dev/null +++ b/src/main/mcpServers/hub/search.ts @@ -0,0 +1,87 @@ +import { generateToolsCode } from './generator' +import type { GeneratedTool, SearchQuery, SearchResult } from './types' + +const DEFAULT_LIMIT = 10 +const MAX_LIMIT = 50 + +export function searchTools(tools: GeneratedTool[], query: SearchQuery): SearchResult { + const { query: queryStr, limit = DEFAULT_LIMIT } = query + const effectiveLimit = Math.min(Math.max(1, limit), MAX_LIMIT) + + const keywords = queryStr + .toLowerCase() + .split(',') + .map((k) => k.trim()) + .filter((k) => k.length > 0) + + if (keywords.length === 0) { + const sliced = tools.slice(0, effectiveLimit) + return { + tools: generateToolsCode(sliced), + total: tools.length + } + } + + const matchedTools = tools.filter((tool) => { + const searchText = buildSearchText(tool).toLowerCase() + return keywords.some((keyword) => searchText.includes(keyword)) + }) + + const rankedTools = rankTools(matchedTools, keywords) + const sliced = rankedTools.slice(0, effectiveLimit) + + return { + tools: generateToolsCode(sliced), + total: matchedTools.length + } +} + +function buildSearchText(tool: GeneratedTool): string { + const parts = [tool.toolName, tool.functionName, tool.serverName, tool.description || '', tool.signature] + return parts.join(' ') +} + +function rankTools(tools: GeneratedTool[], keywords: string[]): GeneratedTool[] { + const scored = tools.map((tool) => ({ + tool, + score: calculateScore(tool, keywords) + })) + + scored.sort((a, b) => b.score - a.score) + + return scored.map((s) => s.tool) +} + +function calculateScore(tool: GeneratedTool, keywords: string[]): number { + let score = 0 + const toolName = tool.toolName.toLowerCase() + const functionName = tool.functionName.toLowerCase() + const description = (tool.description || '').toLowerCase() + + for (const keyword of keywords) { + if (toolName === keyword) { + score += 10 + } else if (toolName.startsWith(keyword)) { + score += 5 + } else if (toolName.includes(keyword)) { + score += 3 + } + + if (functionName === keyword) { + score += 8 + } else if (functionName.includes(keyword)) { + score += 2 + } + + if (description.includes(keyword)) { + const count = (description.match(new RegExp(escapeRegex(keyword), 'g')) || []).length + score += Math.min(count, 3) + } + } + + return score +} + +function escapeRegex(str: string): string { + return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') +} diff --git a/src/main/mcpServers/hub/tool-registry.ts b/src/main/mcpServers/hub/tool-registry.ts new file mode 100644 index 0000000000..f08c80230d --- /dev/null +++ b/src/main/mcpServers/hub/tool-registry.ts @@ -0,0 +1,98 @@ +import { loggerService } from '@logger' + +import { generateToolFunction } from './generator' +import { callMcpTool, getActiveServers, listToolsFromServer } from './mcp-bridge' +import type { GeneratedTool, ToolRegistryOptions } from './types' + +const logger = loggerService.withContext('MCPServer:Hub:Registry') + +const DEFAULT_TTL = 10 * 60 * 1000 + +export class ToolRegistry { + private tools: Map = new Map() + private lastRefresh: number = 0 + private readonly ttl: number + private refreshPromise: Promise | null = null + + constructor(options: ToolRegistryOptions = {}) { + this.ttl = options.ttl ?? DEFAULT_TTL + } + + async getTools(): Promise { + if (this.isExpired()) { + await this.refresh() + } + return Array.from(this.tools.values()) + } + + async getTool(toolId: string): Promise { + if (this.isExpired()) { + await this.refresh() + } + return this.tools.get(toolId) + } + + getToolByFunctionName(functionName: string): GeneratedTool | undefined { + for (const tool of this.tools.values()) { + if (tool.functionName === functionName) { + return tool + } + } + return undefined + } + + private isExpired(): boolean { + return Date.now() - this.lastRefresh > this.ttl + } + + invalidate(): void { + this.lastRefresh = 0 + this.tools.clear() + logger.debug('Tool registry invalidated') + } + + async refresh(): Promise { + if (this.refreshPromise) { + return this.refreshPromise + } + + this.refreshPromise = this.doRefresh() + + try { + await this.refreshPromise + } finally { + this.refreshPromise = null + } + } + + private async doRefresh(): Promise { + logger.debug('Refreshing tool registry') + + const servers = getActiveServers() + const newTools = new Map() + const existingNames = new Set() + + for (const server of servers) { + try { + const serverTools = await listToolsFromServer(server) + + for (const tool of serverTools) { + const generatedTool = generateToolFunction(tool, server, existingNames, callMcpTool) + + newTools.set(generatedTool.toolId, generatedTool) + } + } catch (error) { + logger.error(`Failed to list tools from server ${server.name}:`, error as Error) + } + } + + this.tools = newTools + this.lastRefresh = Date.now() + + logger.debug(`Tool registry refreshed with ${this.tools.size} tools`) + } + + getToolCount(): number { + return this.tools.size + } +}