mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-12 00:49:14 +08:00
✨ feat: isolate hub exec worker and filter disabled tools
This commit is contained in:
parent
d005d4291f
commit
41699a1afd
@ -72,7 +72,8 @@ vi.mock('@main/services/MCPService', () => ({
|
||||
}
|
||||
}
|
||||
return { content: [{ type: 'text', text: '{}' }] }
|
||||
})
|
||||
}),
|
||||
abortTool: vi.fn(async () => true)
|
||||
}
|
||||
}))
|
||||
|
||||
@ -178,6 +179,45 @@ describe('HubServer Integration', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('exec timeouts', () => {
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('aborts in-flight tool calls and returns logs on timeout', async () => {
|
||||
vi.useFakeTimers()
|
||||
|
||||
let toolCallStarted: (() => void) | null = null
|
||||
const toolCallStartedPromise = new Promise<void>((resolve) => {
|
||||
toolCallStarted = resolve
|
||||
})
|
||||
|
||||
vi.mocked(mcpService.callToolById).mockImplementationOnce(async () => {
|
||||
toolCallStarted?.()
|
||||
return await new Promise(() => {})
|
||||
})
|
||||
|
||||
const execPromise = (hubServer as any).handleExec({
|
||||
code: `
|
||||
console.log("starting");
|
||||
return await github_searchRepos({ query: "hang" });
|
||||
`
|
||||
})
|
||||
|
||||
await toolCallStartedPromise
|
||||
await vi.advanceTimersByTimeAsync(60000)
|
||||
await vi.runAllTimersAsync()
|
||||
|
||||
const execResult = await execPromise
|
||||
const execOutput = JSON.parse(execResult.content[0].text)
|
||||
|
||||
expect(execOutput.error).toBe('Execution timed out after 60000ms')
|
||||
expect(execOutput.result).toBeUndefined()
|
||||
expect(execOutput.logs).toContain('[log] starting')
|
||||
expect(vi.mocked(mcpService.abortTool)).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('server instance', () => {
|
||||
it('creates a valid MCP server instance', () => {
|
||||
expect(hubServer.server).toBeDefined()
|
||||
|
||||
@ -56,7 +56,7 @@ describe('Runtime', () => {
|
||||
|
||||
const result = await runtime.execute('return await searchRepos({ query: "test" })', tools)
|
||||
|
||||
expect(result.result).toEqual({ repos: ['repo1', 'repo2'], query: { query: 'test' } })
|
||||
expect(result.result).toEqual({ toolId: 'searchRepos', params: { query: 'test' }, success: true })
|
||||
})
|
||||
|
||||
it('captures console logs', async () => {
|
||||
|
||||
@ -19,7 +19,7 @@ export async function refreshToolMap(): Promise<void> {
|
||||
}
|
||||
}
|
||||
|
||||
export const callMcpTool = async (functionName: string, params: unknown): Promise<unknown> => {
|
||||
export const callMcpTool = async (functionName: string, params: unknown, callId?: string): Promise<unknown> => {
|
||||
const toolInfo = toolFunctionNameToIdMap.get(functionName)
|
||||
if (!toolInfo) {
|
||||
await refreshToolMap()
|
||||
@ -28,14 +28,18 @@ export const callMcpTool = async (functionName: string, params: unknown): Promis
|
||||
throw new Error(`Tool not found: ${functionName}`)
|
||||
}
|
||||
const toolId = `${retryToolInfo.serverId}__${retryToolInfo.toolName}`
|
||||
const result = await mcpService.callToolById(toolId, params)
|
||||
const result = await mcpService.callToolById(toolId, params, callId)
|
||||
return extractToolResult(result)
|
||||
}
|
||||
const toolId = `${toolInfo.serverId}__${toolInfo.toolName}`
|
||||
const result = await mcpService.callToolById(toolId, params)
|
||||
const result = await mcpService.callToolById(toolId, params, callId)
|
||||
return extractToolResult(result)
|
||||
}
|
||||
|
||||
export const abortMcpTool = async (callId: string): Promise<boolean> => {
|
||||
return mcpService.abortTool(null as unknown as Electron.IpcMainInvokeEvent, callId)
|
||||
}
|
||||
|
||||
function extractToolResult(result: { content: Array<{ type: string; text?: string }> }): unknown {
|
||||
if (!result.content || result.content.length === 0) {
|
||||
return null
|
||||
|
||||
@ -1,105 +1,169 @@
|
||||
import crypto from 'node:crypto'
|
||||
import { Worker } from 'node:worker_threads'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
import { callMcpTool } from './mcp-bridge'
|
||||
import type { ConsoleMethods, ExecOutput, ExecutionContext, GeneratedTool } from './types'
|
||||
import { abortMcpTool, callMcpTool } from './mcp-bridge'
|
||||
import type {
|
||||
ExecOutput,
|
||||
GeneratedTool,
|
||||
HubWorkerCallToolMessage,
|
||||
HubWorkerExecMessage,
|
||||
HubWorkerMessage,
|
||||
HubWorkerResultMessage
|
||||
} from './types'
|
||||
|
||||
const logger = loggerService.withContext('MCPServer:Hub:Runtime')
|
||||
|
||||
const MAX_LOGS = 1000
|
||||
const EXECUTION_TIMEOUT = 60000
|
||||
const WORKER_URL = new URL('./worker.js', import.meta.url)
|
||||
|
||||
export class Runtime {
|
||||
async execute(code: string, tools: GeneratedTool[]): Promise<ExecOutput> {
|
||||
const logs: string[] = []
|
||||
const capturedConsole = this.createCapturedConsole(logs)
|
||||
return await new Promise<ExecOutput>((resolve) => {
|
||||
const logs: string[] = []
|
||||
const activeCallIds = new Map<string, string>()
|
||||
let finished = false
|
||||
let timedOut = false
|
||||
let timeoutId: NodeJS.Timeout | null = null
|
||||
|
||||
try {
|
||||
const context = this.buildContext(tools, capturedConsole)
|
||||
const result = await this.runCode(code, context)
|
||||
const worker = new Worker(WORKER_URL)
|
||||
|
||||
return {
|
||||
result,
|
||||
logs: logs.length > 0 ? logs : undefined
|
||||
const addLog = (entry: string) => {
|
||||
if (logs.length >= MAX_LOGS) {
|
||||
return
|
||||
}
|
||||
logs.push(entry)
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
logger.error('Execution error:', error as Error)
|
||||
|
||||
return {
|
||||
result: undefined,
|
||||
logs: logs.length > 0 ? logs : undefined,
|
||||
error: errorMessage
|
||||
const finalize = async (output: ExecOutput, terminateWorker = true) => {
|
||||
if (finished) {
|
||||
return
|
||||
}
|
||||
finished = true
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId)
|
||||
}
|
||||
worker.removeAllListeners()
|
||||
if (terminateWorker) {
|
||||
try {
|
||||
await worker.terminate()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to terminate exec worker', error as Error)
|
||||
}
|
||||
}
|
||||
resolve(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private buildContext(tools: GeneratedTool[], capturedConsole: ConsoleMethods): ExecutionContext {
|
||||
const context: ExecutionContext = {
|
||||
__callTool: callMcpTool,
|
||||
parallel: <T>(...promises: Promise<T>[]) => Promise.all(promises),
|
||||
settle: <T>(...promises: Promise<T>[]) => Promise.allSettled(promises),
|
||||
console: capturedConsole
|
||||
}
|
||||
const abortActiveTools = async () => {
|
||||
const callIds = Array.from(activeCallIds.values())
|
||||
activeCallIds.clear()
|
||||
if (callIds.length === 0) {
|
||||
return
|
||||
}
|
||||
await Promise.allSettled(callIds.map((callId) => abortMcpTool(callId)))
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
context[tool.functionName] = tool.fn
|
||||
}
|
||||
const handleToolCall = async (message: HubWorkerCallToolMessage) => {
|
||||
if (finished || timedOut) {
|
||||
return
|
||||
}
|
||||
const callId = crypto.randomUUID()
|
||||
activeCallIds.set(message.requestId, callId)
|
||||
|
||||
return context
|
||||
}
|
||||
try {
|
||||
const result = await callMcpTool(message.functionName, message.params, callId)
|
||||
if (finished || timedOut) {
|
||||
return
|
||||
}
|
||||
worker.postMessage({ type: 'toolResult', requestId: message.requestId, result })
|
||||
} catch (error) {
|
||||
if (finished || timedOut) {
|
||||
return
|
||||
}
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
worker.postMessage({ type: 'toolError', requestId: message.requestId, error: errorMessage })
|
||||
} finally {
|
||||
activeCallIds.delete(message.requestId)
|
||||
}
|
||||
}
|
||||
|
||||
private async runCode(code: string, context: ExecutionContext): Promise<unknown> {
|
||||
const contextKeys = Object.keys(context)
|
||||
const contextValues = contextKeys.map((k) => context[k])
|
||||
const handleResult = (message: HubWorkerResultMessage) => {
|
||||
const resolvedLogs = message.logs && message.logs.length > 0 ? message.logs : logs
|
||||
void finalize({
|
||||
result: message.result,
|
||||
logs: resolvedLogs.length > 0 ? resolvedLogs : undefined
|
||||
})
|
||||
}
|
||||
|
||||
const wrappedCode = `
|
||||
return (async () => {
|
||||
${code}
|
||||
})()
|
||||
`
|
||||
const handleError = (errorMessage: string, messageLogs?: string[], terminateWorker = true) => {
|
||||
const resolvedLogs = messageLogs && messageLogs.length > 0 ? messageLogs : logs
|
||||
void finalize(
|
||||
{
|
||||
result: undefined,
|
||||
logs: resolvedLogs.length > 0 ? resolvedLogs : undefined,
|
||||
error: errorMessage
|
||||
},
|
||||
terminateWorker
|
||||
)
|
||||
}
|
||||
|
||||
const fn = new Function(...contextKeys, wrappedCode)
|
||||
const handleMessage = (message: HubWorkerMessage) => {
|
||||
if (!message || typeof message !== 'object') {
|
||||
return
|
||||
}
|
||||
switch (message.type) {
|
||||
case 'log':
|
||||
addLog(message.entry)
|
||||
break
|
||||
case 'callTool':
|
||||
void handleToolCall(message)
|
||||
break
|
||||
case 'result':
|
||||
handleResult(message)
|
||||
break
|
||||
case 'error':
|
||||
handleError(message.error, message.logs)
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
setTimeout(() => {
|
||||
reject(new Error(`Execution timed out after ${EXECUTION_TIMEOUT}ms`))
|
||||
timeoutId = setTimeout(() => {
|
||||
timedOut = true
|
||||
void (async () => {
|
||||
await abortActiveTools()
|
||||
try {
|
||||
await worker.terminate()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to terminate exec worker after timeout', error as Error)
|
||||
}
|
||||
handleError(`Execution timed out after ${EXECUTION_TIMEOUT}ms`, undefined, false)
|
||||
})()
|
||||
}, EXECUTION_TIMEOUT)
|
||||
})
|
||||
|
||||
const executionPromise = fn(...contextValues)
|
||||
worker.on('message', handleMessage)
|
||||
worker.on('error', (error) => {
|
||||
logger.error('Worker execution error', error)
|
||||
handleError(error instanceof Error ? error.message : String(error))
|
||||
})
|
||||
worker.on('exit', (code) => {
|
||||
if (finished || timedOut) {
|
||||
return
|
||||
}
|
||||
const message = code === 0 ? 'Exec worker exited unexpectedly' : `Exec worker exited with code ${code}`
|
||||
logger.error(message)
|
||||
handleError(message, undefined, false)
|
||||
})
|
||||
|
||||
return Promise.race([executionPromise, timeoutPromise])
|
||||
}
|
||||
|
||||
private createCapturedConsole(logs: string[]): ConsoleMethods {
|
||||
const addLog = (level: string, ...args: unknown[]) => {
|
||||
if (logs.length >= MAX_LOGS) {
|
||||
return
|
||||
const execMessage: HubWorkerExecMessage = {
|
||||
type: 'exec',
|
||||
code,
|
||||
tools: tools.map((tool) => ({ functionName: tool.functionName }))
|
||||
}
|
||||
const message = args.map((arg) => this.stringify(arg)).join(' ')
|
||||
logs.push(`[${level}] ${message}`)
|
||||
}
|
||||
|
||||
return {
|
||||
log: (...args: unknown[]) => addLog('log', ...args),
|
||||
warn: (...args: unknown[]) => addLog('warn', ...args),
|
||||
error: (...args: unknown[]) => addLog('error', ...args),
|
||||
info: (...args: unknown[]) => addLog('info', ...args),
|
||||
debug: (...args: unknown[]) => addLog('debug', ...args)
|
||||
}
|
||||
}
|
||||
|
||||
private stringify(value: unknown): string {
|
||||
if (value === undefined) return 'undefined'
|
||||
if (value === null) return 'null'
|
||||
if (typeof value === 'string') return value
|
||||
if (typeof value === 'number' || typeof value === 'boolean') return String(value)
|
||||
if (value instanceof Error) return value.message
|
||||
|
||||
try {
|
||||
return JSON.stringify(value, null, 2)
|
||||
} catch {
|
||||
return String(value)
|
||||
}
|
||||
worker.postMessage(execMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,3 +55,58 @@ export interface ConsoleMethods {
|
||||
info: (...args: unknown[]) => void
|
||||
debug: (...args: unknown[]) => void
|
||||
}
|
||||
|
||||
export type HubWorkerTool = {
|
||||
functionName: string
|
||||
}
|
||||
|
||||
export type HubWorkerExecMessage = {
|
||||
type: 'exec'
|
||||
code: string
|
||||
tools: HubWorkerTool[]
|
||||
}
|
||||
|
||||
export type HubWorkerCallToolMessage = {
|
||||
type: 'callTool'
|
||||
requestId: string
|
||||
functionName: string
|
||||
params: unknown
|
||||
}
|
||||
|
||||
export type HubWorkerToolResultMessage = {
|
||||
type: 'toolResult'
|
||||
requestId: string
|
||||
result: unknown
|
||||
}
|
||||
|
||||
export type HubWorkerToolErrorMessage = {
|
||||
type: 'toolError'
|
||||
requestId: string
|
||||
error: string
|
||||
}
|
||||
|
||||
export type HubWorkerResultMessage = {
|
||||
type: 'result'
|
||||
result: unknown
|
||||
logs?: string[]
|
||||
}
|
||||
|
||||
export type HubWorkerErrorMessage = {
|
||||
type: 'error'
|
||||
error: string
|
||||
logs?: string[]
|
||||
}
|
||||
|
||||
export type HubWorkerLogMessage = {
|
||||
type: 'log'
|
||||
entry: string
|
||||
}
|
||||
|
||||
export type HubWorkerMessage =
|
||||
| HubWorkerExecMessage
|
||||
| HubWorkerCallToolMessage
|
||||
| HubWorkerToolResultMessage
|
||||
| HubWorkerToolErrorMessage
|
||||
| HubWorkerResultMessage
|
||||
| HubWorkerErrorMessage
|
||||
| HubWorkerLogMessage
|
||||
|
||||
131
src/main/mcpServers/hub/worker.js
Normal file
131
src/main/mcpServers/hub/worker.js
Normal file
@ -0,0 +1,131 @@
|
||||
const crypto = require('node:crypto')
|
||||
const { parentPort } = require('node:worker_threads')
|
||||
|
||||
const MAX_LOGS = 1000
|
||||
|
||||
const logs = []
|
||||
const pendingCalls = new Map()
|
||||
let isExecuting = false
|
||||
|
||||
const stringify = (value) => {
|
||||
if (value === undefined) return 'undefined'
|
||||
if (value === null) return 'null'
|
||||
if (typeof value === 'string') return value
|
||||
if (typeof value === 'number' || typeof value === 'boolean') return String(value)
|
||||
if (value instanceof Error) return value.message
|
||||
|
||||
try {
|
||||
return JSON.stringify(value, null, 2)
|
||||
} catch {
|
||||
return String(value)
|
||||
}
|
||||
}
|
||||
|
||||
const pushLog = (level, args) => {
|
||||
if (logs.length >= MAX_LOGS) {
|
||||
return
|
||||
}
|
||||
const message = args.map((arg) => stringify(arg)).join(' ')
|
||||
const entry = `[${level}] ${message}`
|
||||
logs.push(entry)
|
||||
parentPort?.postMessage({ type: 'log', entry })
|
||||
}
|
||||
|
||||
const capturedConsole = {
|
||||
log: (...args) => pushLog('log', args),
|
||||
warn: (...args) => pushLog('warn', args),
|
||||
error: (...args) => pushLog('error', args),
|
||||
info: (...args) => pushLog('info', args),
|
||||
debug: (...args) => pushLog('debug', args)
|
||||
}
|
||||
|
||||
const callTool = (functionName, params) =>
|
||||
new Promise((resolve, reject) => {
|
||||
const requestId = crypto.randomUUID()
|
||||
pendingCalls.set(requestId, { resolve, reject })
|
||||
parentPort?.postMessage({ type: 'callTool', requestId, functionName, params })
|
||||
})
|
||||
|
||||
const buildContext = (tools) => {
|
||||
const context = {
|
||||
__callTool: callTool,
|
||||
parallel: (...promises) => Promise.all(promises),
|
||||
settle: (...promises) => Promise.allSettled(promises),
|
||||
console: capturedConsole
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
context[tool.functionName] = (params) => callTool(tool.functionName, params)
|
||||
}
|
||||
|
||||
return context
|
||||
}
|
||||
|
||||
const runCode = async (code, context) => {
|
||||
const contextKeys = Object.keys(context)
|
||||
const contextValues = contextKeys.map((key) => context[key])
|
||||
|
||||
const wrappedCode = `
|
||||
return (async () => {
|
||||
${code}
|
||||
})()
|
||||
`
|
||||
|
||||
const fn = new Function(...contextKeys, wrappedCode)
|
||||
return await fn(...contextValues)
|
||||
}
|
||||
|
||||
const handleExec = async (code, tools) => {
|
||||
if (isExecuting) {
|
||||
return
|
||||
}
|
||||
isExecuting = true
|
||||
|
||||
try {
|
||||
const context = buildContext(tools)
|
||||
const result = await runCode(code, context)
|
||||
parentPort?.postMessage({ type: 'result', result, logs: logs.length > 0 ? logs : undefined })
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
parentPort?.postMessage({ type: 'error', error: errorMessage, logs: logs.length > 0 ? logs : undefined })
|
||||
} finally {
|
||||
pendingCalls.clear()
|
||||
}
|
||||
}
|
||||
|
||||
const handleToolResult = (message) => {
|
||||
const pending = pendingCalls.get(message.requestId)
|
||||
if (!pending) {
|
||||
return
|
||||
}
|
||||
pendingCalls.delete(message.requestId)
|
||||
pending.resolve(message.result)
|
||||
}
|
||||
|
||||
const handleToolError = (message) => {
|
||||
const pending = pendingCalls.get(message.requestId)
|
||||
if (!pending) {
|
||||
return
|
||||
}
|
||||
pendingCalls.delete(message.requestId)
|
||||
pending.reject(new Error(message.error))
|
||||
}
|
||||
|
||||
parentPort?.on('message', (message) => {
|
||||
if (!message || typeof message !== 'object') {
|
||||
return
|
||||
}
|
||||
switch (message.type) {
|
||||
case 'exec':
|
||||
handleExec(message.code, message.tools ?? [])
|
||||
break
|
||||
case 'toolResult':
|
||||
handleToolResult(message)
|
||||
break
|
||||
case 'toolError':
|
||||
handleToolError(message)
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
})
|
||||
@ -180,7 +180,9 @@ class McpService {
|
||||
}
|
||||
try {
|
||||
const tools = await this.listToolsImpl(server)
|
||||
allTools.push(...tools)
|
||||
const disabledTools = new Set(server.disabledTools ?? [])
|
||||
const enabledTools = disabledTools.size > 0 ? tools.filter((tool) => !disabledTools.has(tool.name)) : tools
|
||||
allTools.push(...enabledTools)
|
||||
} catch (error) {
|
||||
logger.error(`[listAllActiveServerTools] Failed to list tools from ${server.name}:`, error as Error)
|
||||
}
|
||||
@ -195,7 +197,8 @@ class McpService {
|
||||
*/
|
||||
public async callToolById(
|
||||
toolId: string,
|
||||
params: unknown
|
||||
params: unknown,
|
||||
callId?: string
|
||||
): Promise<{ content: Array<{ type: string; text?: string }> }> {
|
||||
const parts = toolId.split('__')
|
||||
if (parts.length < 2) {
|
||||
@ -217,7 +220,8 @@ class McpService {
|
||||
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, {
|
||||
server,
|
||||
name: toolName,
|
||||
args: params
|
||||
args: params,
|
||||
callId
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
75
src/main/services/__tests__/MCPService.test.ts
Normal file
75
src/main/services/__tests__/MCPService.test.ts
Normal file
@ -0,0 +1,75 @@
|
||||
import type { MCPServer, MCPTool } from '@types'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@main/apiServer/utils/mcp', () => ({
|
||||
getMCPServersFromRedux: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@main/services/WindowService', () => ({
|
||||
windowService: {
|
||||
getMainWindow: vi.fn(() => null)
|
||||
}
|
||||
}))
|
||||
|
||||
import { getMCPServersFromRedux } from '@main/apiServer/utils/mcp'
|
||||
import mcpService from '@main/services/MCPService'
|
||||
|
||||
const baseInputSchema: { type: 'object'; properties: Record<string, unknown>; required: string[] } = {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
required: []
|
||||
}
|
||||
|
||||
const createTool = (overrides: Partial<MCPTool>): MCPTool => ({
|
||||
id: `${overrides.serverId}__${overrides.name}`,
|
||||
name: overrides.name ?? 'tool',
|
||||
description: overrides.description,
|
||||
serverId: overrides.serverId ?? 'server',
|
||||
serverName: overrides.serverName ?? 'server',
|
||||
inputSchema: baseInputSchema,
|
||||
type: 'mcp',
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('MCPService.listAllActiveServerTools', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('filters disabled tools per server', async () => {
|
||||
const servers: MCPServer[] = [
|
||||
{
|
||||
id: 'alpha',
|
||||
name: 'Alpha',
|
||||
isActive: true,
|
||||
disabledTools: ['disabled_tool']
|
||||
},
|
||||
{
|
||||
id: 'beta',
|
||||
name: 'Beta',
|
||||
isActive: true
|
||||
}
|
||||
]
|
||||
|
||||
vi.mocked(getMCPServersFromRedux).mockResolvedValue(servers)
|
||||
|
||||
const listToolsSpy = vi.spyOn(mcpService as any, 'listToolsImpl').mockImplementation(async (server: any) => {
|
||||
if (server.id === 'alpha') {
|
||||
return [
|
||||
createTool({ name: 'enabled_tool', serverId: server.id, serverName: server.name }),
|
||||
createTool({ name: 'disabled_tool', serverId: server.id, serverName: server.name })
|
||||
]
|
||||
}
|
||||
return [createTool({ name: 'beta_tool', serverId: server.id, serverName: server.name })]
|
||||
})
|
||||
|
||||
const tools = await mcpService.listAllActiveServerTools()
|
||||
|
||||
expect(listToolsSpy).toHaveBeenCalledTimes(2)
|
||||
expect(tools.map((tool) => tool.name)).toEqual(['enabled_tool', 'beta_tool'])
|
||||
})
|
||||
})
|
||||
@ -10,59 +10,69 @@ vi.mock('@logger', async () => {
|
||||
})
|
||||
|
||||
// Mock electron modules that are commonly used in main process
|
||||
vi.mock('electron', () => ({
|
||||
app: {
|
||||
getPath: vi.fn((key: string) => {
|
||||
switch (key) {
|
||||
case 'userData':
|
||||
return '/mock/userData'
|
||||
case 'temp':
|
||||
return '/mock/temp'
|
||||
case 'logs':
|
||||
return '/mock/logs'
|
||||
default:
|
||||
return '/mock/unknown'
|
||||
vi.mock('electron', () => {
|
||||
const mock = {
|
||||
app: {
|
||||
getPath: vi.fn((key: string) => {
|
||||
switch (key) {
|
||||
case 'userData':
|
||||
return '/mock/userData'
|
||||
case 'temp':
|
||||
return '/mock/temp'
|
||||
case 'logs':
|
||||
return '/mock/logs'
|
||||
default:
|
||||
return '/mock/unknown'
|
||||
}
|
||||
}),
|
||||
getVersion: vi.fn(() => '1.0.0')
|
||||
},
|
||||
ipcMain: {
|
||||
handle: vi.fn(),
|
||||
on: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeHandler: vi.fn(),
|
||||
removeAllListeners: vi.fn()
|
||||
},
|
||||
BrowserWindow: vi.fn(),
|
||||
dialog: {
|
||||
showErrorBox: vi.fn(),
|
||||
showMessageBox: vi.fn(),
|
||||
showOpenDialog: vi.fn(),
|
||||
showSaveDialog: vi.fn()
|
||||
},
|
||||
shell: {
|
||||
openExternal: vi.fn(),
|
||||
showItemInFolder: vi.fn()
|
||||
},
|
||||
session: {
|
||||
defaultSession: {
|
||||
clearCache: vi.fn(),
|
||||
clearStorageData: vi.fn()
|
||||
}
|
||||
}),
|
||||
getVersion: vi.fn(() => '1.0.0')
|
||||
},
|
||||
ipcMain: {
|
||||
handle: vi.fn(),
|
||||
on: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeHandler: vi.fn(),
|
||||
removeAllListeners: vi.fn()
|
||||
},
|
||||
BrowserWindow: vi.fn(),
|
||||
dialog: {
|
||||
showErrorBox: vi.fn(),
|
||||
showMessageBox: vi.fn(),
|
||||
showOpenDialog: vi.fn(),
|
||||
showSaveDialog: vi.fn()
|
||||
},
|
||||
shell: {
|
||||
openExternal: vi.fn(),
|
||||
showItemInFolder: vi.fn()
|
||||
},
|
||||
session: {
|
||||
defaultSession: {
|
||||
clearCache: vi.fn(),
|
||||
clearStorageData: vi.fn()
|
||||
}
|
||||
},
|
||||
webContents: {
|
||||
getAllWebContents: vi.fn(() => [])
|
||||
},
|
||||
systemPreferences: {
|
||||
getMediaAccessStatus: vi.fn(),
|
||||
askForMediaAccess: vi.fn()
|
||||
},
|
||||
screen: {
|
||||
getPrimaryDisplay: vi.fn(),
|
||||
getAllDisplays: vi.fn()
|
||||
},
|
||||
Notification: vi.fn()
|
||||
}))
|
||||
},
|
||||
webContents: {
|
||||
getAllWebContents: vi.fn(() => [])
|
||||
},
|
||||
systemPreferences: {
|
||||
getMediaAccessStatus: vi.fn(),
|
||||
askForMediaAccess: vi.fn()
|
||||
},
|
||||
nativeTheme: {
|
||||
themeSource: 'system',
|
||||
shouldUseDarkColors: false,
|
||||
on: vi.fn(),
|
||||
removeListener: vi.fn()
|
||||
},
|
||||
screen: {
|
||||
getPrimaryDisplay: vi.fn(),
|
||||
getAllDisplays: vi.fn()
|
||||
},
|
||||
Notification: vi.fn()
|
||||
}
|
||||
|
||||
return { __esModule: true, ...mock, default: mock }
|
||||
})
|
||||
|
||||
// Mock Winston for LoggerService dependencies
|
||||
vi.mock('winston', () => ({
|
||||
@ -98,13 +108,17 @@ vi.mock('winston-daily-rotate-file', () => {
|
||||
})
|
||||
|
||||
// Mock Node.js modules
|
||||
vi.mock('node:os', () => ({
|
||||
platform: vi.fn(() => 'darwin'),
|
||||
arch: vi.fn(() => 'x64'),
|
||||
version: vi.fn(() => '20.0.0'),
|
||||
cpus: vi.fn(() => [{ model: 'Mock CPU' }]),
|
||||
totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB
|
||||
}))
|
||||
vi.mock('node:os', () => {
|
||||
const mock = {
|
||||
platform: vi.fn(() => 'darwin'),
|
||||
arch: vi.fn(() => 'x64'),
|
||||
version: vi.fn(() => '20.0.0'),
|
||||
cpus: vi.fn(() => [{ model: 'Mock CPU' }]),
|
||||
homedir: vi.fn(() => '/mock/home'),
|
||||
totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB
|
||||
}
|
||||
return { ...mock, default: mock }
|
||||
})
|
||||
|
||||
vi.mock('node:path', async () => {
|
||||
const actual = await vi.importActual('node:path')
|
||||
@ -115,25 +129,29 @@ vi.mock('node:path', async () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('node:fs', () => ({
|
||||
promises: {
|
||||
access: vi.fn(),
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
mkdir: vi.fn(),
|
||||
readdir: vi.fn(),
|
||||
stat: vi.fn(),
|
||||
unlink: vi.fn(),
|
||||
rmdir: vi.fn()
|
||||
},
|
||||
existsSync: vi.fn(),
|
||||
readFileSync: vi.fn(),
|
||||
writeFileSync: vi.fn(),
|
||||
mkdirSync: vi.fn(),
|
||||
readdirSync: vi.fn(),
|
||||
statSync: vi.fn(),
|
||||
unlinkSync: vi.fn(),
|
||||
rmdirSync: vi.fn(),
|
||||
createReadStream: vi.fn(),
|
||||
createWriteStream: vi.fn()
|
||||
}))
|
||||
vi.mock('node:fs', () => {
|
||||
const mock = {
|
||||
promises: {
|
||||
access: vi.fn(),
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
mkdir: vi.fn(),
|
||||
readdir: vi.fn(),
|
||||
stat: vi.fn(),
|
||||
unlink: vi.fn(),
|
||||
rmdir: vi.fn()
|
||||
},
|
||||
existsSync: vi.fn(),
|
||||
readFileSync: vi.fn(),
|
||||
writeFileSync: vi.fn(),
|
||||
mkdirSync: vi.fn(),
|
||||
readdirSync: vi.fn(),
|
||||
statSync: vi.fn(),
|
||||
unlinkSync: vi.fn(),
|
||||
rmdirSync: vi.fn(),
|
||||
createReadStream: vi.fn(),
|
||||
createWriteStream: vi.fn()
|
||||
}
|
||||
|
||||
return { ...mock, default: mock }
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user