♻️ refactor(hub): simplify dependency injection for HubServer

- Remove HubServerDependencies interface and setHubServerDependencies from factory
- Add initHubBridge() to mcp-bridge for direct initialization
- Make HubServer constructor parameterless (uses pre-initialized bridge)
- MCPService now calls initHubBridge() directly instead of factory setter
- Add integration tests for full search → exec flow
This commit is contained in:
Vaayne 2025-12-29 15:36:30 +08:00
parent 19f3e6f2f0
commit d1ed8c315f
5 changed files with 200 additions and 41 deletions

View File

@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import type { Server } from '@modelcontextprotocol/sdk/server/index.js'
import type { BuiltinMCPServerName, MCPServer } from '@types'
import type { BuiltinMCPServerName } from '@types'
import { BuiltinMCPServerNames } from '@types'
import BraveSearchServer from './brave-search'
@ -16,23 +16,6 @@ import ThinkingServer from './sequentialthinking'
const logger = loggerService.withContext('MCPFactory')
interface HubServerDependencies {
mcpService: {
listTools(_: null, server: MCPServer): Promise<unknown[]>
callTool(
_: null,
args: { server: MCPServer; name: string; args: unknown; callId?: string }
): Promise<{ content: Array<{ type: string; text?: string }> }>
}
mcpServersGetter: () => MCPServer[]
}
let hubServerDependencies: HubServerDependencies | null = null
export function setHubServerDependencies(deps: HubServerDependencies): void {
hubServerDependencies = deps
}
export function createInMemoryMCPServer(
name: BuiltinMCPServerName,
args: string[] = [],
@ -71,10 +54,7 @@ export function createInMemoryMCPServer(
return new BrowserServer().server
}
case BuiltinMCPServerNames.hub: {
if (!hubServerDependencies) {
throw new Error('Hub server dependencies not set. Call setHubServerDependencies first.')
}
return new HubServer(hubServerDependencies.mcpService, hubServerDependencies.mcpServersGetter).server
return new HubServer().server
}
default:
throw new Error(`Unknown in-memory MCP server: ${name}`)

View File

@ -0,0 +1,186 @@
import type { MCPServer } from '@types'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { HubServer } from '../index'
import { initHubBridge } from '../mcp-bridge'
const mockMcpServers: MCPServer[] = [
{
id: 'github',
name: 'GitHub',
command: 'npx',
args: ['-y', 'github-mcp-server'],
isActive: true
} as MCPServer,
{
id: 'database',
name: 'Database',
command: 'npx',
args: ['-y', 'db-mcp-server'],
isActive: true
} as MCPServer
]
const mockToolDefinitions = {
github: [
{
name: 'search_repos',
description: 'Search for GitHub repositories',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string', description: 'Search query' },
limit: { type: 'number', description: 'Max results' }
},
required: ['query']
}
},
{
name: 'get_user',
description: 'Get GitHub user profile',
inputSchema: {
type: 'object',
properties: {
username: { type: 'string', description: 'GitHub username' }
},
required: ['username']
}
}
],
database: [
{
name: 'query',
description: 'Execute a database query',
inputSchema: {
type: 'object',
properties: {
sql: { type: 'string', description: 'SQL query to execute' }
},
required: ['sql']
}
}
]
}
const mockMcpService = {
listTools: vi.fn(async (_: null, server: MCPServer) => {
return mockToolDefinitions[server.id as keyof typeof mockToolDefinitions] || []
}),
callTool: vi.fn(async (_: null, args: { server: MCPServer; name: string; args: unknown }) => {
if (args.server.id === 'github' && args.name === 'search_repos') {
return {
content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args.args }) }]
}
}
if (args.server.id === 'github' && args.name === 'get_user') {
return {
content: [{ type: 'text', text: JSON.stringify({ username: (args.args as any).username, id: 123 }) }]
}
}
if (args.server.id === 'database' && args.name === 'query') {
return {
content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }]
}
}
return { content: [{ type: 'text', text: '{}' }] }
})
}
describe('HubServer Integration', () => {
let hubServer: HubServer
beforeEach(() => {
vi.clearAllMocks()
initHubBridge(mockMcpService as any, () => mockMcpServers)
hubServer = new HubServer()
})
afterEach(() => {
hubServer.invalidateCache()
})
describe('full search → exec flow', () => {
it('searches for tools and executes them', async () => {
const searchResult = await (hubServer as any).handleSearch({ query: 'github,repos' })
expect(searchResult.content).toBeDefined()
const searchText = JSON.parse(searchResult.content[0].text)
expect(searchText.total).toBeGreaterThan(0)
expect(searchText.tools).toContain('searchRepos')
const execResult = await (hubServer as any).handleExec({
code: 'return await searchRepos({ query: "test" })'
})
expect(execResult.content).toBeDefined()
const execOutput = JSON.parse(execResult.content[0].text)
expect(execOutput.result).toEqual({ repos: ['repo1', 'repo2'], query: { query: 'test' } })
})
it('handles multiple tool calls in parallel', async () => {
await (hubServer as any).handleSearch({ query: 'github' })
const execResult = await (hubServer as any).handleExec({
code: `
const results = await parallel(
searchRepos({ query: "react" }),
getUser({ username: "octocat" })
);
return results
`
})
const execOutput = JSON.parse(execResult.content[0].text)
expect(execOutput.result).toHaveLength(2)
expect(execOutput.result[0]).toEqual({ repos: ['repo1', 'repo2'], query: { query: 'react' } })
expect(execOutput.result[1]).toEqual({ username: 'octocat', id: 123 })
})
it('searches across multiple servers', async () => {
const searchResult = await (hubServer as any).handleSearch({ query: 'query' })
const searchText = JSON.parse(searchResult.content[0].text)
expect(searchText.tools).toContain('query')
})
})
describe('cache invalidation', () => {
it('refreshes tools after invalidation', async () => {
await (hubServer as any).handleSearch({ query: 'github' })
const initialCallCount = mockMcpService.listTools.mock.calls.length
hubServer.invalidateCache()
await (hubServer as any).handleSearch({ query: 'github' })
expect(mockMcpService.listTools.mock.calls.length).toBeGreaterThan(initialCallCount)
})
})
describe('error handling', () => {
it('throws error for invalid search query', async () => {
await expect((hubServer as any).handleSearch({})).rejects.toThrow('query parameter is required')
})
it('throws error for invalid exec code', async () => {
await expect((hubServer as any).handleExec({})).rejects.toThrow('code parameter is required')
})
it('handles runtime errors in exec', async () => {
const execResult = await (hubServer as any).handleExec({
code: 'throw new Error("test error")'
})
const execOutput = JSON.parse(execResult.content[0].text)
expect(execOutput.error).toBe('test error')
})
})
describe('server instance', () => {
it('creates a valid MCP server instance', () => {
expect(hubServer.server).toBeDefined()
expect(hubServer.server.setRequestHandler).toBeDefined()
})
})
})

