diff --git a/packages/shared/mcp.ts b/packages/shared/mcp.ts new file mode 100644 index 0000000000..7836c00db0 --- /dev/null +++ b/packages/shared/mcp.ts @@ -0,0 +1,43 @@ +/** + * Convert a string to camelCase, ensuring it's a valid JavaScript identifier. + */ +export function toCamelCase(str: string): string { + let result = str + .replace(/[^a-zA-Z0-9]+(.)/g, (_, char) => char.toUpperCase()) + .replace(/^[A-Z]/, (char) => char.toLowerCase()) + .replace(/[^a-zA-Z0-9]/g, '') + + // Ensure valid JS identifier: must start with letter or underscore + if (result && !/^[a-zA-Z_]/.test(result)) { + result = '_' + result + } + + return result +} + +/** + * Generate a unique function name from server name and tool name. + * Format: serverName_toolName (camelCase) + */ +export function generateMcpToolFunctionName( + serverName: string | undefined, + toolName: string, + existingNames?: Set +): string { + const serverPrefix = serverName ? toCamelCase(serverName) : '' + const toolNameCamel = toCamelCase(toolName) + const baseName = serverPrefix ? `${serverPrefix}_${toolNameCamel}` : toolNameCamel + + if (!existingNames) { + return baseName + } + + let name = baseName + let counter = 1 + while (existingNames.has(name)) { + name = `${baseName}${counter}` + counter++ + } + existingNames.add(name) + return name +} diff --git a/src/main/mcpServers/hub/README.md b/src/main/mcpServers/hub/README.md index 671075da5e..e15141b9c5 100644 --- a/src/main/mcpServers/hub/README.md +++ b/src/main/mcpServers/hub/README.md @@ -6,6 +6,62 @@ A built-in MCP server that aggregates all active MCP servers in Cherry Studio an The Hub server enables LLMs to discover and call tools from all active MCP servers without needing to know the specific server names or tool signatures upfront. +## Auto Mode Integration + +The Hub server is the core component of Cherry Studio's **Auto MCP Mode**. When an assistant is set to Auto mode: + +1. **Automatic Injection**: The Hub server is automatically injected as the only MCP server for the assistant +2. **System Prompt**: A specialized system prompt (`HUB_MODE_SYSTEM_PROMPT`) is appended to guide the LLM on how to use the `search` and `exec` tools +3. **Dynamic Discovery**: The LLM can discover and use any tools from all active MCP servers without manual configuration + +### MCP Modes + +Cherry Studio supports three MCP modes per assistant: + +| Mode | Description | Tools Available | +|------|-------------|-----------------| +| **Disabled** | No MCP tools | None | +| **Auto** | Hub server only | `search`, `exec` | +| **Manual** | User selects servers | Selected server tools | + +### How Auto Mode Works + +``` +User Message + │ + ▼ +┌─────────────────────────────────────────┐ +│ Assistant (mcpMode: 'auto') │ +│ │ +│ System Prompt + HUB_MODE_SYSTEM_PROMPT │ +│ Tools: [hub.search, hub.exec] │ +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ LLM decides to use MCP tools │ +│ │ +│ 1. search({ query: "github,repo" }) │ +│ 2. exec({ code: "await searchRepos()" })│ +└─────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────┐ +│ Hub Server │ +│ │ +│ Aggregates all active MCP servers │ +│ Routes tool calls to appropriate server │ +└─────────────────────────────────────────┘ +``` + +### Relevant Code + +- **Type Definition**: `src/renderer/src/types/index.ts` - `McpMode` type and `getEffectiveMcpMode()` +- **Hub Server Constant**: `src/renderer/src/store/mcp.ts` - `hubMCPServer` +- **Server Selection**: `src/renderer/src/services/ApiService.ts` - `getMcpServersForAssistant()` +- **System Prompt**: `src/renderer/src/config/prompts.ts` - `HUB_MODE_SYSTEM_PROMPT` +- **Prompt Injection**: `src/renderer/src/aiCore/prepareParams/parameterBuilder.ts` + ## Tools ### `search` @@ -95,11 +151,24 @@ return { users, repos }; ## Configuration -The Hub server is a built-in server identified as `@cherry/hub`. To enable it: +The Hub server is a built-in server identified as `@cherry/hub`. + +### Using Auto Mode (Recommended) + +The easiest way to use the Hub server is through Auto mode: + +1. Click the **MCP Tools** button (hammer icon) in the input bar +2. Select **Auto** mode +3. The Hub server is automatically enabled for the assistant + +### Manual Configuration + +Alternatively, you can enable the Hub server manually: 1. Go to **Settings** → **MCP Servers** 2. Find **Hub** in the built-in servers list 3. Toggle it on +4. In the assistant's MCP settings, select the Hub server ## Caching diff --git a/src/main/mcpServers/hub/__tests__/generator.test.ts b/src/main/mcpServers/hub/__tests__/generator.test.ts index bb4e0549d4..d541c625bc 100644 --- a/src/main/mcpServers/hub/__tests__/generator.test.ts +++ b/src/main/mcpServers/hub/__tests__/generator.test.ts @@ -23,20 +23,13 @@ describe('generator', () => { type: 'mcp' as const } - const server = { - id: 'github', - name: 'github-server', - isActive: true - } - const existingNames = new Set() const callTool = async () => ({ success: true }) - const result = generateToolFunction(tool, server as any, existingNames, callTool) + const result = generateToolFunction(tool, existingNames, callTool) - expect(result.toolId).toBe('github__search_repos') - expect(result.functionName).toBe('searchRepos') - expect(result.jsCode).toContain('async function searchRepos') + expect(result.functionName).toBe('githubServer_searchRepos') + expect(result.jsCode).toContain('async function githubServer_searchRepos') expect(result.jsCode).toContain('Search for GitHub repositories') expect(result.jsCode).toContain('__callTool') }) @@ -51,13 +44,12 @@ describe('generator', () => { type: 'mcp' as const } - const server = { id: 'server1', name: 'server1', isActive: true } - const existingNames = new Set(['search']) + const existingNames = new Set(['server1_search']) const callTool = async () => ({}) - const result = generateToolFunction(tool, server as any, existingNames, callTool) + const result = generateToolFunction(tool, existingNames, callTool) - expect(result.functionName).toBe('search1') + expect(result.functionName).toBe('server1_search1') }) it('handles enum types in schema', () => { @@ -78,11 +70,10 @@ describe('generator', () => { type: 'mcp' as const } - const server = { id: 'browser', name: 'browser', isActive: true } const existingNames = new Set() const callTool = async () => ({}) - const result = generateToolFunction(tool, server as any, existingNames, callTool) + const result = generateToolFunction(tool, existingNames, callTool) expect(result.jsCode).toContain('"chromium" | "firefox" | "webkit"') }) @@ -95,9 +86,8 @@ describe('generator', () => { serverId: 's1', serverName: 'server1', toolName: 'tool1', - toolId: 's1__tool1', - functionName: 'tool1', - jsCode: 'async function tool1() {}', + functionName: 'server1_tool1', + jsCode: 'async function server1_tool1() {}', fn: async () => ({}), signature: '{}', returns: 'unknown' @@ -106,9 +96,8 @@ describe('generator', () => { serverId: 's2', serverName: 'server2', toolName: 'tool2', - toolId: 's2__tool2', - functionName: 'tool2', - jsCode: 'async function tool2() {}', + functionName: 'server2_tool2', + jsCode: 'async function server2_tool2() {}', fn: async () => ({}), signature: '{}', returns: 'unknown' @@ -118,8 +107,8 @@ describe('generator', () => { const result = generateToolsCode(tools) expect(result).toContain('Found 2 tool(s)') - expect(result).toContain('async function tool1') - expect(result).toContain('async function tool2') + expect(result).toContain('async function server1_tool1') + expect(result).toContain('async function server2_tool2') }) it('returns message for empty tools', () => { diff --git a/src/main/mcpServers/hub/__tests__/hub.test.ts b/src/main/mcpServers/hub/__tests__/hub.test.ts index 335f60f723..32b94f8933 100644 --- a/src/main/mcpServers/hub/__tests__/hub.test.ts +++ b/src/main/mcpServers/hub/__tests__/hub.test.ts @@ -1,102 +1,93 @@ -import type { MCPServer } from '@types' +import type { MCPTool } from '@types' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { HubServer } from '../index' -import { initHubBridge } from '../mcp-bridge' -const mockMcpServers: MCPServer[] = [ +const mockTools: MCPTool[] = [ { - id: 'github', - name: 'GitHub', - command: 'npx', - args: ['-y', 'github-mcp-server'], - isActive: true - } as MCPServer, + id: 'github__search_repos', + name: 'search_repos', + description: 'Search for GitHub repositories', + serverId: 'github', + serverName: 'GitHub', + inputSchema: { + type: 'object', + properties: { + query: { type: 'string', description: 'Search query' }, + limit: { type: 'number', description: 'Max results' } + }, + required: ['query'] + }, + type: 'mcp' + }, { - id: 'database', - name: 'Database', - command: 'npx', - args: ['-y', 'db-mcp-server'], - isActive: true - } as MCPServer + id: 'github__get_user', + name: 'get_user', + description: 'Get GitHub user profile', + serverId: 'github', + serverName: 'GitHub', + inputSchema: { + type: 'object', + properties: { + username: { type: 'string', description: 'GitHub username' } + }, + required: ['username'] + }, + type: 'mcp' + }, + { + id: 'database__query', + name: 'query', + description: 'Execute a database query', + serverId: 'database', + serverName: 'Database', + inputSchema: { + type: 'object', + properties: { + sql: { type: 'string', description: 'SQL query to execute' } + }, + required: ['sql'] + }, + type: 'mcp' + } ] -const mockToolDefinitions = { - github: [ - { - name: 'search_repos', - description: 'Search for GitHub repositories', - inputSchema: { - type: 'object', - properties: { - query: { type: 'string', description: 'Search query' }, - limit: { type: 'number', description: 'Max results' } - }, - required: ['query'] +vi.mock('@main/services/MCPService', () => ({ + default: { + listAllActiveServerTools: vi.fn(async () => mockTools), + callToolById: vi.fn(async (toolId: string, args: unknown) => { + if (toolId === 'github__search_repos') { + return { + content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args }) }] + } } - }, - { - name: 'get_user', - description: 'Get GitHub user profile', - inputSchema: { - type: 'object', - properties: { - username: { type: 'string', description: 'GitHub username' } - }, - required: ['username'] + if (toolId === 'github__get_user') { + return { + content: [{ type: 'text', text: JSON.stringify({ username: (args as any).username, id: 123 }) }] + } } - } - ], - database: [ - { - name: 'query', - description: 'Execute a database query', - inputSchema: { - type: 'object', - properties: { - sql: { type: 'string', description: 'SQL query to execute' } - }, - required: ['sql'] + if (toolId === 'database__query') { + return { + content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }] + } } - } - ] -} + return { content: [{ type: 'text', text: '{}' }] } + }) + } +})) -const mockMcpService = { - listTools: vi.fn(async (_: null, server: MCPServer) => { - return mockToolDefinitions[server.id as keyof typeof mockToolDefinitions] || [] - }), - callTool: vi.fn(async (_: null, args: { server: MCPServer; name: string; args: unknown }) => { - if (args.server.id === 'github' && args.name === 'search_repos') { - return { - content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args.args }) }] - } - } - if (args.server.id === 'github' && args.name === 'get_user') { - return { - content: [{ type: 'text', text: JSON.stringify({ username: (args.args as any).username, id: 123 }) }] - } - } - if (args.server.id === 'database' && args.name === 'query') { - return { - content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }] - } - } - return { content: [{ type: 'text', text: '{}' }] } - }) -} +import mcpService from '@main/services/MCPService' describe('HubServer Integration', () => { let hubServer: HubServer beforeEach(() => { vi.clearAllMocks() - initHubBridge(mockMcpService as any, () => mockMcpServers) hubServer = new HubServer() }) afterEach(() => { - hubServer.invalidateCache() + vi.clearAllMocks() }) describe('full search → exec flow', () => { @@ -106,10 +97,10 @@ describe('HubServer Integration', () => { expect(searchResult.content).toBeDefined() const searchText = JSON.parse(searchResult.content[0].text) expect(searchText.total).toBeGreaterThan(0) - expect(searchText.tools).toContain('searchRepos') + expect(searchText.tools).toContain('gitHub_searchRepos') const execResult = await (hubServer as any).handleExec({ - code: 'return await searchRepos({ query: "test" })' + code: 'return await gitHub_searchRepos({ query: "test" })' }) expect(execResult.content).toBeDefined() @@ -123,8 +114,8 @@ describe('HubServer Integration', () => { const execResult = await (hubServer as any).handleExec({ code: ` const results = await parallel( - searchRepos({ query: "react" }), - getUser({ username: "octocat" }) + gitHub_searchRepos({ query: "react" }), + gitHub_getUser({ username: "octocat" }) ); return results ` @@ -140,21 +131,31 @@ describe('HubServer Integration', () => { const searchResult = await (hubServer as any).handleSearch({ query: 'query' }) const searchText = JSON.parse(searchResult.content[0].text) - expect(searchText.tools).toContain('query') + expect(searchText.tools).toContain('database_query') }) }) - describe('cache invalidation', () => { - it('refreshes tools after invalidation', async () => { + describe('tools caching', () => { + it('uses cached tools within TTL', async () => { await (hubServer as any).handleSearch({ query: 'github' }) + const firstCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length - const initialCallCount = mockMcpService.listTools.mock.calls.length + await (hubServer as any).handleSearch({ query: 'github' }) + const secondCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length + + expect(secondCallCount).toBe(firstCallCount) // Should use cache + }) + + it('refreshes tools after cache invalidation', async () => { + await (hubServer as any).handleSearch({ query: 'github' }) + const firstCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length hubServer.invalidateCache() await (hubServer as any).handleSearch({ query: 'github' }) + const secondCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length - expect(mockMcpService.listTools.mock.calls.length).toBeGreaterThan(initialCallCount) + expect(secondCallCount).toBe(firstCallCount + 1) }) }) diff --git a/src/main/mcpServers/hub/__tests__/runtime.test.ts b/src/main/mcpServers/hub/__tests__/runtime.test.ts index 76366c5e0a..15c7ab7880 100644 --- a/src/main/mcpServers/hub/__tests__/runtime.test.ts +++ b/src/main/mcpServers/hub/__tests__/runtime.test.ts @@ -16,9 +16,8 @@ const createMockTool = (partial: Partial): GeneratedTool => ({ serverId: 'server1', serverName: 'server1', toolName: 'tool', - toolId: 'server1__tool', - functionName: 'mockTool', - jsCode: 'async function mockTool() {}', + functionName: 'server1_mockTool', + jsCode: 'async function server1_mockTool() {}', fn: async (params) => ({ result: params }), signature: '{}', returns: 'unknown', diff --git a/src/main/mcpServers/hub/__tests__/search.test.ts b/src/main/mcpServers/hub/__tests__/search.test.ts index 6d004797ef..4e483003f2 100644 --- a/src/main/mcpServers/hub/__tests__/search.test.ts +++ b/src/main/mcpServers/hub/__tests__/search.test.ts @@ -4,12 +4,11 @@ import { searchTools } from '../search' import type { GeneratedTool } from '../types' const createMockTool = (partial: Partial): GeneratedTool => { - const functionName = partial.functionName || 'tool' + const functionName = partial.functionName || 'server1_tool' return { serverId: 'server1', serverName: 'server1', toolName: partial.toolName || 'tool', - toolId: partial.toolId || 'server1__tool', functionName, jsCode: `async function ${functionName}() {}`, fn: async () => ({}), @@ -86,13 +85,13 @@ describe('search', () => { it('respects limit parameter', () => { const tools = Array.from({ length: 20 }, (_, i) => - createMockTool({ toolName: `tool${i}`, functionName: `tool${i}`, toolId: `s__tool${i}` }) + createMockTool({ toolName: `tool${i}`, functionName: `server1_tool${i}` }) ) const result = searchTools(tools, { query: 'tool', limit: 5 }) expect(result.total).toBe(20) - const matches = (result.tools.match(/async function tool\d+/g) || []).length + const matches = (result.tools.match(/async function server1_tool\d+/g) || []).length expect(matches).toBe(5) }) diff --git a/src/main/mcpServers/hub/generator.ts b/src/main/mcpServers/hub/generator.ts index 479ce17d04..12989a1589 100644 --- a/src/main/mcpServers/hub/generator.ts +++ b/src/main/mcpServers/hub/generator.ts @@ -1,25 +1,8 @@ -import type { MCPServer, MCPTool } from '@types' +import { generateMcpToolFunctionName } from '@shared/mcp' +import type { MCPTool } from '@types' import type { GeneratedTool } from './types' -function toCamelCase(str: string): string { - return str - .replace(/[^a-zA-Z0-9]+(.)/g, (_, char) => char.toUpperCase()) - .replace(/^[A-Z]/, (char) => char.toLowerCase()) - .replace(/[^a-zA-Z0-9]/g, '') -} - -function makeUniqueFunctionName(baseName: string, existingNames: Set): string { - let name = baseName - let counter = 1 - while (existingNames.has(name)) { - name = `${baseName}${counter}` - counter++ - } - existingNames.add(name) - return name -} - function jsonSchemaToSignature(schema: Record | undefined): string { if (!schema || typeof schema !== 'object') { return '{}' @@ -109,13 +92,10 @@ function generateJSDoc(tool: MCPTool, signature: string, returns: string): strin export function generateToolFunction( tool: MCPTool, - server: MCPServer, existingNames: Set, - callToolFn: (toolId: string, params: unknown) => Promise + callToolFn: (functionName: string, params: unknown) => Promise ): GeneratedTool { - const toolId = `${server.id}__${tool.name}` - const baseName = toCamelCase(tool.name) - const functionName = makeUniqueFunctionName(baseName, existingNames) + const functionName = generateMcpToolFunctionName(tool.serverName, tool.name, existingNames) const inputSchema = tool.inputSchema as Record | undefined const outputSchema = tool.outputSchema as Record | undefined @@ -127,18 +107,17 @@ export function generateToolFunction( const jsCode = `${jsDoc} async function ${functionName}(params) { - return await __callTool("${toolId}", params); + return await __callTool("${functionName}", params); }` const fn = async (params: unknown): Promise => { - return await callToolFn(toolId, params) + return await callToolFn(functionName, params) } return { - serverId: server.id, - serverName: server.name, + serverId: tool.serverId, + serverName: tool.serverName, toolName: tool.name, - toolId, functionName, jsCode, fn, diff --git a/src/main/mcpServers/hub/index.ts b/src/main/mcpServers/hub/index.ts index 08b3372c26..2c60ba49b8 100644 --- a/src/main/mcpServers/hub/index.ts +++ b/src/main/mcpServers/hub/index.ts @@ -1,13 +1,17 @@ import { loggerService } from '@logger' +import { CacheService } from '@main/services/CacheService' import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js' +import { generateToolFunction } from './generator' +import { callMcpTool, listAllTools } from './mcp-bridge' import { Runtime } from './runtime' import { searchTools } from './search' -import { ToolRegistry } from './tool-registry' -import type { ExecInput, SearchQuery } from './types' +import type { ExecInput, GeneratedTool, SearchQuery } from './types' const logger = loggerService.withContext('MCPServer:Hub') +const TOOLS_CACHE_KEY = 'hub:tools' +const TOOLS_CACHE_TTL = 60 * 1000 // 1 minute /** * Hub MCP Server - A meta-server that aggregates all active MCP servers. @@ -23,11 +27,9 @@ const logger = loggerService.withContext('MCPServer:Hub') */ export class HubServer { public server: Server - private toolRegistry: ToolRegistry private runtime: Runtime constructor() { - this.toolRegistry = new ToolRegistry() this.runtime = new Runtime() this.server = new Server( @@ -118,12 +120,32 @@ export class HubServer { }) } + private async fetchTools(): Promise { + const cached = CacheService.get(TOOLS_CACHE_KEY) + if (cached) { + logger.debug('Returning cached tools') + return cached + } + + logger.debug('Fetching fresh tools') + const allTools = await listAllTools() + const existingNames = new Set() + const tools = allTools.map((tool) => generateToolFunction(tool, existingNames, callMcpTool)) + CacheService.set(TOOLS_CACHE_KEY, tools, TOOLS_CACHE_TTL) + return tools + } + + invalidateCache(): void { + CacheService.remove(TOOLS_CACHE_KEY) + logger.debug('Tools cache invalidated') + } + private async handleSearch(query: SearchQuery) { if (!query.query || typeof query.query !== 'string') { throw new McpError(ErrorCode.InvalidParams, 'query parameter is required and must be a string') } - const tools = await this.toolRegistry.getTools() + const tools = await this.fetchTools() const result = searchTools(tools, query) return { @@ -141,7 +163,7 @@ export class HubServer { throw new McpError(ErrorCode.InvalidParams, 'code parameter is required and must be a string') } - const tools = await this.toolRegistry.getTools() + const tools = await this.fetchTools() const result = await this.runtime.execute(input.code, tools) return { @@ -153,10 +175,6 @@ export class HubServer { ] } } - - invalidateCache(): void { - this.toolRegistry.invalidate() - } } export default HubServer diff --git a/src/main/mcpServers/hub/mcp-bridge.ts b/src/main/mcpServers/hub/mcp-bridge.ts index e0490c078c..1e7bfc7817 100644 --- a/src/main/mcpServers/hub/mcp-bridge.ts +++ b/src/main/mcpServers/hub/mcp-bridge.ts @@ -1,84 +1,38 @@ -import { loggerService } from '@logger' -import { BuiltinMCPServerNames, type MCPServer, type MCPTool } from '@types' +/** + * Bridge module for Hub server to access MCPService. + * Re-exports the methods needed by tool-registry and runtime. + */ +import mcpService from '@main/services/MCPService' +import { generateMcpToolFunctionName } from '@shared/mcp' -const logger = loggerService.withContext('MCPServer:Hub:Bridge') +export const listAllTools = () => mcpService.listAllActiveServerTools() -let mcpServiceInstance: MCPServiceInterface | null = null -let mcpServersGetter: (() => MCPServer[]) | null = null +const toolFunctionNameToIdMap = new Map() -interface MCPServiceInterface { - listTools(_: null, server: MCPServer): Promise - callTool( - _: null, - args: { server: MCPServer; name: string; args: unknown; callId?: string } - ): Promise<{ content: Array<{ type: string; text?: string }> }> -} - -export function setMCPService(service: MCPServiceInterface): void { - mcpServiceInstance = service -} - -export function setMCPServersGetter(getter: () => MCPServer[]): void { - mcpServersGetter = getter -} - -export function initHubBridge(service: MCPServiceInterface, serversGetter: () => MCPServer[]): void { - mcpServiceInstance = service - mcpServersGetter = serversGetter -} - -export function getActiveServers(): MCPServer[] { - if (!mcpServersGetter) { - logger.warn('MCP servers getter not set') - return [] - } - - const servers = mcpServersGetter() - return servers.filter((s) => s.isActive && s.name !== BuiltinMCPServerNames.hub) -} - -export async function listToolsFromServer(server: MCPServer): Promise { - if (!mcpServiceInstance) { - logger.error('MCP service not initialized') - return [] - } - - try { - return await mcpServiceInstance.listTools(null, server) - } catch (error) { - logger.error(`Failed to list tools from server ${server.name}:`, error as Error) - return [] +export async function refreshToolMap(): Promise { + const tools = await listAllTools() + toolFunctionNameToIdMap.clear() + const existingNames = new Set() + for (const tool of tools) { + const functionName = generateMcpToolFunctionName(tool.serverName, tool.name, existingNames) + toolFunctionNameToIdMap.set(functionName, { serverId: tool.serverId, toolName: tool.name }) } } -export async function callMcpTool(toolId: string, params: unknown): Promise { - if (!mcpServiceInstance) { - throw new Error('MCP service not initialized') +export const callMcpTool = async (functionName: string, params: unknown): Promise => { + const toolInfo = toolFunctionNameToIdMap.get(functionName) + if (!toolInfo) { + await refreshToolMap() + const retryToolInfo = toolFunctionNameToIdMap.get(functionName) + if (!retryToolInfo) { + throw new Error(`Tool not found: ${functionName}`) + } + const toolId = `${retryToolInfo.serverId}__${retryToolInfo.toolName}` + const result = await mcpService.callToolById(toolId, params) + return extractToolResult(result) } - - const parts = toolId.split('__') - if (parts.length < 2) { - throw new Error(`Invalid tool ID format: ${toolId}`) - } - - const serverId = parts[0] - const toolName = parts.slice(1).join('__') - - const servers = getActiveServers() - const server = servers.find((s) => s.id === serverId) - - if (!server) { - throw new Error(`Server not found: ${serverId}`) - } - - logger.debug(`Calling tool ${toolName} on server ${server.name}`) - - const result = await mcpServiceInstance.callTool(null, { - server, - name: toolName, - args: params - }) - + const toolId = `${toolInfo.serverId}__${toolInfo.toolName}` + const result = await mcpService.callToolById(toolId, params) return extractToolResult(result) } diff --git a/src/main/mcpServers/hub/search.ts b/src/main/mcpServers/hub/search.ts index 93c7eb70ac..7bed36a285 100644 --- a/src/main/mcpServers/hub/search.ts +++ b/src/main/mcpServers/hub/search.ts @@ -37,7 +37,15 @@ export function searchTools(tools: GeneratedTool[], query: SearchQuery): SearchR } function buildSearchText(tool: GeneratedTool): string { - const parts = [tool.toolName, tool.functionName, tool.serverName, tool.description || '', tool.signature] + const combinedName = tool.serverName ? `${tool.serverName}_${tool.toolName}` : tool.toolName + const parts = [ + tool.toolName, + tool.functionName, + tool.serverName, + combinedName, + tool.description || '', + tool.signature + ] return parts.join(' ') } @@ -55,10 +63,12 @@ function rankTools(tools: GeneratedTool[], keywords: string[]): GeneratedTool[] function calculateScore(tool: GeneratedTool, keywords: string[]): number { let score = 0 const toolName = tool.toolName.toLowerCase() + const serverName = (tool.serverName || '').toLowerCase() const functionName = tool.functionName.toLowerCase() const description = (tool.description || '').toLowerCase() for (const keyword of keywords) { + // Match tool name if (toolName === keyword) { score += 10 } else if (toolName.startsWith(keyword)) { @@ -67,12 +77,24 @@ function calculateScore(tool: GeneratedTool, keywords: string[]): number { score += 3 } - if (functionName === keyword) { + // Match server name + if (serverName === keyword) { score += 8 - } else if (functionName.includes(keyword)) { + } else if (serverName.startsWith(keyword)) { + score += 4 + } else if (serverName.includes(keyword)) { score += 2 } + // Match function name (serverName_toolName format) + if (functionName === keyword) { + score += 10 + } else if (functionName.startsWith(keyword)) { + score += 5 + } else if (functionName.includes(keyword)) { + score += 3 + } + if (description.includes(keyword)) { const count = (description.match(new RegExp(escapeRegex(keyword), 'g')) || []).length score += Math.min(count, 3) diff --git a/src/main/mcpServers/hub/tool-registry.ts b/src/main/mcpServers/hub/tool-registry.ts deleted file mode 100644 index f08c80230d..0000000000 --- a/src/main/mcpServers/hub/tool-registry.ts +++ /dev/null @@ -1,98 +0,0 @@ -import { loggerService } from '@logger' - -import { generateToolFunction } from './generator' -import { callMcpTool, getActiveServers, listToolsFromServer } from './mcp-bridge' -import type { GeneratedTool, ToolRegistryOptions } from './types' - -const logger = loggerService.withContext('MCPServer:Hub:Registry') - -const DEFAULT_TTL = 10 * 60 * 1000 - -export class ToolRegistry { - private tools: Map = new Map() - private lastRefresh: number = 0 - private readonly ttl: number - private refreshPromise: Promise | null = null - - constructor(options: ToolRegistryOptions = {}) { - this.ttl = options.ttl ?? DEFAULT_TTL - } - - async getTools(): Promise { - if (this.isExpired()) { - await this.refresh() - } - return Array.from(this.tools.values()) - } - - async getTool(toolId: string): Promise { - if (this.isExpired()) { - await this.refresh() - } - return this.tools.get(toolId) - } - - getToolByFunctionName(functionName: string): GeneratedTool | undefined { - for (const tool of this.tools.values()) { - if (tool.functionName === functionName) { - return tool - } - } - return undefined - } - - private isExpired(): boolean { - return Date.now() - this.lastRefresh > this.ttl - } - - invalidate(): void { - this.lastRefresh = 0 - this.tools.clear() - logger.debug('Tool registry invalidated') - } - - async refresh(): Promise { - if (this.refreshPromise) { - return this.refreshPromise - } - - this.refreshPromise = this.doRefresh() - - try { - await this.refreshPromise - } finally { - this.refreshPromise = null - } - } - - private async doRefresh(): Promise { - logger.debug('Refreshing tool registry') - - const servers = getActiveServers() - const newTools = new Map() - const existingNames = new Set() - - for (const server of servers) { - try { - const serverTools = await listToolsFromServer(server) - - for (const tool of serverTools) { - const generatedTool = generateToolFunction(tool, server, existingNames, callMcpTool) - - newTools.set(generatedTool.toolId, generatedTool) - } - } catch (error) { - logger.error(`Failed to list tools from server ${server.name}:`, error as Error) - } - } - - this.tools = newTools - this.lastRefresh = Date.now() - - logger.debug(`Tool registry refreshed with ${this.tools.size} tools`) - } - - getToolCount(): number { - return this.tools.size - } -} diff --git a/src/main/mcpServers/hub/types.ts b/src/main/mcpServers/hub/types.ts index 5513e31857..8b34fd7e14 100644 --- a/src/main/mcpServers/hub/types.ts +++ b/src/main/mcpServers/hub/types.ts @@ -4,7 +4,6 @@ export interface GeneratedTool { serverId: string serverName: string toolName: string - toolId: string functionName: string jsCode: string fn: (params: unknown) => Promise @@ -42,7 +41,7 @@ export interface MCPToolWithServer extends MCPTool { } export interface ExecutionContext { - __callTool: (toolId: string, params: unknown) => Promise + __callTool: (functionName: string, params: unknown) => Promise parallel: (...promises: Promise[]) => Promise settle: (...promises: Promise[]) => Promise[]> console: ConsoleMethods diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 12461da760..7b636c89bf 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -3,8 +3,8 @@ import os from 'node:os' import path from 'node:path' import { loggerService } from '@logger' +import { getMCPServersFromRedux } from '@main/apiServer/utils/mcp' import { createInMemoryMCPServer } from '@main/mcpServers/factory' -import { initHubBridge } from '@main/mcpServers/hub/mcp-bridge' import { makeSureDirExists, removeEnvProxy } from '@main/utils' import { buildFunctionCallToolName } from '@main/utils/mcp' import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process' @@ -59,7 +59,6 @@ import DxtService from './DxtService' import { CallBackServer } from './mcp/oauth/callback' import { McpOAuthClientProvider } from './mcp/oauth/provider' import { ServerLogBuffer } from './mcp/ServerLogBuffer' -import { reduxService } from './ReduxService' import { windowService } from './WindowService' // Generic type for caching wrapped functions @@ -165,27 +164,61 @@ class McpService { this.checkMcpConnectivity = this.checkMcpConnectivity.bind(this) this.getServerVersion = this.getServerVersion.bind(this) this.getServerLogs = this.getServerLogs.bind(this) - - this.initializeHubDependencies() } - private initializeHubDependencies(): void { - initHubBridge( - { - listTools: (_: null, server: MCPServer) => this.listToolsImpl(server), - callTool: async (_: null, args: { server: MCPServer; name: string; args: unknown; callId?: string }) => { - return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, args) - } - }, - () => { - try { - const servers = reduxService.selectSync('state.mcp.servers') - return servers || [] - } catch { - return [] - } + /** + * List all tools from all active MCP servers (excluding hub). + * Used by Hub server's tool registry. + */ + public async listAllActiveServerTools(): Promise { + const servers = await getMCPServersFromRedux() + const allTools: MCPTool[] = [] + + for (const server of servers) { + if (!server.isActive) { + continue } - ) + try { + const tools = await this.listToolsImpl(server) + allTools.push(...tools) + } catch (error) { + logger.error(`[listAllActiveServerTools] Failed to list tools from ${server.name}:`, error as Error) + } + } + + return allTools + } + + /** + * Call a tool by its full ID (serverId__toolName format). + * Used by Hub server's runtime. + */ + public async callToolById( + toolId: string, + params: unknown + ): Promise<{ content: Array<{ type: string; text?: string }> }> { + const parts = toolId.split('__') + if (parts.length < 2) { + throw new Error(`Invalid tool ID format: ${toolId}`) + } + + const serverId = parts[0] + const toolName = parts.slice(1).join('__') + + const servers = await getMCPServersFromRedux() + const server = servers.find((s) => s.id === serverId) + + if (!server) { + throw new Error(`Server not found: ${serverId}`) + } + + logger.debug(`[callToolById] Calling tool ${toolName} on server ${server.name}`) + + return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, { + server, + name: toolName, + args: params + }) } private getServerKey(server: MCPServer): string { diff --git a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts index 52234c5f1f..1d92660072 100644 --- a/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts +++ b/src/renderer/src/aiCore/prepareParams/parameterBuilder.ts @@ -26,11 +26,13 @@ import { isSupportedThinkingTokenModel, isWebSearchModel } from '@renderer/config/models' +import { getHubModeSystemPrompt } from '@renderer/config/prompts' +import { fetchAllActiveServerTools } from '@renderer/services/ApiService' import { getDefaultModel } from '@renderer/services/AssistantService' import store from '@renderer/store' import type { CherryWebSearchConfig } from '@renderer/store/websearch' import type { Model } from '@renderer/types' -import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types' +import { type Assistant, getEffectiveMcpMode, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types' import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern' import { replacePromptVariables } from '@renderer/utils/prompt' @@ -244,7 +246,15 @@ export async function buildStreamTextParams( } if (assistant.prompt) { - params.system = await replacePromptVariables(assistant.prompt, model.name) + let systemPrompt = await replacePromptVariables(assistant.prompt, model.name) + if (getEffectiveMcpMode(assistant) === 'auto') { + const allActiveTools = await fetchAllActiveServerTools() + systemPrompt = systemPrompt + '\n\n' + getHubModeSystemPrompt(allActiveTools) + } + params.system = systemPrompt + } else if (getEffectiveMcpMode(assistant) === 'auto') { + const allActiveTools = await fetchAllActiveServerTools() + params.system = getHubModeSystemPrompt(allActiveTools) } logger.debug('params', params) diff --git a/src/renderer/src/config/prompts.ts b/src/renderer/src/config/prompts.ts index 815eb7d113..0f8493938f 100644 --- a/src/renderer/src/config/prompts.ts +++ b/src/renderer/src/config/prompts.ts @@ -1,3 +1,4 @@ +import { generateMcpToolFunctionName } from '@shared/mcp' import dayjs from 'dayjs' export const AGENT_PROMPT = ` @@ -468,3 +469,68 @@ Example: [nytimes.com](https://nytimes.com/some-page). If have multiple citations, please directly list them like this: [www.nytimes.com](https://nytimes.com/some-page)[www.bbc.com](https://bbc.com/some-page) ` + +const HUB_MODE_SYSTEM_PROMPT_BASE = ` +## MCP Tools (Code Mode) + +You have access to MCP tools via the hub server. + +### Workflow +1. Call \`search\` with relevant keywords to discover tools +2. Review the returned function signatures and their parameters +3. Call \`exec\` with JavaScript code using those functions +4. The last expression in your code becomes the result + +### Example Usage + +**Step 1: Search for tools** +\`\`\` +search({ query: "github,repository" }) +\`\`\` + +**Step 2: Use discovered tools** +\`\`\`javascript +exec({ + code: \` + const repos = await searchRepos({ query: "react", limit: 5 }) + const details = await parallel( + repos.map(r => getRepoDetails({ owner: r.owner, repo: r.name })) + ) + return { repos, details } + \` +}) +\`\`\` + +### Best Practices +- Always search first to discover available tools and their exact signatures +- Use descriptive variable names in your code +- Handle errors gracefully with try/catch when needed +- Use \`parallel()\` for independent operations to improve performance +` + +interface ToolInfo { + name: string + serverName?: string + description?: string +} + +export function getHubModeSystemPrompt(tools: ToolInfo[] = []): string { + if (tools.length === 0) { + return HUB_MODE_SYSTEM_PROMPT_BASE + } + + const existingNames = new Set() + const toolsSection = tools + .map((t) => { + const functionName = generateMcpToolFunctionName(t.serverName, t.name, existingNames) + const desc = t.description || '' + const truncatedDesc = desc.length > 50 ? `${desc.slice(0, 50)}...` : desc + return `- ${functionName}: ${truncatedDesc}` + }) + .join('\n') + + return `${HUB_MODE_SYSTEM_PROMPT_BASE} +### Available Tools +${toolsSection} +` +} diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 9e60f31f00..0e556cc547 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -544,6 +544,20 @@ "description": "Default enabled MCP servers", "enableFirst": "Enable this server in MCP settings first", "label": "MCP Servers", + "mode": { + "auto": { + "description": "AI discovers and uses tools automatically", + "label": "Auto" + }, + "disabled": { + "description": "No MCP tools", + "label": "Disabled" + }, + "manual": { + "description": "Select specific MCP servers", + "label": "Manual" + } + }, "noServersAvailable": "No MCP servers available. Add servers in settings", "title": "MCP Settings" }, diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index b9b07a596c..88967e3dcd 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -544,6 +544,20 @@ "description": "默认启用的 MCP 服务器", "enableFirst": "请先在 MCP 设置中启用此服务器", "label": "MCP 服务器", + "mode": { + "auto": { + "description": "AI 自动发现和使用工具", + "label": "自动" + }, + "disabled": { + "description": "不使用 MCP 工具", + "label": "禁用" + }, + "manual": { + "description": "选择特定的 MCP 服务器", + "label": "手动" + } + }, "noServersAvailable": "无可用 MCP 服务器。请在设置中添加服务器", "title": "MCP 服务器" }, diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 3d613f00f4..63e7128a48 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -544,6 +544,20 @@ "description": "預設啟用的 MCP 伺服器", "enableFirst": "請先在 MCP 設定中啟用此伺服器", "label": "MCP 伺服器", + "mode": { + "auto": { + "description": "AI 自動發現和使用工具", + "label": "自動" + }, + "disabled": { + "description": "不使用 MCP 工具", + "label": "停用" + }, + "manual": { + "description": "選擇特定的 MCP 伺服器", + "label": "手動" + } + }, "noServersAvailable": "無可用 MCP 伺服器。請在設定中新增伺服器", "title": "MCP 設定" }, diff --git a/src/renderer/src/pages/home/Inputbar/tools/components/MCPToolsButton.tsx b/src/renderer/src/pages/home/Inputbar/tools/components/MCPToolsButton.tsx index 0b0610e1d6..73397a139d 100644 --- a/src/renderer/src/pages/home/Inputbar/tools/components/MCPToolsButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/tools/components/MCPToolsButton.tsx @@ -8,11 +8,12 @@ import { useTimer } from '@renderer/hooks/useTimer' import type { ToolQuickPanelApi } from '@renderer/pages/home/Inputbar/types' import { getProviderByModel } from '@renderer/services/AssistantService' import { EventEmitter } from '@renderer/services/EventService' -import type { MCPPrompt, MCPResource, MCPServer } from '@renderer/types' +import type { McpMode, MCPPrompt, MCPResource, MCPServer } from '@renderer/types' +import { getEffectiveMcpMode } from '@renderer/types' import { isToolUseModeFunction } from '@renderer/utils/assistant' import { isGeminiWebSearchProvider, isSupportUrlContextProvider } from '@renderer/utils/provider' import { Form, Input, Tooltip } from 'antd' -import { CircleX, Hammer, Plus } from 'lucide-react' +import { CircleX, Hammer, Plus, Sparkles } from 'lucide-react' import type { FC } from 'react' import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -25,7 +26,6 @@ interface Props { resizeTextArea: () => void } -// 添加类型定义 interface PromptArgument { name: string description?: string @@ -44,24 +44,19 @@ interface ResourceData { uri?: string } -// 提取到组件外的工具函数 const extractPromptContent = (response: any): string | null => { - // Handle string response (backward compatibility) if (typeof response === 'string') { return response } - // Handle GetMCPPromptResponse format if (response && Array.isArray(response.messages)) { let formattedContent = '' for (const message of response.messages) { if (!message.content) continue - // Add role prefix if available const rolePrefix = message.role ? `**${message.role.charAt(0).toUpperCase() + message.role.slice(1)}:** ` : '' - // Process different content types switch (message.content.type) { case 'text': formattedContent += `${rolePrefix}${message.content.text}\n\n` @@ -98,7 +93,6 @@ const extractPromptContent = (response: any): string | null => { return formattedContent.trim() } - // Fallback handling for single message format if (response && response.messages && response.messages.length > 0) { const message = response.messages[0] if (message.content && message.content.text) { @@ -121,7 +115,6 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, const model = assistant.model const { setTimeoutTimer } = useTimer() - // 使用 useRef 存储不需要触发重渲染的值 const isMountedRef = useRef(true) useEffect(() => { @@ -130,11 +123,30 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, } }, []) + const currentMode = useMemo(() => getEffectiveMcpMode(assistant), [assistant]) + const mcpServers = useMemo(() => assistant.mcpServers || [], [assistant.mcpServers]) const assistantMcpServers = useMemo( () => activedMcpServers.filter((server) => mcpServers.some((s) => s.id === server.id)), [activedMcpServers, mcpServers] ) + + const handleModeChange = useCallback( + (mode: McpMode) => { + setTimeoutTimer( + 'updateMcpMode', + () => { + updateAssistant({ + ...assistant, + mcpMode: mode + }) + }, + 200 + ) + }, + [assistant, setTimeoutTimer, updateAssistant] + ) + const handleMcpServerSelect = useCallback( (server: MCPServer) => { const update = { ...assistant } @@ -144,29 +156,24 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, update.mcpServers = [...mcpServers, server] } - // only for gemini if (update.mcpServers.length > 0 && isGeminiModel(model) && isToolUseModeFunction(assistant)) { const provider = getProviderByModel(model) if (isSupportUrlContextProvider(provider) && assistant.enableUrlContext) { window.toast.warning(t('chat.mcp.warning.url_context')) update.enableUrlContext = false } - if ( - // 非官方 API (openrouter etc.) 可能支持同时启用内置搜索和函数调用 - // 这里先假设 gemini type 和 vertexai type 不支持 - isGeminiWebSearchProvider(provider) && - assistant.enableWebSearch - ) { + if (isGeminiWebSearchProvider(provider) && assistant.enableWebSearch) { window.toast.warning(t('chat.mcp.warning.gemini_web_search')) update.enableWebSearch = false } } + + update.mcpMode = 'manual' updateAssistant(update) }, [assistant, assistantMcpServers, mcpServers, model, t, updateAssistant] ) - // 使用 useRef 缓存事件处理函数 const handleMcpServerSelectRef = useRef(handleMcpServerSelect) handleMcpServerSelectRef.current = handleMcpServerSelect @@ -176,23 +183,7 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, return () => EventEmitter.off('mcp-server-select', handler) }, []) - const updateMcpEnabled = useCallback( - (enabled: boolean) => { - setTimeoutTimer( - 'updateMcpEnabled', - () => { - updateAssistant({ - ...assistant, - mcpServers: enabled ? assistant.mcpServers || [] : [] - }) - }, - 200 - ) - }, - [assistant, setTimeoutTimer, updateAssistant] - ) - - const menuItems = useMemo(() => { + const manualModeMenuItems = useMemo(() => { const newList: QuickPanelListItem[] = activedMcpServers.map((server) => ({ label: server.name, description: server.description || server.baseUrl, @@ -207,33 +198,69 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, action: () => navigate('/settings/mcp') }) - newList.unshift({ - label: t('settings.input.clear.all'), - description: t('settings.mcp.disable.description'), - icon: , - isSelected: false, - action: () => { - updateMcpEnabled(false) - quickPanelHook.close() - } - }) - return newList - }, [activedMcpServers, t, assistantMcpServers, navigate, updateMcpEnabled, quickPanelHook]) + }, [activedMcpServers, t, assistantMcpServers, navigate]) - const openQuickPanel = useCallback(() => { + const openManualModePanel = useCallback(() => { quickPanelHook.open({ - title: t('settings.mcp.title'), - list: menuItems, + title: t('assistants.settings.mcp.mode.manual.label'), + list: manualModeMenuItems, symbol: QuickPanelReservedSymbol.Mcp, multiple: true, afterAction({ item }) { item.isSelected = !item.isSelected } }) + }, [manualModeMenuItems, quickPanelHook, t]) + + const menuItems = useMemo(() => { + const newList: QuickPanelListItem[] = [] + + newList.push({ + label: t('assistants.settings.mcp.mode.disabled.label'), + description: t('assistants.settings.mcp.mode.disabled.description'), + icon: , + isSelected: currentMode === 'disabled', + action: () => { + handleModeChange('disabled') + quickPanelHook.close() + } + }) + + newList.push({ + label: t('assistants.settings.mcp.mode.auto.label'), + description: t('assistants.settings.mcp.mode.auto.description'), + icon: , + isSelected: currentMode === 'auto', + action: () => { + handleModeChange('auto') + quickPanelHook.close() + } + }) + + newList.push({ + label: t('assistants.settings.mcp.mode.manual.label'), + description: t('assistants.settings.mcp.mode.manual.description'), + icon: , + isSelected: currentMode === 'manual', + isMenu: true, + action: () => { + openManualModePanel() + } + }) + + return newList + }, [t, currentMode, handleModeChange, quickPanelHook, openManualModePanel]) + + const openQuickPanel = useCallback(() => { + quickPanelHook.open({ + title: t('settings.mcp.title'), + list: menuItems, + symbol: QuickPanelReservedSymbol.Mcp, + multiple: false + }) }, [menuItems, quickPanelHook, t]) - // 使用 useCallback 优化 insertPromptIntoTextArea const insertPromptIntoTextArea = useCallback( (promptText: string) => { setInputValue((prev) => { @@ -245,7 +272,6 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, const selectionEndPosition = cursorPosition + promptText.length const newText = prev.slice(0, cursorPosition) + promptText + prev.slice(cursorPosition) - // 使用 requestAnimationFrame 优化 DOM 操作 requestAnimationFrame(() => { textArea.focus() textArea.setSelectionRange(selectionStart, selectionEndPosition) @@ -424,7 +450,6 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, [activedMcpServers, t, insertPromptIntoTextArea] ) - // 优化 resourcesList 的状态更新 const [resourcesList, setResourcesList] = useState([]) useEffect(() => { @@ -514,17 +539,26 @@ const MCPToolsButton: FC = ({ quickPanel, setInputValue, resizeTextArea, } }, [openPromptList, openQuickPanel, openResourcesList, quickPanel, t]) + const isActive = currentMode !== 'disabled' + + const getButtonIcon = () => { + switch (currentMode) { + case 'auto': + return + case 'disabled': + case 'manual': + default: + return + } + } + return ( - 0} - aria-label={t('settings.mcp.title')}> - + + {getButtonIcon()} ) } -// 使用 React.memo 包装组件 export default React.memo(MCPToolsButton) diff --git a/src/renderer/src/pages/home/Inputbar/tools/components/UrlContextbutton.tsx b/src/renderer/src/pages/home/Inputbar/tools/components/UrlContextbutton.tsx index 2c2b5077a7..aa8aac3bf7 100644 --- a/src/renderer/src/pages/home/Inputbar/tools/components/UrlContextbutton.tsx +++ b/src/renderer/src/pages/home/Inputbar/tools/components/UrlContextbutton.tsx @@ -1,6 +1,7 @@ import { ActionIconButton } from '@renderer/components/Buttons' import { useAssistant } from '@renderer/hooks/useAssistant' import { useTimer } from '@renderer/hooks/useTimer' +import { getEffectiveMcpMode } from '@renderer/types' import { isToolUseModeFunction } from '@renderer/utils/assistant' import { Tooltip } from 'antd' import { Link } from 'lucide-react' @@ -30,8 +31,7 @@ const UrlContextButton: FC = ({ assistantId }) => { () => { const update = { ...assistant } if ( - assistant.mcpServers && - assistant.mcpServers.length > 0 && + getEffectiveMcpMode(assistant) !== 'disabled' && urlContentNewState === true && isToolUseModeFunction(assistant) ) { diff --git a/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx b/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx index 5728887af8..262b1f0492 100644 --- a/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx +++ b/src/renderer/src/pages/home/Inputbar/tools/components/WebSearchQuickPanelManager.tsx @@ -16,7 +16,7 @@ import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders' import type { ToolQuickPanelController, ToolRenderContext } from '@renderer/pages/home/Inputbar/types' import { getProviderByModel } from '@renderer/services/AssistantService' import WebSearchService from '@renderer/services/WebSearchService' -import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types' +import { getEffectiveMcpMode, type WebSearchProvider, type WebSearchProviderId } from '@renderer/types' import { hasObjectKey } from '@renderer/utils' import { isToolUseModeFunction } from '@renderer/utils/assistant' import { isPromptToolUse } from '@renderer/utils/mcp-tools' @@ -108,8 +108,7 @@ export const useWebSearchPanelController = (assistantId: string, quickPanelContr isGeminiModel(model) && isToolUseModeFunction(assistant) && update.enableWebSearch && - assistant.mcpServers && - assistant.mcpServers.length > 0 + getEffectiveMcpMode(assistant) !== 'disabled' ) { update.enableWebSearch = false window.toast.warning(t('chat.mcp.warning.gemini_web_search')) diff --git a/src/renderer/src/pages/settings/AssistantSettings/AssistantMCPSettings.tsx b/src/renderer/src/pages/settings/AssistantSettings/AssistantMCPSettings.tsx index ac89141092..9509a9bec2 100644 --- a/src/renderer/src/pages/settings/AssistantSettings/AssistantMCPSettings.tsx +++ b/src/renderer/src/pages/settings/AssistantSettings/AssistantMCPSettings.tsx @@ -1,8 +1,9 @@ import { InfoCircleOutlined } from '@ant-design/icons' import { Box } from '@renderer/components/Layout' import { useMCPServers } from '@renderer/hooks/useMCPServers' -import type { Assistant, AssistantSettings } from '@renderer/types' -import { Empty, Switch, Tooltip } from 'antd' +import type { Assistant, AssistantSettings, McpMode } from '@renderer/types' +import { getEffectiveMcpMode } from '@renderer/types' +import { Empty, Radio, Switch, Tooltip } from 'antd' import { useTranslation } from 'react-i18next' import styled from 'styled-components' @@ -27,22 +28,26 @@ const AssistantMCPSettings: React.FC = ({ assistant, updateAssistant }) = const { t } = useTranslation() const { mcpServers: allMcpServers } = useMCPServers() + const currentMode = getEffectiveMcpMode(assistant) + + const handleModeChange = (mode: McpMode) => { + updateAssistant({ ...assistant, mcpMode: mode }) + } + const onUpdate = (ids: string[]) => { const mcpServers = ids .map((id) => allMcpServers.find((server) => server.id === id)) .filter((server): server is MCPServer => server !== undefined && server.isActive) - updateAssistant({ ...assistant, mcpServers }) + updateAssistant({ ...assistant, mcpServers, mcpMode: 'manual' }) } const handleServerToggle = (serverId: string) => { const currentServerIds = assistant.mcpServers?.map((server) => server.id) || [] if (currentServerIds.includes(serverId)) { - // Remove server if it's already enabled onUpdate(currentServerIds.filter((id) => id !== serverId)) } else { - // Add server if it's not enabled onUpdate([...currentServerIds, serverId]) } } @@ -58,49 +63,77 @@ const AssistantMCPSettings: React.FC = ({ assistant, updateAssistant }) = - {allMcpServers.length > 0 && ( - - {enabledCount} / {allMcpServers.length} {t('settings.mcp.active')} - - )} - {allMcpServers.length > 0 ? ( - - {allMcpServers.map((server) => { - const isEnabled = assistant.mcpServers?.some((s) => s.id === server.id) || false + + handleModeChange(e.target.value)}> + + + {t('assistants.settings.mcp.mode.disabled.label')} + {t('assistants.settings.mcp.mode.disabled.description')} + + + + + {t('assistants.settings.mcp.mode.auto.label')} + {t('assistants.settings.mcp.mode.auto.description')} + + + + + {t('assistants.settings.mcp.mode.manual.label')} + {t('assistants.settings.mcp.mode.manual.description')} + + + + - return ( - - - {server.name} - {server.description && {server.description}} - {server.baseUrl && {server.baseUrl}} - - - handleServerToggle(server.id)} - size="small" - /> - - - ) - })} - - ) : ( - - - + {currentMode === 'manual' && ( + <> + {allMcpServers.length > 0 && ( + + {enabledCount} / {allMcpServers.length} {t('settings.mcp.active')} + + )} + + {allMcpServers.length > 0 ? ( + + {allMcpServers.map((server) => { + const isEnabled = assistant.mcpServers?.some((s) => s.id === server.id) || false + + return ( + + + {server.name} + {server.description && {server.description}} + {server.baseUrl && {server.baseUrl}} + + + handleServerToggle(server.id)} + size="small" + /> + + + ) + })} + + ) : ( + + + + )} + )} ) @@ -127,9 +160,54 @@ const InfoIcon = styled(InfoCircleOutlined)` cursor: help; ` +const ModeSelector = styled.div` + margin-bottom: 16px; + + .ant-radio-group { + display: flex; + flex-direction: column; + gap: 8px; + } + + .ant-radio-button-wrapper { + height: auto; + padding: 12px 16px; + border-radius: 8px; + border: 1px solid var(--color-border); + + &:not(:first-child)::before { + display: none; + } + + &:first-child { + border-radius: 8px; + } + + &:last-child { + border-radius: 8px; + } + } +` + +const ModeOption = styled.div` + display: flex; + flex-direction: column; + gap: 2px; +` + +const ModeLabel = styled.span` + font-weight: 600; +` + +const ModeDescription = styled.span` + font-size: 12px; + color: var(--color-text-2); +` + const EnabledCount = styled.span` font-size: 12px; color: var(--color-text-2); + margin-bottom: 8px; ` const EmptyContainer = styled.div` diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 0cd57a353a..5b86d3c5c5 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -8,8 +8,9 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import store from '@renderer/store' +import { hubMCPServer } from '@renderer/store/mcp' import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types' -import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types' +import { type FetchChatCompletionParams, getEffectiveMcpMode, isSystemProvider } from '@renderer/types' import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { type Chunk, ChunkType } from '@renderer/types/chunk' import type { Message, ResponseError } from '@renderer/types/newMessage' @@ -51,14 +52,60 @@ import type { StreamProcessorCallbacks } from './StreamProcessingService' const logger = loggerService.withContext('ApiService') -export async function fetchMcpTools(assistant: Assistant) { - // Get MCP tools (Fix duplicate declaration) - let mcpTools: MCPTool[] = [] // Initialize as empty array +/** + * Get the MCP servers to use based on the assistant's MCP mode. + */ +export function getMcpServersForAssistant(assistant: Assistant): MCPServer[] { + const mode = getEffectiveMcpMode(assistant) const allMcpServers = store.getState().mcp.servers || [] const activedMcpServers = allMcpServers.filter((s) => s.isActive) - const assistantMcpServers = assistant.mcpServers || [] - const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id)) + switch (mode) { + case 'disabled': + return [] + case 'auto': + return [hubMCPServer] + case 'manual': { + const assistantMcpServers = assistant.mcpServers || [] + return activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id)) + } + default: + return [] + } +} + +export async function fetchAllActiveServerTools(): Promise { + const allMcpServers = store.getState().mcp.servers || [] + const activedMcpServers = allMcpServers.filter((s) => s.isActive) + + if (activedMcpServers.length === 0) { + return [] + } + + try { + const toolPromises = activedMcpServers.map(async (mcpServer: MCPServer) => { + try { + const tools = await window.api.mcp.listTools(mcpServer) + return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name)) + } catch (error) { + logger.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error as Error) + return [] + } + }) + const results = await Promise.allSettled(toolPromises) + return results + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value) + .flat() + } catch (toolError) { + logger.error('Error fetching all active server tools:', toolError as Error) + return [] + } +} + +export async function fetchMcpTools(assistant: Assistant) { + let mcpTools: MCPTool[] = [] + const enabledMCPs = getMcpServersForAssistant(assistant) if (enabledMCPs && enabledMCPs.length > 0) { try { diff --git a/src/renderer/src/services/__tests__/mcpMode.test.ts b/src/renderer/src/services/__tests__/mcpMode.test.ts new file mode 100644 index 0000000000..9117caa666 --- /dev/null +++ b/src/renderer/src/services/__tests__/mcpMode.test.ts @@ -0,0 +1,54 @@ +import { describe, expect, it } from 'vitest' + +import type { Assistant, MCPServer } from '../../types' +import { getEffectiveMcpMode } from '../../types' + +describe('getEffectiveMcpMode', () => { + it('should return mcpMode when explicitly set to auto', () => { + const assistant: Partial = { mcpMode: 'auto' } + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('auto') + }) + + it('should return disabled when mcpMode is explicitly disabled', () => { + const assistant: Partial = { mcpMode: 'disabled' } + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled') + }) + + it('should return manual when mcpMode is explicitly manual', () => { + const assistant: Partial = { mcpMode: 'manual' } + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('manual') + }) + + it('should return manual when no mcpMode but mcpServers has items (backward compatibility)', () => { + const assistant: Partial = { + mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[] + } + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('manual') + }) + + it('should return disabled when no mcpMode and no mcpServers (backward compatibility)', () => { + const assistant: Partial = {} + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled') + }) + + it('should return disabled when no mcpMode and empty mcpServers (backward compatibility)', () => { + const assistant: Partial = { mcpServers: [] } + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled') + }) + + it('should prioritize explicit mcpMode over mcpServers presence', () => { + const assistant: Partial = { + mcpMode: 'disabled', + mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[] + } + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled') + }) + + it('should return auto when mcpMode is auto regardless of mcpServers', () => { + const assistant: Partial = { + mcpMode: 'auto', + mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[] + } + expect(getEffectiveMcpMode(assistant as Assistant)).toBe('auto') + }) +}) diff --git a/src/renderer/src/store/mcp.ts b/src/renderer/src/store/mcp.ts index 72cb4e39dc..0e2028a2b8 100644 --- a/src/renderer/src/store/mcp.ts +++ b/src/renderer/src/store/mcp.ts @@ -86,6 +86,20 @@ export { mcpSlice } // Export the reducer as default export export default mcpSlice.reducer +/** + * Hub MCP server for auto mode - aggregates all MCP servers for LLM code mode. + * This server is injected automatically when mcpMode === 'auto'. + */ +export const hubMCPServer: BuiltinMCPServer = { + id: 'hub', + name: BuiltinMCPServerNames.hub, + type: 'inMemory', + isActive: true, + provider: 'CherryAI', + installSource: 'builtin', + isTrusted: true +} + /** * User-installable built-in MCP servers shown in the UI. * diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 8f3c046c5a..c0b1fac916 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -27,6 +27,8 @@ export * from './ocr' export * from './plugin' export * from './provider' +export type McpMode = 'disabled' | 'auto' | 'manual' + export type Assistant = { id: string name: string @@ -47,6 +49,8 @@ export type Assistant = { // enableUrlContext 是 Gemini/Anthropic 的特有功能 enableUrlContext?: boolean enableGenerateImage?: boolean + /** MCP mode: 'disabled' (no MCP), 'auto' (hub server only), 'manual' (user selects servers) */ + mcpMode?: McpMode mcpServers?: MCPServer[] knowledgeRecognition?: 'off' | 'on' regularPhrases?: QuickPhrase[] // Added for regular phrase @@ -57,6 +61,15 @@ export type Assistant = { targetLanguage?: TranslateLanguage } +/** + * Get the effective MCP mode for an assistant with backward compatibility. + * Legacy assistants without mcpMode default based on mcpServers presence. + */ +export function getEffectiveMcpMode(assistant: Assistant): McpMode { + if (assistant.mcpMode) return assistant.mcpMode + return (assistant.mcpServers?.length ?? 0) > 0 ? 'manual' : 'disabled' +} + export type TranslateAssistant = Assistant & { model: Model content: string diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index 691689dcc4..3bc4a273cb 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -13,7 +13,7 @@ import { isFunctionCallingModel, isVisionModel } from '@renderer/config/models' import i18n from '@renderer/i18n' import { currentSpan } from '@renderer/services/SpanManagerService' import store from '@renderer/store' -import { addMCPServer } from '@renderer/store/mcp' +import { addMCPServer, hubMCPServer } from '@renderer/store/mcp' import type { Assistant, MCPCallToolResponse, @@ -325,7 +325,16 @@ export function filterMCPTools( export function getMcpServerByTool(tool: MCPTool) { const servers = store.getState().mcp.servers - return servers.find((s) => s.id === tool.serverId) + const server = servers.find((s) => s.id === tool.serverId) + if (server) { + return server + } + // For hub server (auto mode), the server isn't in the store + // Return the hub server constant if the tool's serverId matches + if (tool.serverId === 'hub') { + return hubMCPServer + } + return undefined } export function isToolAutoApproved(tool: MCPTool, server?: MCPServer): boolean {