feat: isolate hub exec worker and filter disabled tools

This commit is contained in:
Vaayne 2025-12-30 11:28:20 +08:00
parent d005d4291f
commit 41699a1afd
9 changed files with 557 additions and 166 deletions

View File

@ -72,7 +72,8 @@ vi.mock('@main/services/MCPService', () => ({
}
}
return { content: [{ type: 'text', text: '{}' }] }
})
}),
abortTool: vi.fn(async () => true)
}
}))
@ -178,6 +179,45 @@ describe('HubServer Integration', () => {
})
})
describe('exec timeouts', () => {
afterEach(() => {
vi.useRealTimers()
})
it('aborts in-flight tool calls and returns logs on timeout', async () => {
vi.useFakeTimers()
let toolCallStarted: (() => void) | null = null
const toolCallStartedPromise = new Promise<void>((resolve) => {
toolCallStarted = resolve
})
vi.mocked(mcpService.callToolById).mockImplementationOnce(async () => {
toolCallStarted?.()
return await new Promise(() => {})
})
const execPromise = (hubServer as any).handleExec({
code: `
console.log("starting");
return await github_searchRepos({ query: "hang" });
`
})
await toolCallStartedPromise
await vi.advanceTimersByTimeAsync(60000)
await vi.runAllTimersAsync()
const execResult = await execPromise
const execOutput = JSON.parse(execResult.content[0].text)
expect(execOutput.error).toBe('Execution timed out after 60000ms')
expect(execOutput.result).toBeUndefined()
expect(execOutput.logs).toContain('[log] starting')
expect(vi.mocked(mcpService.abortTool)).toHaveBeenCalled()
})
})
describe('server instance', () => {
it('creates a valid MCP server instance', () => {
expect(hubServer.server).toBeDefined()

View File

@ -56,7 +56,7 @@ describe('Runtime', () => {
const result = await runtime.execute('return await searchRepos({ query: "test" })', tools)
expect(result.result).toEqual({ repos: ['repo1', 'repo2'], query: { query: 'test' } })
expect(result.result).toEqual({ toolId: 'searchRepos', params: { query: 'test' }, success: true })
})
it('captures console logs', async () => {

View File

@ -19,7 +19,7 @@ export async function refreshToolMap(): Promise<void> {
}
}
export const callMcpTool = async (functionName: string, params: unknown): Promise<unknown> => {
export const callMcpTool = async (functionName: string, params: unknown, callId?: string): Promise<unknown> => {
const toolInfo = toolFunctionNameToIdMap.get(functionName)
if (!toolInfo) {
await refreshToolMap()
@ -28,14 +28,18 @@ export const callMcpTool = async (functionName: string, params: unknown): Promis
throw new Error(`Tool not found: ${functionName}`)
}
const toolId = `${retryToolInfo.serverId}__${retryToolInfo.toolName}`
const result = await mcpService.callToolById(toolId, params)
const result = await mcpService.callToolById(toolId, params, callId)
return extractToolResult(result)
}
const toolId = `${toolInfo.serverId}__${toolInfo.toolName}`
const result = await mcpService.callToolById(toolId, params)
const result = await mcpService.callToolById(toolId, params, callId)
return extractToolResult(result)
}
export const abortMcpTool = async (callId: string): Promise<boolean> => {
return mcpService.abortTool(null as unknown as Electron.IpcMainInvokeEvent, callId)
}
function extractToolResult(result: { content: Array<{ type: string; text?: string }> }): unknown {
if (!result.content || result.content.length === 0) {
return null

View File

@ -1,105 +1,169 @@
import crypto from 'node:crypto'
import { Worker } from 'node:worker_threads'
import { loggerService } from '@logger'
import { callMcpTool } from './mcp-bridge'
import type { ConsoleMethods, ExecOutput, ExecutionContext, GeneratedTool } from './types'
import { abortMcpTool, callMcpTool } from './mcp-bridge'
import type {
ExecOutput,
GeneratedTool,
HubWorkerCallToolMessage,
HubWorkerExecMessage,
HubWorkerMessage,
HubWorkerResultMessage
} from './types'
const logger = loggerService.withContext('MCPServer:Hub:Runtime')
const MAX_LOGS = 1000
const EXECUTION_TIMEOUT = 60000
const WORKER_URL = new URL('./worker.js', import.meta.url)
export class Runtime {
async execute(code: string, tools: GeneratedTool[]): Promise<ExecOutput> {
const logs: string[] = []
const capturedConsole = this.createCapturedConsole(logs)
return await new Promise<ExecOutput>((resolve) => {
const logs: string[] = []
const activeCallIds = new Map<string, string>()
let finished = false
let timedOut = false
let timeoutId: NodeJS.Timeout | null = null
try {
const context = this.buildContext(tools, capturedConsole)
const result = await this.runCode(code, context)
const worker = new Worker(WORKER_URL)
return {
result,
logs: logs.length > 0 ? logs : undefined
const addLog = (entry: string) => {
if (logs.length >= MAX_LOGS) {
return
}
logs.push(entry)
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
logger.error('Execution error:', error as Error)
return {
result: undefined,
logs: logs.length > 0 ? logs : undefined,
error: errorMessage
const finalize = async (output: ExecOutput, terminateWorker = true) => {
if (finished) {
return
}
finished = true
if (timeoutId) {
clearTimeout(timeoutId)
}
worker.removeAllListeners()
if (terminateWorker) {
try {
await worker.terminate()
} catch (error) {
logger.warn('Failed to terminate exec worker', error as Error)
}
}
resolve(output)
}
}
}
private buildContext(tools: GeneratedTool[], capturedConsole: ConsoleMethods): ExecutionContext {
const context: ExecutionContext = {
__callTool: callMcpTool,
parallel: <T>(...promises: Promise<T>[]) => Promise.all(promises),
settle: <T>(...promises: Promise<T>[]) => Promise.allSettled(promises),
console: capturedConsole
}
const abortActiveTools = async () => {
const callIds = Array.from(activeCallIds.values())
activeCallIds.clear()
if (callIds.length === 0) {
return
}
await Promise.allSettled(callIds.map((callId) => abortMcpTool(callId)))
}
for (const tool of tools) {
context[tool.functionName] = tool.fn
}
const handleToolCall = async (message: HubWorkerCallToolMessage) => {
if (finished || timedOut) {
return
}
const callId = crypto.randomUUID()
activeCallIds.set(message.requestId, callId)
return context
}
try {
const result = await callMcpTool(message.functionName, message.params, callId)
if (finished || timedOut) {
return
}
worker.postMessage({ type: 'toolResult', requestId: message.requestId, result })
} catch (error) {
if (finished || timedOut) {
return
}
const errorMessage = error instanceof Error ? error.message : String(error)
worker.postMessage({ type: 'toolError', requestId: message.requestId, error: errorMessage })
} finally {
activeCallIds.delete(message.requestId)
}
}
private async runCode(code: string, context: ExecutionContext): Promise<unknown> {
const contextKeys = Object.keys(context)
const contextValues = contextKeys.map((k) => context[k])
const handleResult = (message: HubWorkerResultMessage) => {
const resolvedLogs = message.logs && message.logs.length > 0 ? message.logs : logs
void finalize({
result: message.result,
logs: resolvedLogs.length > 0 ? resolvedLogs : undefined
})
}
const wrappedCode = `
return (async () => {
${code}
})()
`
const handleError = (errorMessage: string, messageLogs?: string[], terminateWorker = true) => {
const resolvedLogs = messageLogs && messageLogs.length > 0 ? messageLogs : logs
void finalize(
{
result: undefined,
logs: resolvedLogs.length > 0 ? resolvedLogs : undefined,
error: errorMessage
},
terminateWorker
)
}
const fn = new Function(...contextKeys, wrappedCode)
const handleMessage = (message: HubWorkerMessage) => {
if (!message || typeof message !== 'object') {
return
}
switch (message.type) {
case 'log':
addLog(message.entry)
break
case 'callTool':
void handleToolCall(message)
break
case 'result':
handleResult(message)
break
case 'error':
handleError(message.error, message.logs)
break
default:
break
}
}
const timeoutPromise = new Promise<never>((_, reject) => {
setTimeout(() => {
reject(new Error(`Execution timed out after ${EXECUTION_TIMEOUT}ms`))
timeoutId = setTimeout(() => {
timedOut = true
void (async () => {
await abortActiveTools()
try {
await worker.terminate()
} catch (error) {
logger.warn('Failed to terminate exec worker after timeout', error as Error)
}
handleError(`Execution timed out after ${EXECUTION_TIMEOUT}ms`, undefined, false)
})()
}, EXECUTION_TIMEOUT)
})
const executionPromise = fn(...contextValues)
worker.on('message', handleMessage)
worker.on('error', (error) => {
logger.error('Worker execution error', error)
handleError(error instanceof Error ? error.message : String(error))
})
worker.on('exit', (code) => {
if (finished || timedOut) {
return
}
const message = code === 0 ? 'Exec worker exited unexpectedly' : `Exec worker exited with code ${code}`
logger.error(message)
handleError(message, undefined, false)
})
return Promise.race([executionPromise, timeoutPromise])
}
private createCapturedConsole(logs: string[]): ConsoleMethods {
const addLog = (level: string, ...args: unknown[]) => {
if (logs.length >= MAX_LOGS) {
return
const execMessage: HubWorkerExecMessage = {
type: 'exec',
code,
tools: tools.map((tool) => ({ functionName: tool.functionName }))
}
const message = args.map((arg) => this.stringify(arg)).join(' ')
logs.push(`[${level}] ${message}`)
}
return {
log: (...args: unknown[]) => addLog('log', ...args),
warn: (...args: unknown[]) => addLog('warn', ...args),
error: (...args: unknown[]) => addLog('error', ...args),
info: (...args: unknown[]) => addLog('info', ...args),
debug: (...args: unknown[]) => addLog('debug', ...args)
}
}
private stringify(value: unknown): string {
if (value === undefined) return 'undefined'
if (value === null) return 'null'
if (typeof value === 'string') return value
if (typeof value === 'number' || typeof value === 'boolean') return String(value)
if (value instanceof Error) return value.message
try {
return JSON.stringify(value, null, 2)
} catch {
return String(value)
}
worker.postMessage(execMessage)
})
}
}

View File

@ -55,3 +55,58 @@ export interface ConsoleMethods {
info: (...args: unknown[]) => void
debug: (...args: unknown[]) => void
}
export type HubWorkerTool = {
functionName: string
}
export type HubWorkerExecMessage = {
type: 'exec'
code: string
tools: HubWorkerTool[]
}
export type HubWorkerCallToolMessage = {
type: 'callTool'
requestId: string
functionName: string
params: unknown
}
export type HubWorkerToolResultMessage = {
type: 'toolResult'
requestId: string
result: unknown
}
export type HubWorkerToolErrorMessage = {
type: 'toolError'
requestId: string
error: string
}
export type HubWorkerResultMessage = {
type: 'result'
result: unknown
logs?: string[]
}
export type HubWorkerErrorMessage = {
type: 'error'
error: string
logs?: string[]
}
export type HubWorkerLogMessage = {
type: 'log'
entry: string
}
export type HubWorkerMessage =
| HubWorkerExecMessage
| HubWorkerCallToolMessage
| HubWorkerToolResultMessage
| HubWorkerToolErrorMessage
| HubWorkerResultMessage
| HubWorkerErrorMessage
| HubWorkerLogMessage

View File

@ -0,0 +1,131 @@
const crypto = require('node:crypto')
const { parentPort } = require('node:worker_threads')
const MAX_LOGS = 1000
const logs = []
const pendingCalls = new Map()
let isExecuting = false
const stringify = (value) => {
if (value === undefined) return 'undefined'
if (value === null) return 'null'
if (typeof value === 'string') return value
if (typeof value === 'number' || typeof value === 'boolean') return String(value)
if (value instanceof Error) return value.message
try {
return JSON.stringify(value, null, 2)
} catch {
return String(value)
}
}
const pushLog = (level, args) => {
if (logs.length >= MAX_LOGS) {
return
}
const message = args.map((arg) => stringify(arg)).join(' ')
const entry = `[${level}] ${message}`
logs.push(entry)
parentPort?.postMessage({ type: 'log', entry })
}
const capturedConsole = {
log: (...args) => pushLog('log', args),
warn: (...args) => pushLog('warn', args),
error: (...args) => pushLog('error', args),
info: (...args) => pushLog('info', args),
debug: (...args) => pushLog('debug', args)
}
const callTool = (functionName, params) =>
new Promise((resolve, reject) => {
const requestId = crypto.randomUUID()
pendingCalls.set(requestId, { resolve, reject })
parentPort?.postMessage({ type: 'callTool', requestId, functionName, params })
})
const buildContext = (tools) => {
const context = {
__callTool: callTool,
parallel: (...promises) => Promise.all(promises),
settle: (...promises) => Promise.allSettled(promises),
console: capturedConsole
}
for (const tool of tools) {
context[tool.functionName] = (params) => callTool(tool.functionName, params)
}
return context
}
const runCode = async (code, context) => {
const contextKeys = Object.keys(context)
const contextValues = contextKeys.map((key) => context[key])
const wrappedCode = `
return (async () => {
${code}
})()
`
const fn = new Function(...contextKeys, wrappedCode)
return await fn(...contextValues)
}
const handleExec = async (code, tools) => {
if (isExecuting) {
return
}
isExecuting = true
try {
const context = buildContext(tools)
const result = await runCode(code, context)
parentPort?.postMessage({ type: 'result', result, logs: logs.length > 0 ? logs : undefined })
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
parentPort?.postMessage({ type: 'error', error: errorMessage, logs: logs.length > 0 ? logs : undefined })
} finally {
pendingCalls.clear()
}
}
const handleToolResult = (message) => {
const pending = pendingCalls.get(message.requestId)
if (!pending) {
return
}
pendingCalls.delete(message.requestId)
pending.resolve(message.result)
}
const handleToolError = (message) => {
const pending = pendingCalls.get(message.requestId)
if (!pending) {
return
}
pendingCalls.delete(message.requestId)
pending.reject(new Error(message.error))
}
parentPort?.on('message', (message) => {
if (!message || typeof message !== 'object') {
return
}
switch (message.type) {
case 'exec':
handleExec(message.code, message.tools ?? [])
break
case 'toolResult':
handleToolResult(message)
break
case 'toolError':
handleToolError(message)
break
default:
break
}
})

View File

@ -180,7 +180,9 @@ class McpService {
}
try {
const tools = await this.listToolsImpl(server)
allTools.push(...tools)
const disabledTools = new Set(server.disabledTools ?? [])
const enabledTools = disabledTools.size > 0 ? tools.filter((tool) => !disabledTools.has(tool.name)) : tools
allTools.push(...enabledTools)
} catch (error) {
logger.error(`[listAllActiveServerTools] Failed to list tools from ${server.name}:`, error as Error)
}
@ -195,7 +197,8 @@ class McpService {
*/
public async callToolById(
toolId: string,
params: unknown
params: unknown,
callId?: string
): Promise<{ content: Array<{ type: string; text?: string }> }> {
const parts = toolId.split('__')
if (parts.length < 2) {
@ -217,7 +220,8 @@ class McpService {
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, {
server,
name: toolName,
args: params
args: params,
callId
})
}

View File

@ -0,0 +1,75 @@
import type { MCPServer, MCPTool } from '@types'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('@main/apiServer/utils/mcp', () => ({
getMCPServersFromRedux: vi.fn()
}))
vi.mock('@main/services/WindowService', () => ({
windowService: {
getMainWindow: vi.fn(() => null)
}
}))
import { getMCPServersFromRedux } from '@main/apiServer/utils/mcp'
import mcpService from '@main/services/MCPService'
const baseInputSchema: { type: 'object'; properties: Record<string, unknown>; required: string[] } = {
type: 'object',
properties: {},
required: []
}
const createTool = (overrides: Partial<MCPTool>): MCPTool => ({
id: `${overrides.serverId}__${overrides.name}`,
name: overrides.name ?? 'tool',
description: overrides.description,
serverId: overrides.serverId ?? 'server',
serverName: overrides.serverName ?? 'server',
inputSchema: baseInputSchema,
type: 'mcp',
...overrides
})
describe('MCPService.listAllActiveServerTools', () => {
beforeEach(() => {
vi.clearAllMocks()
})
afterEach(() => {
vi.restoreAllMocks()
})
it('filters disabled tools per server', async () => {
const servers: MCPServer[] = [
{
id: 'alpha',
name: 'Alpha',
isActive: true,
disabledTools: ['disabled_tool']
},
{
id: 'beta',
name: 'Beta',
isActive: true
}
]
vi.mocked(getMCPServersFromRedux).mockResolvedValue(servers)
const listToolsSpy = vi.spyOn(mcpService as any, 'listToolsImpl').mockImplementation(async (server: any) => {
if (server.id === 'alpha') {
return [
createTool({ name: 'enabled_tool', serverId: server.id, serverName: server.name }),
createTool({ name: 'disabled_tool', serverId: server.id, serverName: server.name })
]
}
return [createTool({ name: 'beta_tool', serverId: server.id, serverName: server.name })]
})
const tools = await mcpService.listAllActiveServerTools()
expect(listToolsSpy).toHaveBeenCalledTimes(2)
expect(tools.map((tool) => tool.name)).toEqual(['enabled_tool', 'beta_tool'])
})
})

View File

@ -10,59 +10,69 @@ vi.mock('@logger', async () => {
})
// Mock electron modules that are commonly used in main process
vi.mock('electron', () => ({
app: {
getPath: vi.fn((key: string) => {
switch (key) {
case 'userData':
return '/mock/userData'
case 'temp':
return '/mock/temp'
case 'logs':
return '/mock/logs'
default:
return '/mock/unknown'
vi.mock('electron', () => {
const mock = {
app: {
getPath: vi.fn((key: string) => {
switch (key) {
case 'userData':
return '/mock/userData'
case 'temp':
return '/mock/temp'
case 'logs':
return '/mock/logs'
default:
return '/mock/unknown'
}
}),
getVersion: vi.fn(() => '1.0.0')
},
ipcMain: {
handle: vi.fn(),
on: vi.fn(),
once: vi.fn(),
removeHandler: vi.fn(),
removeAllListeners: vi.fn()
},
BrowserWindow: vi.fn(),
dialog: {
showErrorBox: vi.fn(),
showMessageBox: vi.fn(),
showOpenDialog: vi.fn(),
showSaveDialog: vi.fn()
},
shell: {
openExternal: vi.fn(),
showItemInFolder: vi.fn()
},
session: {
defaultSession: {
clearCache: vi.fn(),
clearStorageData: vi.fn()
}
}),
getVersion: vi.fn(() => '1.0.0')
},
ipcMain: {
handle: vi.fn(),
on: vi.fn(),
once: vi.fn(),
removeHandler: vi.fn(),
removeAllListeners: vi.fn()
},
BrowserWindow: vi.fn(),
dialog: {
showErrorBox: vi.fn(),
showMessageBox: vi.fn(),
showOpenDialog: vi.fn(),
showSaveDialog: vi.fn()
},
shell: {
openExternal: vi.fn(),
showItemInFolder: vi.fn()
},
session: {
defaultSession: {
clearCache: vi.fn(),
clearStorageData: vi.fn()
}
},
webContents: {
getAllWebContents: vi.fn(() => [])
},
systemPreferences: {
getMediaAccessStatus: vi.fn(),
askForMediaAccess: vi.fn()
},
screen: {
getPrimaryDisplay: vi.fn(),
getAllDisplays: vi.fn()
},
Notification: vi.fn()
}))
},
webContents: {
getAllWebContents: vi.fn(() => [])
},
systemPreferences: {
getMediaAccessStatus: vi.fn(),
askForMediaAccess: vi.fn()
},
nativeTheme: {
themeSource: 'system',
shouldUseDarkColors: false,
on: vi.fn(),
removeListener: vi.fn()
},
screen: {
getPrimaryDisplay: vi.fn(),
getAllDisplays: vi.fn()
},
Notification: vi.fn()
}
return { __esModule: true, ...mock, default: mock }
})
// Mock Winston for LoggerService dependencies
vi.mock('winston', () => ({
@ -98,13 +108,17 @@ vi.mock('winston-daily-rotate-file', () => {
})
// Mock Node.js modules
vi.mock('node:os', () => ({
platform: vi.fn(() => 'darwin'),
arch: vi.fn(() => 'x64'),
version: vi.fn(() => '20.0.0'),
cpus: vi.fn(() => [{ model: 'Mock CPU' }]),
totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB
}))
vi.mock('node:os', () => {
const mock = {
platform: vi.fn(() => 'darwin'),
arch: vi.fn(() => 'x64'),
version: vi.fn(() => '20.0.0'),
cpus: vi.fn(() => [{ model: 'Mock CPU' }]),
homedir: vi.fn(() => '/mock/home'),
totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB
}
return { ...mock, default: mock }
})
vi.mock('node:path', async () => {
const actual = await vi.importActual('node:path')
@ -115,25 +129,29 @@ vi.mock('node:path', async () => {
}
})
vi.mock('node:fs', () => ({
promises: {
access: vi.fn(),
readFile: vi.fn(),
writeFile: vi.fn(),
mkdir: vi.fn(),
readdir: vi.fn(),
stat: vi.fn(),
unlink: vi.fn(),
rmdir: vi.fn()
},
existsSync: vi.fn(),
readFileSync: vi.fn(),
writeFileSync: vi.fn(),
mkdirSync: vi.fn(),
readdirSync: vi.fn(),
statSync: vi.fn(),
unlinkSync: vi.fn(),
rmdirSync: vi.fn(),
createReadStream: vi.fn(),
createWriteStream: vi.fn()
}))
vi.mock('node:fs', () => {
const mock = {
promises: {
access: vi.fn(),
readFile: vi.fn(),
writeFile: vi.fn(),
mkdir: vi.fn(),
readdir: vi.fn(),
stat: vi.fn(),
unlink: vi.fn(),
rmdir: vi.fn()
},
existsSync: vi.fn(),
readFileSync: vi.fn(),
writeFileSync: vi.fn(),
mkdirSync: vi.fn(),
readdirSync: vi.fn(),
statSync: vi.fn(),
unlinkSync: vi.fn(),
rmdirSync: vi.fn(),
createReadStream: vi.fn(),
createWriteStream: vi.fn()
}
return { ...mock, default: mock }
})