View File

@ -1,9 +1,7 @@
import { loggerService } from '@logger'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js'
import type { MCPServer } from '@types'
import { setMCPServersGetter, setMCPService } from './mcp-bridge'
import { Runtime } from './runtime'
import { searchTools } from './search'
import { ToolRegistry } from './tool-registry'
@ -11,23 +9,12 @@ import type { ExecInput, SearchQuery } from './types'
const logger = loggerService.withContext('MCPServer:Hub')
interface MCPServiceInterface {
listTools(_: null, server: MCPServer): Promise<unknown[]>
callTool(
_: null,
args: { server: MCPServer; name: string; args: unknown; callId?: string }
): Promise<{ content: Array<{ type: string; text?: string }> }>
}
export class HubServer {
public server: Server
private toolRegistry: ToolRegistry
private runtime: Runtime
constructor(mcpService: MCPServiceInterface, mcpServersGetter: () => MCPServer[]) {
setMCPService(mcpService as any)
setMCPServersGetter(mcpServersGetter)
constructor() {
this.toolRegistry = new ToolRegistry()
this.runtime = new Runtime()

View File

@ -22,6 +22,11 @@ export function setMCPServersGetter(getter: () => MCPServer[]): void {
mcpServersGetter = getter
}
export function initHubBridge(service: MCPServiceInterface, serversGetter: () => MCPServer[]): void {
mcpServiceInstance = service
mcpServersGetter = serversGetter
}
export function getActiveServers(): MCPServer[] {
if (!mcpServersGetter) {
logger.warn('MCP servers getter not set')

View File

@ -3,7 +3,8 @@ import os from 'node:os'
import path from 'node:path'
import { loggerService } from '@logger'
import { createInMemoryMCPServer, setHubServerDependencies } from '@main/mcpServers/factory'
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
import { initHubBridge } from '@main/mcpServers/hub/mcp-bridge'
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
import { buildFunctionCallToolName } from '@main/utils/mcp'
import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process'
@ -169,14 +170,14 @@ class McpService {
}
private initializeHubDependencies(): void {
setHubServerDependencies({
mcpService: {
initHubBridge(
{
listTools: (_: null, server: MCPServer) => this.listToolsImpl(server),
callTool: async (_: null, args: { server: MCPServer; name: string; args: unknown; callId?: string }) => {
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, args)
}
},
mcpServersGetter: () => {
() => {
try {
const servers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
return servers || []
@ -184,7 +185,7 @@ class McpService {
return []
}
}
})
)
}
private getServerKey(server: MCPServer): string {