From 41699a1afd4c344d2524ebdfba8e3148d46cd246 Mon Sep 17 00:00:00 2001 From: Vaayne Date: Tue, 30 Dec 2025 11:28:20 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20isolate=20hub=20exec=20work?= =?UTF-8?q?er=20and=20filter=20disabled=20tools?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/mcpServers/hub/__tests__/hub.test.ts | 42 +++- .../mcpServers/hub/__tests__/runtime.test.ts | 2 +- src/main/mcpServers/hub/mcp-bridge.ts | 10 +- src/main/mcpServers/hub/runtime.ts | 218 +++++++++++------- src/main/mcpServers/hub/types.ts | 55 +++++ src/main/mcpServers/hub/worker.js | 131 +++++++++++ src/main/services/MCPService.ts | 10 +- .../services/__tests__/MCPService.test.ts | 75 ++++++ tests/main.setup.ts | 180 ++++++++------- 9 files changed, 557 insertions(+), 166 deletions(-) create mode 100644 src/main/mcpServers/hub/worker.js create mode 100644 src/main/services/__tests__/MCPService.test.ts diff --git a/src/main/mcpServers/hub/__tests__/hub.test.ts b/src/main/mcpServers/hub/__tests__/hub.test.ts index 657719d3e7..92a5b65578 100644 --- a/src/main/mcpServers/hub/__tests__/hub.test.ts +++ b/src/main/mcpServers/hub/__tests__/hub.test.ts @@ -72,7 +72,8 @@ vi.mock('@main/services/MCPService', () => ({ } } return { content: [{ type: 'text', text: '{}' }] } - }) + }), + abortTool: vi.fn(async () => true) } })) @@ -178,6 +179,45 @@ describe('HubServer Integration', () => { }) }) + describe('exec timeouts', () => { + afterEach(() => { + vi.useRealTimers() + }) + + it('aborts in-flight tool calls and returns logs on timeout', async () => { + vi.useFakeTimers() + + let toolCallStarted: (() => void) | null = null + const toolCallStartedPromise = new Promise((resolve) => { + toolCallStarted = resolve + }) + + vi.mocked(mcpService.callToolById).mockImplementationOnce(async () => { + toolCallStarted?.() + return await new Promise(() => {}) + }) + + const execPromise = (hubServer as any).handleExec({ + code: ` + console.log("starting"); + return await github_searchRepos({ query: "hang" }); + ` + }) + + await toolCallStartedPromise + await vi.advanceTimersByTimeAsync(60000) + await vi.runAllTimersAsync() + + const execResult = await execPromise + const execOutput = JSON.parse(execResult.content[0].text) + + expect(execOutput.error).toBe('Execution timed out after 60000ms') + expect(execOutput.result).toBeUndefined() + expect(execOutput.logs).toContain('[log] starting') + expect(vi.mocked(mcpService.abortTool)).toHaveBeenCalled() + }) + }) + describe('server instance', () => { it('creates a valid MCP server instance', () => { expect(hubServer.server).toBeDefined() diff --git a/src/main/mcpServers/hub/__tests__/runtime.test.ts b/src/main/mcpServers/hub/__tests__/runtime.test.ts index 15c7ab7880..86664a2021 100644 --- a/src/main/mcpServers/hub/__tests__/runtime.test.ts +++ b/src/main/mcpServers/hub/__tests__/runtime.test.ts @@ -56,7 +56,7 @@ describe('Runtime', () => { const result = await runtime.execute('return await searchRepos({ query: "test" })', tools) - expect(result.result).toEqual({ repos: ['repo1', 'repo2'], query: { query: 'test' } }) + expect(result.result).toEqual({ toolId: 'searchRepos', params: { query: 'test' }, success: true }) }) it('captures console logs', async () => { diff --git a/src/main/mcpServers/hub/mcp-bridge.ts b/src/main/mcpServers/hub/mcp-bridge.ts index 1e7bfc7817..83549f86ac 100644 --- a/src/main/mcpServers/hub/mcp-bridge.ts +++ b/src/main/mcpServers/hub/mcp-bridge.ts @@ -19,7 +19,7 @@ export async function refreshToolMap(): Promise { } } -export const callMcpTool = async (functionName: string, params: unknown): Promise => { +export const callMcpTool = async (functionName: string, params: unknown, callId?: string): Promise => { const toolInfo = toolFunctionNameToIdMap.get(functionName) if (!toolInfo) { await refreshToolMap() @@ -28,14 +28,18 @@ export const callMcpTool = async (functionName: string, params: unknown): Promis throw new Error(`Tool not found: ${functionName}`) } const toolId = `${retryToolInfo.serverId}__${retryToolInfo.toolName}` - const result = await mcpService.callToolById(toolId, params) + const result = await mcpService.callToolById(toolId, params, callId) return extractToolResult(result) } const toolId = `${toolInfo.serverId}__${toolInfo.toolName}` - const result = await mcpService.callToolById(toolId, params) + const result = await mcpService.callToolById(toolId, params, callId) return extractToolResult(result) } +export const abortMcpTool = async (callId: string): Promise => { + return mcpService.abortTool(null as unknown as Electron.IpcMainInvokeEvent, callId) +} + function extractToolResult(result: { content: Array<{ type: string; text?: string }> }): unknown { if (!result.content || result.content.length === 0) { return null diff --git a/src/main/mcpServers/hub/runtime.ts b/src/main/mcpServers/hub/runtime.ts index 3b34a49dab..8a3495560b 100644 --- a/src/main/mcpServers/hub/runtime.ts +++ b/src/main/mcpServers/hub/runtime.ts @@ -1,105 +1,169 @@ +import crypto from 'node:crypto' +import { Worker } from 'node:worker_threads' + import { loggerService } from '@logger' -import { callMcpTool } from './mcp-bridge' -import type { ConsoleMethods, ExecOutput, ExecutionContext, GeneratedTool } from './types' +import { abortMcpTool, callMcpTool } from './mcp-bridge' +import type { + ExecOutput, + GeneratedTool, + HubWorkerCallToolMessage, + HubWorkerExecMessage, + HubWorkerMessage, + HubWorkerResultMessage +} from './types' const logger = loggerService.withContext('MCPServer:Hub:Runtime') const MAX_LOGS = 1000 const EXECUTION_TIMEOUT = 60000 +const WORKER_URL = new URL('./worker.js', import.meta.url) export class Runtime { async execute(code: string, tools: GeneratedTool[]): Promise { - const logs: string[] = [] - const capturedConsole = this.createCapturedConsole(logs) + return await new Promise((resolve) => { + const logs: string[] = [] + const activeCallIds = new Map() + let finished = false + let timedOut = false + let timeoutId: NodeJS.Timeout | null = null - try { - const context = this.buildContext(tools, capturedConsole) - const result = await this.runCode(code, context) + const worker = new Worker(WORKER_URL) - return { - result, - logs: logs.length > 0 ? logs : undefined + const addLog = (entry: string) => { + if (logs.length >= MAX_LOGS) { + return + } + logs.push(entry) } - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - logger.error('Execution error:', error as Error) - return { - result: undefined, - logs: logs.length > 0 ? logs : undefined, - error: errorMessage + const finalize = async (output: ExecOutput, terminateWorker = true) => { + if (finished) { + return + } + finished = true + if (timeoutId) { + clearTimeout(timeoutId) + } + worker.removeAllListeners() + if (terminateWorker) { + try { + await worker.terminate() + } catch (error) { + logger.warn('Failed to terminate exec worker', error as Error) + } + } + resolve(output) } - } - } - private buildContext(tools: GeneratedTool[], capturedConsole: ConsoleMethods): ExecutionContext { - const context: ExecutionContext = { - __callTool: callMcpTool, - parallel: (...promises: Promise[]) => Promise.all(promises), - settle: (...promises: Promise[]) => Promise.allSettled(promises), - console: capturedConsole - } + const abortActiveTools = async () => { + const callIds = Array.from(activeCallIds.values()) + activeCallIds.clear() + if (callIds.length === 0) { + return + } + await Promise.allSettled(callIds.map((callId) => abortMcpTool(callId))) + } - for (const tool of tools) { - context[tool.functionName] = tool.fn - } + const handleToolCall = async (message: HubWorkerCallToolMessage) => { + if (finished || timedOut) { + return + } + const callId = crypto.randomUUID() + activeCallIds.set(message.requestId, callId) - return context - } + try { + const result = await callMcpTool(message.functionName, message.params, callId) + if (finished || timedOut) { + return + } + worker.postMessage({ type: 'toolResult', requestId: message.requestId, result }) + } catch (error) { + if (finished || timedOut) { + return + } + const errorMessage = error instanceof Error ? error.message : String(error) + worker.postMessage({ type: 'toolError', requestId: message.requestId, error: errorMessage }) + } finally { + activeCallIds.delete(message.requestId) + } + } - private async runCode(code: string, context: ExecutionContext): Promise { - const contextKeys = Object.keys(context) - const contextValues = contextKeys.map((k) => context[k]) + const handleResult = (message: HubWorkerResultMessage) => { + const resolvedLogs = message.logs && message.logs.length > 0 ? message.logs : logs + void finalize({ + result: message.result, + logs: resolvedLogs.length > 0 ? resolvedLogs : undefined + }) + } - const wrappedCode = ` - return (async () => { - ${code} - })() - ` + const handleError = (errorMessage: string, messageLogs?: string[], terminateWorker = true) => { + const resolvedLogs = messageLogs && messageLogs.length > 0 ? messageLogs : logs + void finalize( + { + result: undefined, + logs: resolvedLogs.length > 0 ? resolvedLogs : undefined, + error: errorMessage + }, + terminateWorker + ) + } - const fn = new Function(...contextKeys, wrappedCode) + const handleMessage = (message: HubWorkerMessage) => { + if (!message || typeof message !== 'object') { + return + } + switch (message.type) { + case 'log': + addLog(message.entry) + break + case 'callTool': + void handleToolCall(message) + break + case 'result': + handleResult(message) + break + case 'error': + handleError(message.error, message.logs) + break + default: + break + } + } - const timeoutPromise = new Promise((_, reject) => { - setTimeout(() => { - reject(new Error(`Execution timed out after ${EXECUTION_TIMEOUT}ms`)) + timeoutId = setTimeout(() => { + timedOut = true + void (async () => { + await abortActiveTools() + try { + await worker.terminate() + } catch (error) { + logger.warn('Failed to terminate exec worker after timeout', error as Error) + } + handleError(`Execution timed out after ${EXECUTION_TIMEOUT}ms`, undefined, false) + })() }, EXECUTION_TIMEOUT) - }) - const executionPromise = fn(...contextValues) + worker.on('message', handleMessage) + worker.on('error', (error) => { + logger.error('Worker execution error', error) + handleError(error instanceof Error ? error.message : String(error)) + }) + worker.on('exit', (code) => { + if (finished || timedOut) { + return + } + const message = code === 0 ? 'Exec worker exited unexpectedly' : `Exec worker exited with code ${code}` + logger.error(message) + handleError(message, undefined, false) + }) - return Promise.race([executionPromise, timeoutPromise]) - } - - private createCapturedConsole(logs: string[]): ConsoleMethods { - const addLog = (level: string, ...args: unknown[]) => { - if (logs.length >= MAX_LOGS) { - return + const execMessage: HubWorkerExecMessage = { + type: 'exec', + code, + tools: tools.map((tool) => ({ functionName: tool.functionName })) } - const message = args.map((arg) => this.stringify(arg)).join(' ') - logs.push(`[${level}] ${message}`) - } - - return { - log: (...args: unknown[]) => addLog('log', ...args), - warn: (...args: unknown[]) => addLog('warn', ...args), - error: (...args: unknown[]) => addLog('error', ...args), - info: (...args: unknown[]) => addLog('info', ...args), - debug: (...args: unknown[]) => addLog('debug', ...args) - } - } - - private stringify(value: unknown): string { - if (value === undefined) return 'undefined' - if (value === null) return 'null' - if (typeof value === 'string') return value - if (typeof value === 'number' || typeof value === 'boolean') return String(value) - if (value instanceof Error) return value.message - - try { - return JSON.stringify(value, null, 2) - } catch { - return String(value) - } + worker.postMessage(execMessage) + }) } } diff --git a/src/main/mcpServers/hub/types.ts b/src/main/mcpServers/hub/types.ts index 8b34fd7e14..2bff26b8dc 100644 --- a/src/main/mcpServers/hub/types.ts +++ b/src/main/mcpServers/hub/types.ts @@ -55,3 +55,58 @@ export interface ConsoleMethods { info: (...args: unknown[]) => void debug: (...args: unknown[]) => void } + +export type HubWorkerTool = { + functionName: string +} + +export type HubWorkerExecMessage = { + type: 'exec' + code: string + tools: HubWorkerTool[] +} + +export type HubWorkerCallToolMessage = { + type: 'callTool' + requestId: string + functionName: string + params: unknown +} + +export type HubWorkerToolResultMessage = { + type: 'toolResult' + requestId: string + result: unknown +} + +export type HubWorkerToolErrorMessage = { + type: 'toolError' + requestId: string + error: string +} + +export type HubWorkerResultMessage = { + type: 'result' + result: unknown + logs?: string[] +} + +export type HubWorkerErrorMessage = { + type: 'error' + error: string + logs?: string[] +} + +export type HubWorkerLogMessage = { + type: 'log' + entry: string +} + +export type HubWorkerMessage = + | HubWorkerExecMessage + | HubWorkerCallToolMessage + | HubWorkerToolResultMessage + | HubWorkerToolErrorMessage + | HubWorkerResultMessage + | HubWorkerErrorMessage + | HubWorkerLogMessage diff --git a/src/main/mcpServers/hub/worker.js b/src/main/mcpServers/hub/worker.js new file mode 100644 index 0000000000..a8706f0704 --- /dev/null +++ b/src/main/mcpServers/hub/worker.js @@ -0,0 +1,131 @@ +const crypto = require('node:crypto') +const { parentPort } = require('node:worker_threads') + +const MAX_LOGS = 1000 + +const logs = [] +const pendingCalls = new Map() +let isExecuting = false + +const stringify = (value) => { + if (value === undefined) return 'undefined' + if (value === null) return 'null' + if (typeof value === 'string') return value + if (typeof value === 'number' || typeof value === 'boolean') return String(value) + if (value instanceof Error) return value.message + + try { + return JSON.stringify(value, null, 2) + } catch { + return String(value) + } +} + +const pushLog = (level, args) => { + if (logs.length >= MAX_LOGS) { + return + } + const message = args.map((arg) => stringify(arg)).join(' ') + const entry = `[${level}] ${message}` + logs.push(entry) + parentPort?.postMessage({ type: 'log', entry }) +} + +const capturedConsole = { + log: (...args) => pushLog('log', args), + warn: (...args) => pushLog('warn', args), + error: (...args) => pushLog('error', args), + info: (...args) => pushLog('info', args), + debug: (...args) => pushLog('debug', args) +} + +const callTool = (functionName, params) => + new Promise((resolve, reject) => { + const requestId = crypto.randomUUID() + pendingCalls.set(requestId, { resolve, reject }) + parentPort?.postMessage({ type: 'callTool', requestId, functionName, params }) + }) + +const buildContext = (tools) => { + const context = { + __callTool: callTool, + parallel: (...promises) => Promise.all(promises), + settle: (...promises) => Promise.allSettled(promises), + console: capturedConsole + } + + for (const tool of tools) { + context[tool.functionName] = (params) => callTool(tool.functionName, params) + } + + return context +} + +const runCode = async (code, context) => { + const contextKeys = Object.keys(context) + const contextValues = contextKeys.map((key) => context[key]) + + const wrappedCode = ` + return (async () => { + ${code} + })() + ` + + const fn = new Function(...contextKeys, wrappedCode) + return await fn(...contextValues) +} + +const handleExec = async (code, tools) => { + if (isExecuting) { + return + } + isExecuting = true + + try { + const context = buildContext(tools) + const result = await runCode(code, context) + parentPort?.postMessage({ type: 'result', result, logs: logs.length > 0 ? logs : undefined }) + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + parentPort?.postMessage({ type: 'error', error: errorMessage, logs: logs.length > 0 ? logs : undefined }) + } finally { + pendingCalls.clear() + } +} + +const handleToolResult = (message) => { + const pending = pendingCalls.get(message.requestId) + if (!pending) { + return + } + pendingCalls.delete(message.requestId) + pending.resolve(message.result) +} + +const handleToolError = (message) => { + const pending = pendingCalls.get(message.requestId) + if (!pending) { + return + } + pendingCalls.delete(message.requestId) + pending.reject(new Error(message.error)) +} + +parentPort?.on('message', (message) => { + if (!message || typeof message !== 'object') { + return + } + switch (message.type) { + case 'exec': + handleExec(message.code, message.tools ?? []) + break + case 'toolResult': + handleToolResult(message) + break + case 'toolError': + handleToolError(message) + break + default: + break + } +}) diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index d35d895914..f106dd9faf 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -180,7 +180,9 @@ class McpService { } try { const tools = await this.listToolsImpl(server) - allTools.push(...tools) + const disabledTools = new Set(server.disabledTools ?? []) + const enabledTools = disabledTools.size > 0 ? tools.filter((tool) => !disabledTools.has(tool.name)) : tools + allTools.push(...enabledTools) } catch (error) { logger.error(`[listAllActiveServerTools] Failed to list tools from ${server.name}:`, error as Error) } @@ -195,7 +197,8 @@ class McpService { */ public async callToolById( toolId: string, - params: unknown + params: unknown, + callId?: string ): Promise<{ content: Array<{ type: string; text?: string }> }> { const parts = toolId.split('__') if (parts.length < 2) { @@ -217,7 +220,8 @@ class McpService { return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, { server, name: toolName, - args: params + args: params, + callId }) } diff --git a/src/main/services/__tests__/MCPService.test.ts b/src/main/services/__tests__/MCPService.test.ts new file mode 100644 index 0000000000..4757d20cff --- /dev/null +++ b/src/main/services/__tests__/MCPService.test.ts @@ -0,0 +1,75 @@ +import type { MCPServer, MCPTool } from '@types' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@main/apiServer/utils/mcp', () => ({ + getMCPServersFromRedux: vi.fn() +})) + +vi.mock('@main/services/WindowService', () => ({ + windowService: { + getMainWindow: vi.fn(() => null) + } +})) + +import { getMCPServersFromRedux } from '@main/apiServer/utils/mcp' +import mcpService from '@main/services/MCPService' + +const baseInputSchema: { type: 'object'; properties: Record; required: string[] } = { + type: 'object', + properties: {}, + required: [] +} + +const createTool = (overrides: Partial): MCPTool => ({ + id: `${overrides.serverId}__${overrides.name}`, + name: overrides.name ?? 'tool', + description: overrides.description, + serverId: overrides.serverId ?? 'server', + serverName: overrides.serverName ?? 'server', + inputSchema: baseInputSchema, + type: 'mcp', + ...overrides +}) + +describe('MCPService.listAllActiveServerTools', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('filters disabled tools per server', async () => { + const servers: MCPServer[] = [ + { + id: 'alpha', + name: 'Alpha', + isActive: true, + disabledTools: ['disabled_tool'] + }, + { + id: 'beta', + name: 'Beta', + isActive: true + } + ] + + vi.mocked(getMCPServersFromRedux).mockResolvedValue(servers) + + const listToolsSpy = vi.spyOn(mcpService as any, 'listToolsImpl').mockImplementation(async (server: any) => { + if (server.id === 'alpha') { + return [ + createTool({ name: 'enabled_tool', serverId: server.id, serverName: server.name }), + createTool({ name: 'disabled_tool', serverId: server.id, serverName: server.name }) + ] + } + return [createTool({ name: 'beta_tool', serverId: server.id, serverName: server.name })] + }) + + const tools = await mcpService.listAllActiveServerTools() + + expect(listToolsSpy).toHaveBeenCalledTimes(2) + expect(tools.map((tool) => tool.name)).toEqual(['enabled_tool', 'beta_tool']) + }) +}) diff --git a/tests/main.setup.ts b/tests/main.setup.ts index 5cadb89d02..9d6731e4a7 100644 --- a/tests/main.setup.ts +++ b/tests/main.setup.ts @@ -10,59 +10,69 @@ vi.mock('@logger', async () => { }) // Mock electron modules that are commonly used in main process -vi.mock('electron', () => ({ - app: { - getPath: vi.fn((key: string) => { - switch (key) { - case 'userData': - return '/mock/userData' - case 'temp': - return '/mock/temp' - case 'logs': - return '/mock/logs' - default: - return '/mock/unknown' +vi.mock('electron', () => { + const mock = { + app: { + getPath: vi.fn((key: string) => { + switch (key) { + case 'userData': + return '/mock/userData' + case 'temp': + return '/mock/temp' + case 'logs': + return '/mock/logs' + default: + return '/mock/unknown' + } + }), + getVersion: vi.fn(() => '1.0.0') + }, + ipcMain: { + handle: vi.fn(), + on: vi.fn(), + once: vi.fn(), + removeHandler: vi.fn(), + removeAllListeners: vi.fn() + }, + BrowserWindow: vi.fn(), + dialog: { + showErrorBox: vi.fn(), + showMessageBox: vi.fn(), + showOpenDialog: vi.fn(), + showSaveDialog: vi.fn() + }, + shell: { + openExternal: vi.fn(), + showItemInFolder: vi.fn() + }, + session: { + defaultSession: { + clearCache: vi.fn(), + clearStorageData: vi.fn() } - }), - getVersion: vi.fn(() => '1.0.0') - }, - ipcMain: { - handle: vi.fn(), - on: vi.fn(), - once: vi.fn(), - removeHandler: vi.fn(), - removeAllListeners: vi.fn() - }, - BrowserWindow: vi.fn(), - dialog: { - showErrorBox: vi.fn(), - showMessageBox: vi.fn(), - showOpenDialog: vi.fn(), - showSaveDialog: vi.fn() - }, - shell: { - openExternal: vi.fn(), - showItemInFolder: vi.fn() - }, - session: { - defaultSession: { - clearCache: vi.fn(), - clearStorageData: vi.fn() - } - }, - webContents: { - getAllWebContents: vi.fn(() => []) - }, - systemPreferences: { - getMediaAccessStatus: vi.fn(), - askForMediaAccess: vi.fn() - }, - screen: { - getPrimaryDisplay: vi.fn(), - getAllDisplays: vi.fn() - }, - Notification: vi.fn() -})) + }, + webContents: { + getAllWebContents: vi.fn(() => []) + }, + systemPreferences: { + getMediaAccessStatus: vi.fn(), + askForMediaAccess: vi.fn() + }, + nativeTheme: { + themeSource: 'system', + shouldUseDarkColors: false, + on: vi.fn(), + removeListener: vi.fn() + }, + screen: { + getPrimaryDisplay: vi.fn(), + getAllDisplays: vi.fn() + }, + Notification: vi.fn() + } + + return { __esModule: true, ...mock, default: mock } +}) // Mock Winston for LoggerService dependencies vi.mock('winston', () => ({ @@ -98,13 +108,17 @@ vi.mock('winston-daily-rotate-file', () => { }) // Mock Node.js modules -vi.mock('node:os', () => ({ - platform: vi.fn(() => 'darwin'), - arch: vi.fn(() => 'x64'), - version: vi.fn(() => '20.0.0'), - cpus: vi.fn(() => [{ model: 'Mock CPU' }]), - totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB -})) +vi.mock('node:os', () => { + const mock = { + platform: vi.fn(() => 'darwin'), + arch: vi.fn(() => 'x64'), + version: vi.fn(() => '20.0.0'), + cpus: vi.fn(() => [{ model: 'Mock CPU' }]), + homedir: vi.fn(() => '/mock/home'), + totalmem: vi.fn(() => 8 * 1024 * 1024 * 1024) // 8GB + } + return { ...mock, default: mock } +}) vi.mock('node:path', async () => { const actual = await vi.importActual('node:path') @@ -115,25 +129,29 @@ vi.mock('node:path', async () => { } }) -vi.mock('node:fs', () => ({ - promises: { - access: vi.fn(), - readFile: vi.fn(), - writeFile: vi.fn(), - mkdir: vi.fn(), - readdir: vi.fn(), - stat: vi.fn(), - unlink: vi.fn(), - rmdir: vi.fn() - }, - existsSync: vi.fn(), - readFileSync: vi.fn(), - writeFileSync: vi.fn(), - mkdirSync: vi.fn(), - readdirSync: vi.fn(), - statSync: vi.fn(), - unlinkSync: vi.fn(), - rmdirSync: vi.fn(), - createReadStream: vi.fn(), - createWriteStream: vi.fn() -})) +vi.mock('node:fs', () => { + const mock = { + promises: { + access: vi.fn(), + readFile: vi.fn(), + writeFile: vi.fn(), + mkdir: vi.fn(), + readdir: vi.fn(), + stat: vi.fn(), + unlink: vi.fn(), + rmdir: vi.fn() + }, + existsSync: vi.fn(), + readFileSync: vi.fn(), + writeFileSync: vi.fn(), + mkdirSync: vi.fn(), + readdirSync: vi.fn(), + statSync: vi.fn(), + unlinkSync: vi.fn(), + rmdirSync: vi.fn(), + createReadStream: vi.fn(), + createWriteStream: vi.fn() + } + + return { ...mock, default: mock } +})