mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-12 17:09:37 +08:00
♻️ refactor(hub): simplify dependency injection for HubServer
- Remove HubServerDependencies interface and setHubServerDependencies from factory - Add initHubBridge() to mcp-bridge for direct initialization - Make HubServer constructor parameterless (uses pre-initialized bridge) - MCPService now calls initHubBridge() directly instead of factory setter - Add integration tests for full search → exec flow
This commit is contained in:
parent
19f3e6f2f0
commit
d1ed8c315f
@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import type { BuiltinMCPServerName, MCPServer } from '@types'
|
||||
import type { BuiltinMCPServerName } from '@types'
|
||||
import { BuiltinMCPServerNames } from '@types'
|
||||
|
||||
import BraveSearchServer from './brave-search'
|
||||
@ -16,23 +16,6 @@ import ThinkingServer from './sequentialthinking'
|
||||
|
||||
const logger = loggerService.withContext('MCPFactory')
|
||||
|
||||
interface HubServerDependencies {
|
||||
mcpService: {
|
||||
listTools(_: null, server: MCPServer): Promise<unknown[]>
|
||||
callTool(
|
||||
_: null,
|
||||
args: { server: MCPServer; name: string; args: unknown; callId?: string }
|
||||
): Promise<{ content: Array<{ type: string; text?: string }> }>
|
||||
}
|
||||
mcpServersGetter: () => MCPServer[]
|
||||
}
|
||||
|
||||
let hubServerDependencies: HubServerDependencies | null = null
|
||||
|
||||
export function setHubServerDependencies(deps: HubServerDependencies): void {
|
||||
hubServerDependencies = deps
|
||||
}
|
||||
|
||||
export function createInMemoryMCPServer(
|
||||
name: BuiltinMCPServerName,
|
||||
args: string[] = [],
|
||||
@ -71,10 +54,7 @@ export function createInMemoryMCPServer(
|
||||
return new BrowserServer().server
|
||||
}
|
||||
case BuiltinMCPServerNames.hub: {
|
||||
if (!hubServerDependencies) {
|
||||
throw new Error('Hub server dependencies not set. Call setHubServerDependencies first.')
|
||||
}
|
||||
return new HubServer(hubServerDependencies.mcpService, hubServerDependencies.mcpServersGetter).server
|
||||
return new HubServer().server
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
||||
|
||||
186
src/main/mcpServers/hub/__tests__/hub.test.ts
Normal file
186
src/main/mcpServers/hub/__tests__/hub.test.ts
Normal file
@ -0,0 +1,186 @@
|
||||
import type { MCPServer } from '@types'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { HubServer } from '../index'
|
||||
import { initHubBridge } from '../mcp-bridge'
|
||||
|
||||
const mockMcpServers: MCPServer[] = [
|
||||
{
|
||||
id: 'github',
|
||||
name: 'GitHub',
|
||||
command: 'npx',
|
||||
args: ['-y', 'github-mcp-server'],
|
||||
isActive: true
|
||||
} as MCPServer,
|
||||
{
|
||||
id: 'database',
|
||||
name: 'Database',
|
||||
command: 'npx',
|
||||
args: ['-y', 'db-mcp-server'],
|
||||
isActive: true
|
||||
} as MCPServer
|
||||
]
|
||||
|
||||
const mockToolDefinitions = {
|
||||
github: [
|
||||
{
|
||||
name: 'search_repos',
|
||||
description: 'Search for GitHub repositories',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string', description: 'Search query' },
|
||||
limit: { type: 'number', description: 'Max results' }
|
||||
},
|
||||
required: ['query']
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'get_user',
|
||||
description: 'Get GitHub user profile',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
username: { type: 'string', description: 'GitHub username' }
|
||||
},
|
||||
required: ['username']
|
||||
}
|
||||
}
|
||||
],
|
||||
database: [
|
||||
{
|
||||
name: 'query',
|
||||
description: 'Execute a database query',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
sql: { type: 'string', description: 'SQL query to execute' }
|
||||
},
|
||||
required: ['sql']
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
const mockMcpService = {
|
||||
listTools: vi.fn(async (_: null, server: MCPServer) => {
|
||||
return mockToolDefinitions[server.id as keyof typeof mockToolDefinitions] || []
|
||||
}),
|
||||
callTool: vi.fn(async (_: null, args: { server: MCPServer; name: string; args: unknown }) => {
|
||||
if (args.server.id === 'github' && args.name === 'search_repos') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args.args }) }]
|
||||
}
|
||||
}
|
||||
if (args.server.id === 'github' && args.name === 'get_user') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ username: (args.args as any).username, id: 123 }) }]
|
||||
}
|
||||
}
|
||||
if (args.server.id === 'database' && args.name === 'query') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }]
|
||||
}
|
||||
}
|
||||
return { content: [{ type: 'text', text: '{}' }] }
|
||||
})
|
||||
}
|
||||
|
||||
describe('HubServer Integration', () => {
|
||||
let hubServer: HubServer
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
initHubBridge(mockMcpService as any, () => mockMcpServers)
|
||||
hubServer = new HubServer()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
hubServer.invalidateCache()
|
||||
})
|
||||
|
||||
describe('full search → exec flow', () => {
|
||||
it('searches for tools and executes them', async () => {
|
||||
const searchResult = await (hubServer as any).handleSearch({ query: 'github,repos' })
|
||||
|
||||
expect(searchResult.content).toBeDefined()
|
||||
const searchText = JSON.parse(searchResult.content[0].text)
|
||||
expect(searchText.total).toBeGreaterThan(0)
|
||||
expect(searchText.tools).toContain('searchRepos')
|
||||
|
||||
const execResult = await (hubServer as any).handleExec({
|
||||
code: 'return await searchRepos({ query: "test" })'
|
||||
})
|
||||
|
||||
expect(execResult.content).toBeDefined()
|
||||
const execOutput = JSON.parse(execResult.content[0].text)
|
||||
expect(execOutput.result).toEqual({ repos: ['repo1', 'repo2'], query: { query: 'test' } })
|
||||
})
|
||||
|
||||
it('handles multiple tool calls in parallel', async () => {
|
||||
await (hubServer as any).handleSearch({ query: 'github' })
|
||||
|
||||
const execResult = await (hubServer as any).handleExec({
|
||||
code: `
|
||||
const results = await parallel(
|
||||
searchRepos({ query: "react" }),
|
||||
getUser({ username: "octocat" })
|
||||
);
|
||||
return results
|
||||
`
|
||||
})
|
||||
|
||||
const execOutput = JSON.parse(execResult.content[0].text)
|
||||
expect(execOutput.result).toHaveLength(2)
|
||||
expect(execOutput.result[0]).toEqual({ repos: ['repo1', 'repo2'], query: { query: 'react' } })
|
||||
expect(execOutput.result[1]).toEqual({ username: 'octocat', id: 123 })
|
||||
})
|
||||
|
||||
it('searches across multiple servers', async () => {
|
||||
const searchResult = await (hubServer as any).handleSearch({ query: 'query' })
|
||||
|
||||
const searchText = JSON.parse(searchResult.content[0].text)
|
||||
expect(searchText.tools).toContain('query')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cache invalidation', () => {
|
||||
it('refreshes tools after invalidation', async () => {
|
||||
await (hubServer as any).handleSearch({ query: 'github' })
|
||||
|
||||
const initialCallCount = mockMcpService.listTools.mock.calls.length
|
||||
|
||||
hubServer.invalidateCache()
|
||||
|
||||
await (hubServer as any).handleSearch({ query: 'github' })
|
||||
|
||||
expect(mockMcpService.listTools.mock.calls.length).toBeGreaterThan(initialCallCount)
|
||||
})
|
||||
})
|
||||
|
||||
describe('error handling', () => {
|
||||
it('throws error for invalid search query', async () => {
|
||||
await expect((hubServer as any).handleSearch({})).rejects.toThrow('query parameter is required')
|
||||
})
|
||||
|
||||
it('throws error for invalid exec code', async () => {
|
||||
await expect((hubServer as any).handleExec({})).rejects.toThrow('code parameter is required')
|
||||
})
|
||||
|
||||
it('handles runtime errors in exec', async () => {
|
||||
const execResult = await (hubServer as any).handleExec({
|
||||
code: 'throw new Error("test error")'
|
||||
})
|
||||
|
||||
const execOutput = JSON.parse(execResult.content[0].text)
|
||||
expect(execOutput.error).toBe('test error')
|
||||
})
|
||||
})
|
||||
|
||||
describe('server instance', () => {
|
||||
it('creates a valid MCP server instance', () => {
|
||||
expect(hubServer.server).toBeDefined()
|
||||
expect(hubServer.server.setRequestHandler).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,9 +1,7 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js'
|
||||
import type { MCPServer } from '@types'
|
||||
|
||||
import { setMCPServersGetter, setMCPService } from './mcp-bridge'
|
||||
import { Runtime } from './runtime'
|
||||
import { searchTools } from './search'
|
||||
import { ToolRegistry } from './tool-registry'
|
||||
@ -11,23 +9,12 @@ import type { ExecInput, SearchQuery } from './types'
|
||||
|
||||
const logger = loggerService.withContext('MCPServer:Hub')
|
||||
|
||||
interface MCPServiceInterface {
|
||||
listTools(_: null, server: MCPServer): Promise<unknown[]>
|
||||
callTool(
|
||||
_: null,
|
||||
args: { server: MCPServer; name: string; args: unknown; callId?: string }
|
||||
): Promise<{ content: Array<{ type: string; text?: string }> }>
|
||||
}
|
||||
|
||||
export class HubServer {
|
||||
public server: Server
|
||||
private toolRegistry: ToolRegistry
|
||||
private runtime: Runtime
|
||||
|
||||
constructor(mcpService: MCPServiceInterface, mcpServersGetter: () => MCPServer[]) {
|
||||
setMCPService(mcpService as any)
|
||||
setMCPServersGetter(mcpServersGetter)
|
||||
|
||||
constructor() {
|
||||
this.toolRegistry = new ToolRegistry()
|
||||
this.runtime = new Runtime()
|
||||
|
||||
|
||||
@ -22,6 +22,11 @@ export function setMCPServersGetter(getter: () => MCPServer[]): void {
|
||||
mcpServersGetter = getter
|
||||
}
|
||||
|
||||
export function initHubBridge(service: MCPServiceInterface, serversGetter: () => MCPServer[]): void {
|
||||
mcpServiceInstance = service
|
||||
mcpServersGetter = serversGetter
|
||||
}
|
||||
|
||||
export function getActiveServers(): MCPServer[] {
|
||||
if (!mcpServersGetter) {
|
||||
logger.warn('MCP servers getter not set')
|
||||
|
||||
@ -3,7 +3,8 @@ import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { createInMemoryMCPServer, setHubServerDependencies } from '@main/mcpServers/factory'
|
||||
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
|
||||
import { initHubBridge } from '@main/mcpServers/hub/mcp-bridge'
|
||||
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
||||
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||
import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process'
|
||||
@ -169,14 +170,14 @@ class McpService {
|
||||
}
|
||||
|
||||
private initializeHubDependencies(): void {
|
||||
setHubServerDependencies({
|
||||
mcpService: {
|
||||
initHubBridge(
|
||||
{
|
||||
listTools: (_: null, server: MCPServer) => this.listToolsImpl(server),
|
||||
callTool: async (_: null, args: { server: MCPServer; name: string; args: unknown; callId?: string }) => {
|
||||
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, args)
|
||||
}
|
||||
},
|
||||
mcpServersGetter: () => {
|
||||
() => {
|
||||
try {
|
||||
const servers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
|
||||
return servers || []
|
||||
@ -184,7 +185,7 @@ class McpService {
|
||||
return []
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
private getServerKey(server: MCPServer): string {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user