feat: add available tools section to HUB_MODE_SYSTEM_PROMPT

- Add shared utility for generating MCP tool function names (serverName_toolName format)
- Update hub server to use consistent function naming across search, exec and prompt
- Add fetchAllActiveServerTools to ApiService for renderer process
- Update parameterBuilder to include available tools in auto/hub mode prompt
- Use CacheService for 1-minute tools caching in hub server
- Remove ToolRegistry in favor of direct fetching with caching
- Update search ranking to include server name matching
- Fix tests to use new naming format

Amp-Thread-ID: https://ampcode.com/threads/T-019b6971-d5c9-7719-9245-a89390078647
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Vaayne 2025-12-29 18:51:15 +08:00
parent 6c951a3945
commit 7a5b011bfa
27 changed files with 846 additions and 473 deletions

43
packages/shared/mcp.ts Normal file
View File

@ -0,0 +1,43 @@
/**
* Convert a string to camelCase, ensuring it's a valid JavaScript identifier.
*/
export function toCamelCase(str: string): string {
let result = str
.replace(/[^a-zA-Z0-9]+(.)/g, (_, char) => char.toUpperCase())
.replace(/^[A-Z]/, (char) => char.toLowerCase())
.replace(/[^a-zA-Z0-9]/g, '')
// Ensure valid JS identifier: must start with letter or underscore
if (result && !/^[a-zA-Z_]/.test(result)) {
result = '_' + result
}
return result
}
/**
* Generate a unique function name from server name and tool name.
* Format: serverName_toolName (camelCase)
*/
export function generateMcpToolFunctionName(
serverName: string | undefined,
toolName: string,
existingNames?: Set<string>
): string {
const serverPrefix = serverName ? toCamelCase(serverName) : ''
const toolNameCamel = toCamelCase(toolName)
const baseName = serverPrefix ? `${serverPrefix}_${toolNameCamel}` : toolNameCamel
if (!existingNames) {
return baseName
}
let name = baseName
let counter = 1
while (existingNames.has(name)) {
name = `${baseName}${counter}`
counter++
}
existingNames.add(name)
return name
}

View File

@ -6,6 +6,62 @@ A built-in MCP server that aggregates all active MCP servers in Cherry Studio an
The Hub server enables LLMs to discover and call tools from all active MCP servers without needing to know the specific server names or tool signatures upfront.
## Auto Mode Integration
The Hub server is the core component of Cherry Studio's **Auto MCP Mode**. When an assistant is set to Auto mode:
1. **Automatic Injection**: The Hub server is automatically injected as the only MCP server for the assistant
2. **System Prompt**: A specialized system prompt (`HUB_MODE_SYSTEM_PROMPT`) is appended to guide the LLM on how to use the `search` and `exec` tools
3. **Dynamic Discovery**: The LLM can discover and use any tools from all active MCP servers without manual configuration
### MCP Modes
Cherry Studio supports three MCP modes per assistant:
| Mode | Description | Tools Available |
|------|-------------|-----------------|
| **Disabled** | No MCP tools | None |
| **Auto** | Hub server only | `search`, `exec` |
| **Manual** | User selects servers | Selected server tools |
### How Auto Mode Works
```
User Message
┌─────────────────────────────────────────┐
│ Assistant (mcpMode: 'auto') │
│ │
│ System Prompt + HUB_MODE_SYSTEM_PROMPT │
│ Tools: [hub.search, hub.exec] │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
│ LLM decides to use MCP tools │
│ │
│ 1. search({ query: "github,repo" }) │
│ 2. exec({ code: "await searchRepos()" })│
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
│ Hub Server │
│ │
│ Aggregates all active MCP servers │
│ Routes tool calls to appropriate server │
└─────────────────────────────────────────┘
```
### Relevant Code
- **Type Definition**: `src/renderer/src/types/index.ts` - `McpMode` type and `getEffectiveMcpMode()`
- **Hub Server Constant**: `src/renderer/src/store/mcp.ts` - `hubMCPServer`
- **Server Selection**: `src/renderer/src/services/ApiService.ts` - `getMcpServersForAssistant()`
- **System Prompt**: `src/renderer/src/config/prompts.ts` - `HUB_MODE_SYSTEM_PROMPT`
- **Prompt Injection**: `src/renderer/src/aiCore/prepareParams/parameterBuilder.ts`
## Tools
### `search`
@ -95,11 +151,24 @@ return { users, repos };
## Configuration
The Hub server is a built-in server identified as `@cherry/hub`. To enable it:
The Hub server is a built-in server identified as `@cherry/hub`.
### Using Auto Mode (Recommended)
The easiest way to use the Hub server is through Auto mode:
1. Click the **MCP Tools** button (hammer icon) in the input bar
2. Select **Auto** mode
3. The Hub server is automatically enabled for the assistant
### Manual Configuration
Alternatively, you can enable the Hub server manually:
1. Go to **Settings** → **MCP Servers**
2. Find **Hub** in the built-in servers list
3. Toggle it on
4. In the assistant's MCP settings, select the Hub server
## Caching

View File

@ -23,20 +23,13 @@ describe('generator', () => {
type: 'mcp' as const
}
const server = {
id: 'github',
name: 'github-server',
isActive: true
}
const existingNames = new Set<string>()
const callTool = async () => ({ success: true })
const result = generateToolFunction(tool, server as any, existingNames, callTool)
const result = generateToolFunction(tool, existingNames, callTool)
expect(result.toolId).toBe('github__search_repos')
expect(result.functionName).toBe('searchRepos')
expect(result.jsCode).toContain('async function searchRepos')
expect(result.functionName).toBe('githubServer_searchRepos')
expect(result.jsCode).toContain('async function githubServer_searchRepos')
expect(result.jsCode).toContain('Search for GitHub repositories')
expect(result.jsCode).toContain('__callTool')
})
@ -51,13 +44,12 @@ describe('generator', () => {
type: 'mcp' as const
}
const server = { id: 'server1', name: 'server1', isActive: true }
const existingNames = new Set<string>(['search'])
const existingNames = new Set<string>(['server1_search'])
const callTool = async () => ({})
const result = generateToolFunction(tool, server as any, existingNames, callTool)
const result = generateToolFunction(tool, existingNames, callTool)
expect(result.functionName).toBe('search1')
expect(result.functionName).toBe('server1_search1')
})
it('handles enum types in schema', () => {
@ -78,11 +70,10 @@ describe('generator', () => {
type: 'mcp' as const
}
const server = { id: 'browser', name: 'browser', isActive: true }
const existingNames = new Set<string>()
const callTool = async () => ({})
const result = generateToolFunction(tool, server as any, existingNames, callTool)
const result = generateToolFunction(tool, existingNames, callTool)
expect(result.jsCode).toContain('"chromium" | "firefox" | "webkit"')
})
@ -95,9 +86,8 @@ describe('generator', () => {
serverId: 's1',
serverName: 'server1',
toolName: 'tool1',
toolId: 's1__tool1',
functionName: 'tool1',
jsCode: 'async function tool1() {}',
functionName: 'server1_tool1',
jsCode: 'async function server1_tool1() {}',
fn: async () => ({}),
signature: '{}',
returns: 'unknown'
@ -106,9 +96,8 @@ describe('generator', () => {
serverId: 's2',
serverName: 'server2',
toolName: 'tool2',
toolId: 's2__tool2',
functionName: 'tool2',
jsCode: 'async function tool2() {}',
functionName: 'server2_tool2',
jsCode: 'async function server2_tool2() {}',
fn: async () => ({}),
signature: '{}',
returns: 'unknown'
@ -118,8 +107,8 @@ describe('generator', () => {
const result = generateToolsCode(tools)
expect(result).toContain('Found 2 tool(s)')
expect(result).toContain('async function tool1')
expect(result).toContain('async function tool2')
expect(result).toContain('async function server1_tool1')
expect(result).toContain('async function server2_tool2')
})
it('returns message for empty tools', () => {

View File

@ -1,102 +1,93 @@
import type { MCPServer } from '@types'
import type { MCPTool } from '@types'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { HubServer } from '../index'
import { initHubBridge } from '../mcp-bridge'
const mockMcpServers: MCPServer[] = [
const mockTools: MCPTool[] = [
{
id: 'github',
name: 'GitHub',
command: 'npx',
args: ['-y', 'github-mcp-server'],
isActive: true
} as MCPServer,
id: 'github__search_repos',
name: 'search_repos',
description: 'Search for GitHub repositories',
serverId: 'github',
serverName: 'GitHub',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string', description: 'Search query' },
limit: { type: 'number', description: 'Max results' }
},
required: ['query']
},
type: 'mcp'
},
{
id: 'database',
name: 'Database',
command: 'npx',
args: ['-y', 'db-mcp-server'],
isActive: true
} as MCPServer
id: 'github__get_user',
name: 'get_user',
description: 'Get GitHub user profile',
serverId: 'github',
serverName: 'GitHub',
inputSchema: {
type: 'object',
properties: {
username: { type: 'string', description: 'GitHub username' }
},
required: ['username']
},
type: 'mcp'
},
{
id: 'database__query',
name: 'query',
description: 'Execute a database query',
serverId: 'database',
serverName: 'Database',
inputSchema: {
type: 'object',
properties: {
sql: { type: 'string', description: 'SQL query to execute' }
},
required: ['sql']
},
type: 'mcp'
}
]
const mockToolDefinitions = {
github: [
{
name: 'search_repos',
description: 'Search for GitHub repositories',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string', description: 'Search query' },
limit: { type: 'number', description: 'Max results' }
},
required: ['query']
vi.mock('@main/services/MCPService', () => ({
default: {
listAllActiveServerTools: vi.fn(async () => mockTools),
callToolById: vi.fn(async (toolId: string, args: unknown) => {
if (toolId === 'github__search_repos') {
return {
content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args }) }]
}
}
},
{
name: 'get_user',
description: 'Get GitHub user profile',
inputSchema: {
type: 'object',
properties: {
username: { type: 'string', description: 'GitHub username' }
},
required: ['username']
if (toolId === 'github__get_user') {
return {
content: [{ type: 'text', text: JSON.stringify({ username: (args as any).username, id: 123 }) }]
}
}
}
],
database: [
{
name: 'query',
description: 'Execute a database query',
inputSchema: {
type: 'object',
properties: {
sql: { type: 'string', description: 'SQL query to execute' }
},
required: ['sql']
if (toolId === 'database__query') {
return {
content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }]
}
}
}
]
}
return { content: [{ type: 'text', text: '{}' }] }
})
}
}))
const mockMcpService = {
listTools: vi.fn(async (_: null, server: MCPServer) => {
return mockToolDefinitions[server.id as keyof typeof mockToolDefinitions] || []
}),
callTool: vi.fn(async (_: null, args: { server: MCPServer; name: string; args: unknown }) => {
if (args.server.id === 'github' && args.name === 'search_repos') {
return {
content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args.args }) }]
}
}
if (args.server.id === 'github' && args.name === 'get_user') {
return {
content: [{ type: 'text', text: JSON.stringify({ username: (args.args as any).username, id: 123 }) }]
}
}
if (args.server.id === 'database' && args.name === 'query') {
return {
content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }]
}
}
return { content: [{ type: 'text', text: '{}' }] }
})
}
import mcpService from '@main/services/MCPService'
describe('HubServer Integration', () => {
let hubServer: HubServer
beforeEach(() => {
vi.clearAllMocks()
initHubBridge(mockMcpService as any, () => mockMcpServers)
hubServer = new HubServer()
})
afterEach(() => {
hubServer.invalidateCache()
vi.clearAllMocks()
})
describe('full search → exec flow', () => {
@ -106,10 +97,10 @@ describe('HubServer Integration', () => {
expect(searchResult.content).toBeDefined()
const searchText = JSON.parse(searchResult.content[0].text)
expect(searchText.total).toBeGreaterThan(0)
expect(searchText.tools).toContain('searchRepos')
expect(searchText.tools).toContain('gitHub_searchRepos')
const execResult = await (hubServer as any).handleExec({
code: 'return await searchRepos({ query: "test" })'
code: 'return await gitHub_searchRepos({ query: "test" })'
})
expect(execResult.content).toBeDefined()
@ -123,8 +114,8 @@ describe('HubServer Integration', () => {
const execResult = await (hubServer as any).handleExec({
code: `
const results = await parallel(
searchRepos({ query: "react" }),
getUser({ username: "octocat" })
gitHub_searchRepos({ query: "react" }),
gitHub_getUser({ username: "octocat" })
);
return results
`
@ -140,21 +131,31 @@ describe('HubServer Integration', () => {
const searchResult = await (hubServer as any).handleSearch({ query: 'query' })
const searchText = JSON.parse(searchResult.content[0].text)
expect(searchText.tools).toContain('query')
expect(searchText.tools).toContain('database_query')
})
})
describe('cache invalidation', () => {
it('refreshes tools after invalidation', async () => {
describe('tools caching', () => {
it('uses cached tools within TTL', async () => {
await (hubServer as any).handleSearch({ query: 'github' })
const firstCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
const initialCallCount = mockMcpService.listTools.mock.calls.length
await (hubServer as any).handleSearch({ query: 'github' })
const secondCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
expect(secondCallCount).toBe(firstCallCount) // Should use cache
})
it('refreshes tools after cache invalidation', async () => {
await (hubServer as any).handleSearch({ query: 'github' })
const firstCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
hubServer.invalidateCache()
await (hubServer as any).handleSearch({ query: 'github' })
const secondCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
expect(mockMcpService.listTools.mock.calls.length).toBeGreaterThan(initialCallCount)
expect(secondCallCount).toBe(firstCallCount + 1)
})
})

View File

@ -16,9 +16,8 @@ const createMockTool = (partial: Partial<GeneratedTool>): GeneratedTool => ({
serverId: 'server1',
serverName: 'server1',
toolName: 'tool',
toolId: 'server1__tool',
functionName: 'mockTool',
jsCode: 'async function mockTool() {}',
functionName: 'server1_mockTool',
jsCode: 'async function server1_mockTool() {}',
fn: async (params) => ({ result: params }),
signature: '{}',
returns: 'unknown',

View File

@ -4,12 +4,11 @@ import { searchTools } from '../search'
import type { GeneratedTool } from '../types'
const createMockTool = (partial: Partial<GeneratedTool>): GeneratedTool => {
const functionName = partial.functionName || 'tool'
const functionName = partial.functionName || 'server1_tool'
return {
serverId: 'server1',
serverName: 'server1',
toolName: partial.toolName || 'tool',
toolId: partial.toolId || 'server1__tool',
functionName,
jsCode: `async function ${functionName}() {}`,
fn: async () => ({}),
@ -86,13 +85,13 @@ describe('search', () => {
it('respects limit parameter', () => {
const tools = Array.from({ length: 20 }, (_, i) =>
createMockTool({ toolName: `tool${i}`, functionName: `tool${i}`, toolId: `s__tool${i}` })
createMockTool({ toolName: `tool${i}`, functionName: `server1_tool${i}` })
)
const result = searchTools(tools, { query: 'tool', limit: 5 })
expect(result.total).toBe(20)
const matches = (result.tools.match(/async function tool\d+/g) || []).length
const matches = (result.tools.match(/async function server1_tool\d+/g) || []).length
expect(matches).toBe(5)
})

View File

@ -1,25 +1,8 @@
import type { MCPServer, MCPTool } from '@types'
import { generateMcpToolFunctionName } from '@shared/mcp'
import type { MCPTool } from '@types'
import type { GeneratedTool } from './types'
function toCamelCase(str: string): string {
return str
.replace(/[^a-zA-Z0-9]+(.)/g, (_, char) => char.toUpperCase())
.replace(/^[A-Z]/, (char) => char.toLowerCase())
.replace(/[^a-zA-Z0-9]/g, '')
}
function makeUniqueFunctionName(baseName: string, existingNames: Set<string>): string {
let name = baseName
let counter = 1
while (existingNames.has(name)) {
name = `${baseName}${counter}`
counter++
}
existingNames.add(name)
return name
}
function jsonSchemaToSignature(schema: Record<string, unknown> | undefined): string {
if (!schema || typeof schema !== 'object') {
return '{}'
@ -109,13 +92,10 @@ function generateJSDoc(tool: MCPTool, signature: string, returns: string): strin
export function generateToolFunction(
tool: MCPTool,
server: MCPServer,
existingNames: Set<string>,
callToolFn: (toolId: string, params: unknown) => Promise<unknown>
callToolFn: (functionName: string, params: unknown) => Promise<unknown>
): GeneratedTool {
const toolId = `${server.id}__${tool.name}`
const baseName = toCamelCase(tool.name)
const functionName = makeUniqueFunctionName(baseName, existingNames)
const functionName = generateMcpToolFunctionName(tool.serverName, tool.name, existingNames)
const inputSchema = tool.inputSchema as Record<string, unknown> | undefined
const outputSchema = tool.outputSchema as Record<string, unknown> | undefined
@ -127,18 +107,17 @@ export function generateToolFunction(
const jsCode = `${jsDoc}
async function ${functionName}(params) {
return await __callTool("${toolId}", params);
return await __callTool("${functionName}", params);
}`
const fn = async (params: unknown): Promise<unknown> => {
return await callToolFn(toolId, params)
return await callToolFn(functionName, params)
}
return {
serverId: server.id,
serverName: server.name,
serverId: tool.serverId,
serverName: tool.serverName,
toolName: tool.name,
toolId,
functionName,
jsCode,
fn,

View File

@ -1,13 +1,17 @@
import { loggerService } from '@logger'
import { CacheService } from '@main/services/CacheService'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js'
import { generateToolFunction } from './generator'
import { callMcpTool, listAllTools } from './mcp-bridge'
import { Runtime } from './runtime'
import { searchTools } from './search'
import { ToolRegistry } from './tool-registry'
import type { ExecInput, SearchQuery } from './types'
import type { ExecInput, GeneratedTool, SearchQuery } from './types'
const logger = loggerService.withContext('MCPServer:Hub')
const TOOLS_CACHE_KEY = 'hub:tools'
const TOOLS_CACHE_TTL = 60 * 1000 // 1 minute
/**
* Hub MCP Server - A meta-server that aggregates all active MCP servers.
@ -23,11 +27,9 @@ const logger = loggerService.withContext('MCPServer:Hub')
*/
export class HubServer {
public server: Server
private toolRegistry: ToolRegistry
private runtime: Runtime
constructor() {
this.toolRegistry = new ToolRegistry()
this.runtime = new Runtime()
this.server = new Server(
@ -118,12 +120,32 @@ export class HubServer {
})
}
private async fetchTools(): Promise<GeneratedTool[]> {
const cached = CacheService.get<GeneratedTool[]>(TOOLS_CACHE_KEY)
if (cached) {
logger.debug('Returning cached tools')
return cached
}
logger.debug('Fetching fresh tools')
const allTools = await listAllTools()
const existingNames = new Set<string>()
const tools = allTools.map((tool) => generateToolFunction(tool, existingNames, callMcpTool))
CacheService.set(TOOLS_CACHE_KEY, tools, TOOLS_CACHE_TTL)
return tools
}
invalidateCache(): void {
CacheService.remove(TOOLS_CACHE_KEY)
logger.debug('Tools cache invalidated')
}
private async handleSearch(query: SearchQuery) {
if (!query.query || typeof query.query !== 'string') {
throw new McpError(ErrorCode.InvalidParams, 'query parameter is required and must be a string')
}
const tools = await this.toolRegistry.getTools()
const tools = await this.fetchTools()
const result = searchTools(tools, query)
return {
@ -141,7 +163,7 @@ export class HubServer {
throw new McpError(ErrorCode.InvalidParams, 'code parameter is required and must be a string')
}
const tools = await this.toolRegistry.getTools()
const tools = await this.fetchTools()
const result = await this.runtime.execute(input.code, tools)
return {
@ -153,10 +175,6 @@ export class HubServer {
]
}
}
invalidateCache(): void {
this.toolRegistry.invalidate()
}
}
export default HubServer

View File

@ -1,84 +1,38 @@
import { loggerService } from '@logger'
import { BuiltinMCPServerNames, type MCPServer, type MCPTool } from '@types'
/**
* Bridge module for Hub server to access MCPService.
* Re-exports the methods needed by tool-registry and runtime.
*/
import mcpService from '@main/services/MCPService'
import { generateMcpToolFunctionName } from '@shared/mcp'
const logger = loggerService.withContext('MCPServer:Hub:Bridge')
export const listAllTools = () => mcpService.listAllActiveServerTools()
let mcpServiceInstance: MCPServiceInterface | null = null
let mcpServersGetter: (() => MCPServer[]) | null = null
const toolFunctionNameToIdMap = new Map<string, { serverId: string; toolName: string }>()
interface MCPServiceInterface {
listTools(_: null, server: MCPServer): Promise<MCPTool[]>
callTool(
_: null,
args: { server: MCPServer; name: string; args: unknown; callId?: string }
): Promise<{ content: Array<{ type: string; text?: string }> }>
}
export function setMCPService(service: MCPServiceInterface): void {
mcpServiceInstance = service
}
export function setMCPServersGetter(getter: () => MCPServer[]): void {
mcpServersGetter = getter
}
export function initHubBridge(service: MCPServiceInterface, serversGetter: () => MCPServer[]): void {
mcpServiceInstance = service
mcpServersGetter = serversGetter
}
export function getActiveServers(): MCPServer[] {
if (!mcpServersGetter) {
logger.warn('MCP servers getter not set')
return []
}
const servers = mcpServersGetter()
return servers.filter((s) => s.isActive && s.name !== BuiltinMCPServerNames.hub)
}
export async function listToolsFromServer(server: MCPServer): Promise<MCPTool[]> {
if (!mcpServiceInstance) {
logger.error('MCP service not initialized')
return []
}
try {
return await mcpServiceInstance.listTools(null, server)
} catch (error) {
logger.error(`Failed to list tools from server ${server.name}:`, error as Error)
return []
export async function refreshToolMap(): Promise<void> {
const tools = await listAllTools()
toolFunctionNameToIdMap.clear()
const existingNames = new Set<string>()
for (const tool of tools) {
const functionName = generateMcpToolFunctionName(tool.serverName, tool.name, existingNames)
toolFunctionNameToIdMap.set(functionName, { serverId: tool.serverId, toolName: tool.name })
}
}
export async function callMcpTool(toolId: string, params: unknown): Promise<unknown> {
if (!mcpServiceInstance) {
throw new Error('MCP service not initialized')
export const callMcpTool = async (functionName: string, params: unknown): Promise<unknown> => {
const toolInfo = toolFunctionNameToIdMap.get(functionName)
if (!toolInfo) {
await refreshToolMap()
const retryToolInfo = toolFunctionNameToIdMap.get(functionName)
if (!retryToolInfo) {
throw new Error(`Tool not found: ${functionName}`)
}
const toolId = `${retryToolInfo.serverId}__${retryToolInfo.toolName}`
const result = await mcpService.callToolById(toolId, params)
return extractToolResult(result)
}
const parts = toolId.split('__')
if (parts.length < 2) {
throw new Error(`Invalid tool ID format: ${toolId}`)
}
const serverId = parts[0]
const toolName = parts.slice(1).join('__')
const servers = getActiveServers()
const server = servers.find((s) => s.id === serverId)
if (!server) {
throw new Error(`Server not found: ${serverId}`)
}
logger.debug(`Calling tool ${toolName} on server ${server.name}`)
const result = await mcpServiceInstance.callTool(null, {
server,
name: toolName,
args: params
})
const toolId = `${toolInfo.serverId}__${toolInfo.toolName}`
const result = await mcpService.callToolById(toolId, params)
return extractToolResult(result)
}

View File

@ -37,7 +37,15 @@ export function searchTools(tools: GeneratedTool[], query: SearchQuery): SearchR
}
function buildSearchText(tool: GeneratedTool): string {
const parts = [tool.toolName, tool.functionName, tool.serverName, tool.description || '', tool.signature]
const combinedName = tool.serverName ? `${tool.serverName}_${tool.toolName}` : tool.toolName
const parts = [
tool.toolName,
tool.functionName,
tool.serverName,
combinedName,
tool.description || '',
tool.signature
]
return parts.join(' ')
}
@ -55,10 +63,12 @@ function rankTools(tools: GeneratedTool[], keywords: string[]): GeneratedTool[]
function calculateScore(tool: GeneratedTool, keywords: string[]): number {
let score = 0
const toolName = tool.toolName.toLowerCase()
const serverName = (tool.serverName || '').toLowerCase()
const functionName = tool.functionName.toLowerCase()
const description = (tool.description || '').toLowerCase()
for (const keyword of keywords) {
// Match tool name
if (toolName === keyword) {
score += 10
} else if (toolName.startsWith(keyword)) {
@ -67,12 +77,24 @@ function calculateScore(tool: GeneratedTool, keywords: string[]): number {
score += 3
}
if (functionName === keyword) {
// Match server name
if (serverName === keyword) {
score += 8
} else if (functionName.includes(keyword)) {
} else if (serverName.startsWith(keyword)) {
score += 4
} else if (serverName.includes(keyword)) {
score += 2
}
// Match function name (serverName_toolName format)
if (functionName === keyword) {
score += 10
} else if (functionName.startsWith(keyword)) {
score += 5
} else if (functionName.includes(keyword)) {
score += 3
}
if (description.includes(keyword)) {
const count = (description.match(new RegExp(escapeRegex(keyword), 'g')) || []).length
score += Math.min(count, 3)

View File

@ -1,98 +0,0 @@
import { loggerService } from '@logger'
import { generateToolFunction } from './generator'
import { callMcpTool, getActiveServers, listToolsFromServer } from './mcp-bridge'
import type { GeneratedTool, ToolRegistryOptions } from './types'
const logger = loggerService.withContext('MCPServer:Hub:Registry')
const DEFAULT_TTL = 10 * 60 * 1000
export class ToolRegistry {
private tools: Map<string, GeneratedTool> = new Map()
private lastRefresh: number = 0
private readonly ttl: number
private refreshPromise: Promise<void> | null = null
constructor(options: ToolRegistryOptions = {}) {
this.ttl = options.ttl ?? DEFAULT_TTL
}
async getTools(): Promise<GeneratedTool[]> {
if (this.isExpired()) {
await this.refresh()
}
return Array.from(this.tools.values())
}
async getTool(toolId: string): Promise<GeneratedTool | undefined> {
if (this.isExpired()) {
await this.refresh()
}
return this.tools.get(toolId)
}
getToolByFunctionName(functionName: string): GeneratedTool | undefined {
for (const tool of this.tools.values()) {
if (tool.functionName === functionName) {
return tool
}
}
return undefined
}
private isExpired(): boolean {
return Date.now() - this.lastRefresh > this.ttl
}
invalidate(): void {
this.lastRefresh = 0
this.tools.clear()
logger.debug('Tool registry invalidated')
}
async refresh(): Promise<void> {
if (this.refreshPromise) {
return this.refreshPromise
}
this.refreshPromise = this.doRefresh()
try {
await this.refreshPromise
} finally {
this.refreshPromise = null
}
}
private async doRefresh(): Promise<void> {
logger.debug('Refreshing tool registry')
const servers = getActiveServers()
const newTools = new Map<string, GeneratedTool>()
const existingNames = new Set<string>()
for (const server of servers) {
try {
const serverTools = await listToolsFromServer(server)
for (const tool of serverTools) {
const generatedTool = generateToolFunction(tool, server, existingNames, callMcpTool)
newTools.set(generatedTool.toolId, generatedTool)
}
} catch (error) {
logger.error(`Failed to list tools from server ${server.name}:`, error as Error)
}
}
this.tools = newTools
this.lastRefresh = Date.now()
logger.debug(`Tool registry refreshed with ${this.tools.size} tools`)
}
getToolCount(): number {
return this.tools.size
}
}

View File

@ -4,7 +4,6 @@ export interface GeneratedTool {
serverId: string
serverName: string
toolName: string
toolId: string
functionName: string
jsCode: string
fn: (params: unknown) => Promise<unknown>
@ -42,7 +41,7 @@ export interface MCPToolWithServer extends MCPTool {
}
export interface ExecutionContext {
__callTool: (toolId: string, params: unknown) => Promise<unknown>
__callTool: (functionName: string, params: unknown) => Promise<unknown>
parallel: <T>(...promises: Promise<T>[]) => Promise<T[]>
settle: <T>(...promises: Promise<T>[]) => Promise<PromiseSettledResult<T>[]>
console: ConsoleMethods

View File

@ -3,8 +3,8 @@ import os from 'node:os'
import path from 'node:path'
import { loggerService } from '@logger'
import { getMCPServersFromRedux } from '@main/apiServer/utils/mcp'
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
import { initHubBridge } from '@main/mcpServers/hub/mcp-bridge'
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
import { buildFunctionCallToolName } from '@main/utils/mcp'
import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process'
@ -59,7 +59,6 @@ import DxtService from './DxtService'
import { CallBackServer } from './mcp/oauth/callback'
import { McpOAuthClientProvider } from './mcp/oauth/provider'
import { ServerLogBuffer } from './mcp/ServerLogBuffer'
import { reduxService } from './ReduxService'
import { windowService } from './WindowService'
// Generic type for caching wrapped functions
@ -165,27 +164,61 @@ class McpService {
this.checkMcpConnectivity = this.checkMcpConnectivity.bind(this)
this.getServerVersion = this.getServerVersion.bind(this)
this.getServerLogs = this.getServerLogs.bind(this)
this.initializeHubDependencies()
}
private initializeHubDependencies(): void {
initHubBridge(
{
listTools: (_: null, server: MCPServer) => this.listToolsImpl(server),
callTool: async (_: null, args: { server: MCPServer; name: string; args: unknown; callId?: string }) => {
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, args)
}
},
() => {
try {
const servers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
return servers || []
} catch {
return []
}
/**
* List all tools from all active MCP servers (excluding hub).
* Used by Hub server's tool registry.
*/
public async listAllActiveServerTools(): Promise<MCPTool[]> {
const servers = await getMCPServersFromRedux()
const allTools: MCPTool[] = []
for (const server of servers) {
if (!server.isActive) {
continue
}
)
try {
const tools = await this.listToolsImpl(server)
allTools.push(...tools)
} catch (error) {
logger.error(`[listAllActiveServerTools] Failed to list tools from ${server.name}:`, error as Error)
}
}
return allTools
}
/**
* Call a tool by its full ID (serverId__toolName format).
* Used by Hub server's runtime.
*/
public async callToolById(
toolId: string,
params: unknown
): Promise<{ content: Array<{ type: string; text?: string }> }> {
const parts = toolId.split('__')
if (parts.length < 2) {
throw new Error(`Invalid tool ID format: ${toolId}`)
}
const serverId = parts[0]
const toolName = parts.slice(1).join('__')
const servers = await getMCPServersFromRedux()
const server = servers.find((s) => s.id === serverId)
if (!server) {
throw new Error(`Server not found: ${serverId}`)
}
logger.debug(`[callToolById] Calling tool ${toolName} on server ${server.name}`)
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, {
server,
name: toolName,
args: params
})
}
private getServerKey(server: MCPServer): string {

View File

@ -26,11 +26,13 @@ import {
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { getHubModeSystemPrompt } from '@renderer/config/prompts'
import { fetchAllActiveServerTools } from '@renderer/services/ApiService'
import { getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
import { type Assistant, getEffectiveMcpMode, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
import { replacePromptVariables } from '@renderer/utils/prompt'
@ -244,7 +246,15 @@ export async function buildStreamTextParams(
}
if (assistant.prompt) {
params.system = await replacePromptVariables(assistant.prompt, model.name)
let systemPrompt = await replacePromptVariables(assistant.prompt, model.name)
if (getEffectiveMcpMode(assistant) === 'auto') {
const allActiveTools = await fetchAllActiveServerTools()
systemPrompt = systemPrompt + '\n\n' + getHubModeSystemPrompt(allActiveTools)
}
params.system = systemPrompt
} else if (getEffectiveMcpMode(assistant) === 'auto') {
const allActiveTools = await fetchAllActiveServerTools()
params.system = getHubModeSystemPrompt(allActiveTools)
}
logger.debug('params', params)

View File

@ -1,3 +1,4 @@
import { generateMcpToolFunctionName } from '@shared/mcp'
import dayjs from 'dayjs'
export const AGENT_PROMPT = `
@ -468,3 +469,68 @@ Example: [nytimes.com](https://nytimes.com/some-page).
If have multiple citations, please directly list them like this:
[www.nytimes.com](https://nytimes.com/some-page)[www.bbc.com](https://bbc.com/some-page)
`
const HUB_MODE_SYSTEM_PROMPT_BASE = `
## MCP Tools (Code Mode)
You have access to MCP tools via the hub server.
### Workflow
1. Call \`search\` with relevant keywords to discover tools
2. Review the returned function signatures and their parameters
3. Call \`exec\` with JavaScript code using those functions
4. The last expression in your code becomes the result
### Example Usage
**Step 1: Search for tools**
\`\`\`
search({ query: "github,repository" })
\`\`\`
**Step 2: Use discovered tools**
\`\`\`javascript
exec({
code: \`
const repos = await searchRepos({ query: "react", limit: 5 })
const details = await parallel(
repos.map(r => getRepoDetails({ owner: r.owner, repo: r.name }))
)
return { repos, details }
\`
})
\`\`\`
### Best Practices
- Always search first to discover available tools and their exact signatures
- Use descriptive variable names in your code
- Handle errors gracefully with try/catch when needed
- Use \`parallel()\` for independent operations to improve performance
`
interface ToolInfo {
name: string
serverName?: string
description?: string
}
export function getHubModeSystemPrompt(tools: ToolInfo[] = []): string {
if (tools.length === 0) {
return HUB_MODE_SYSTEM_PROMPT_BASE
}
const existingNames = new Set<string>()
const toolsSection = tools
.map((t) => {
const functionName = generateMcpToolFunctionName(t.serverName, t.name, existingNames)
const desc = t.description || ''
const truncatedDesc = desc.length > 50 ? `${desc.slice(0, 50)}...` : desc
return `- ${functionName}: ${truncatedDesc}`
})
.join('\n')
return `${HUB_MODE_SYSTEM_PROMPT_BASE}
### Available Tools
${toolsSection}
`
}

View File

@ -544,6 +544,20 @@
"description": "Default enabled MCP servers",
"enableFirst": "Enable this server in MCP settings first",
"label": "MCP Servers",
"mode": {
"auto": {
"description": "AI discovers and uses tools automatically",
"label": "Auto"
},
"disabled": {
"description": "No MCP tools",
"label": "Disabled"
},
"manual": {
"description": "Select specific MCP servers",
"label": "Manual"
}
},
"noServersAvailable": "No MCP servers available. Add servers in settings",
"title": "MCP Settings"
},

View File

@ -544,6 +544,20 @@
"description": "默认启用的 MCP 服务器",
"enableFirst": "请先在 MCP 设置中启用此服务器",
"label": "MCP 服务器",
"mode": {
"auto": {
"description": "AI 自动发现和使用工具",
"label": "自动"
},
"disabled": {
"description": "不使用 MCP 工具",
"label": "禁用"
},
"manual": {
"description": "选择特定的 MCP 服务器",
"label": "手动"
}
},
"noServersAvailable": "无可用 MCP 服务器。请在设置中添加服务器",
"title": "MCP 服务器"
},

View File

@ -544,6 +544,20 @@
"description": "預設啟用的 MCP 伺服器",
"enableFirst": "請先在 MCP 設定中啟用此伺服器",
"label": "MCP 伺服器",
"mode": {
"auto": {
"description": "AI 自動發現和使用工具",
"label": "自動"
},
"disabled": {
"description": "不使用 MCP 工具",
"label": "停用"
},
"manual": {
"description": "選擇特定的 MCP 伺服器",
"label": "手動"
}
},
"noServersAvailable": "無可用 MCP 伺服器。請在設定中新增伺服器",
"title": "MCP 設定"
},

View File

@ -8,11 +8,12 @@ import { useTimer } from '@renderer/hooks/useTimer'
import type { ToolQuickPanelApi } from '@renderer/pages/home/Inputbar/types'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { EventEmitter } from '@renderer/services/EventService'
import type { MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import type { McpMode, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import { getEffectiveMcpMode } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isGeminiWebSearchProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
import { Form, Input, Tooltip } from 'antd'
import { CircleX, Hammer, Plus } from 'lucide-react'
import { CircleX, Hammer, Plus, Sparkles } from 'lucide-react'
import type { FC } from 'react'
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@ -25,7 +26,6 @@ interface Props {
resizeTextArea: () => void
}
// 添加类型定义
interface PromptArgument {
name: string
description?: string
@ -44,24 +44,19 @@ interface ResourceData {
uri?: string
}
// 提取到组件外的工具函数
const extractPromptContent = (response: any): string | null => {
// Handle string response (backward compatibility)
if (typeof response === 'string') {
return response
}
// Handle GetMCPPromptResponse format
if (response && Array.isArray(response.messages)) {
let formattedContent = ''
for (const message of response.messages) {
if (!message.content) continue
// Add role prefix if available
const rolePrefix = message.role ? `**${message.role.charAt(0).toUpperCase() + message.role.slice(1)}:** ` : ''
// Process different content types
switch (message.content.type) {
case 'text':
formattedContent += `${rolePrefix}${message.content.text}\n\n`
@ -98,7 +93,6 @@ const extractPromptContent = (response: any): string | null => {
return formattedContent.trim()
}
// Fallback handling for single message format
if (response && response.messages && response.messages.length > 0) {
const message = response.messages[0]
if (message.content && message.content.text) {
@ -121,7 +115,6 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
const model = assistant.model
const { setTimeoutTimer } = useTimer()
// 使用 useRef 存储不需要触发重渲染的值
const isMountedRef = useRef(true)
useEffect(() => {
@ -130,11 +123,30 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
}
}, [])
const currentMode = useMemo(() => getEffectiveMcpMode(assistant), [assistant])
const mcpServers = useMemo(() => assistant.mcpServers || [], [assistant.mcpServers])
const assistantMcpServers = useMemo(
() => activedMcpServers.filter((server) => mcpServers.some((s) => s.id === server.id)),
[activedMcpServers, mcpServers]
)
const handleModeChange = useCallback(
(mode: McpMode) => {
setTimeoutTimer(
'updateMcpMode',
() => {
updateAssistant({
...assistant,
mcpMode: mode
})
},
200
)
},
[assistant, setTimeoutTimer, updateAssistant]
)
const handleMcpServerSelect = useCallback(
(server: MCPServer) => {
const update = { ...assistant }
@ -144,29 +156,24 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
update.mcpServers = [...mcpServers, server]
}
// only for gemini
if (update.mcpServers.length > 0 && isGeminiModel(model) && isToolUseModeFunction(assistant)) {
const provider = getProviderByModel(model)
if (isSupportUrlContextProvider(provider) && assistant.enableUrlContext) {
window.toast.warning(t('chat.mcp.warning.url_context'))
update.enableUrlContext = false
}
if (
// 非官方 API (openrouter etc.) 可能支持同时启用内置搜索和函数调用
// 这里先假设 gemini type 和 vertexai type 不支持
isGeminiWebSearchProvider(provider) &&
assistant.enableWebSearch
) {
if (isGeminiWebSearchProvider(provider) && assistant.enableWebSearch) {
window.toast.warning(t('chat.mcp.warning.gemini_web_search'))
update.enableWebSearch = false
}
}
update.mcpMode = 'manual'
updateAssistant(update)
},
[assistant, assistantMcpServers, mcpServers, model, t, updateAssistant]
)
// 使用 useRef 缓存事件处理函数
const handleMcpServerSelectRef = useRef(handleMcpServerSelect)
handleMcpServerSelectRef.current = handleMcpServerSelect
@ -176,23 +183,7 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
return () => EventEmitter.off('mcp-server-select', handler)
}, [])
const updateMcpEnabled = useCallback(
(enabled: boolean) => {
setTimeoutTimer(
'updateMcpEnabled',
() => {
updateAssistant({
...assistant,
mcpServers: enabled ? assistant.mcpServers || [] : []
})
},
200
)
},
[assistant, setTimeoutTimer, updateAssistant]
)
const menuItems = useMemo(() => {
const manualModeMenuItems = useMemo(() => {
const newList: QuickPanelListItem[] = activedMcpServers.map((server) => ({
label: server.name,
description: server.description || server.baseUrl,
@ -207,33 +198,69 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
action: () => navigate('/settings/mcp')
})
newList.unshift({
label: t('settings.input.clear.all'),
description: t('settings.mcp.disable.description'),
icon: <CircleX />,
isSelected: false,
action: () => {
updateMcpEnabled(false)
quickPanelHook.close()
}
})
return newList
}, [activedMcpServers, t, assistantMcpServers, navigate, updateMcpEnabled, quickPanelHook])
}, [activedMcpServers, t, assistantMcpServers, navigate])
const openQuickPanel = useCallback(() => {
const openManualModePanel = useCallback(() => {
quickPanelHook.open({
title: t('settings.mcp.title'),
list: menuItems,
title: t('assistants.settings.mcp.mode.manual.label'),
list: manualModeMenuItems,
symbol: QuickPanelReservedSymbol.Mcp,
multiple: true,
afterAction({ item }) {
item.isSelected = !item.isSelected
}
})
}, [manualModeMenuItems, quickPanelHook, t])
const menuItems = useMemo(() => {
const newList: QuickPanelListItem[] = []
newList.push({
label: t('assistants.settings.mcp.mode.disabled.label'),
description: t('assistants.settings.mcp.mode.disabled.description'),
icon: <CircleX />,
isSelected: currentMode === 'disabled',
action: () => {
handleModeChange('disabled')
quickPanelHook.close()
}
})
newList.push({
label: t('assistants.settings.mcp.mode.auto.label'),
description: t('assistants.settings.mcp.mode.auto.description'),
icon: <Sparkles />,
isSelected: currentMode === 'auto',
action: () => {
handleModeChange('auto')
quickPanelHook.close()
}
})
newList.push({
label: t('assistants.settings.mcp.mode.manual.label'),
description: t('assistants.settings.mcp.mode.manual.description'),
icon: <Hammer />,
isSelected: currentMode === 'manual',
isMenu: true,
action: () => {
openManualModePanel()
}
})
return newList
}, [t, currentMode, handleModeChange, quickPanelHook, openManualModePanel])
const openQuickPanel = useCallback(() => {
quickPanelHook.open({
title: t('settings.mcp.title'),
list: menuItems,
symbol: QuickPanelReservedSymbol.Mcp,
multiple: false
})
}, [menuItems, quickPanelHook, t])
// 使用 useCallback 优化 insertPromptIntoTextArea
const insertPromptIntoTextArea = useCallback(
(promptText: string) => {
setInputValue((prev) => {
@ -245,7 +272,6 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
const selectionEndPosition = cursorPosition + promptText.length
const newText = prev.slice(0, cursorPosition) + promptText + prev.slice(cursorPosition)
// 使用 requestAnimationFrame 优化 DOM 操作
requestAnimationFrame(() => {
textArea.focus()
textArea.setSelectionRange(selectionStart, selectionEndPosition)
@ -424,7 +450,6 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
[activedMcpServers, t, insertPromptIntoTextArea]
)
// 优化 resourcesList 的状态更新
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
useEffect(() => {
@ -514,17 +539,26 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
}
}, [openPromptList, openQuickPanel, openResourcesList, quickPanel, t])
const isActive = currentMode !== 'disabled'
const getButtonIcon = () => {
switch (currentMode) {
case 'auto':
return <Sparkles size={18} />
case 'disabled':
case 'manual':
default:
return <Hammer size={18} />
}
}
return (
<Tooltip placement="top" title={t('settings.mcp.title')} mouseLeaveDelay={0} arrow>
<ActionIconButton
onClick={handleOpenQuickPanel}
active={assistant.mcpServers && assistant.mcpServers.length > 0}
aria-label={t('settings.mcp.title')}>
<Hammer size={18} />
<ActionIconButton onClick={handleOpenQuickPanel} active={isActive} aria-label={t('settings.mcp.title')}>
{getButtonIcon()}
</ActionIconButton>
</Tooltip>
)
}
// 使用 React.memo 包装组件
export default React.memo(MCPToolsButton)

View File

@ -1,6 +1,7 @@
import { ActionIconButton } from '@renderer/components/Buttons'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useTimer } from '@renderer/hooks/useTimer'
import { getEffectiveMcpMode } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { Tooltip } from 'antd'
import { Link } from 'lucide-react'
@ -30,8 +31,7 @@ const UrlContextButton: FC<Props> = ({ assistantId }) => {
() => {
const update = { ...assistant }
if (
assistant.mcpServers &&
assistant.mcpServers.length > 0 &&
getEffectiveMcpMode(assistant) !== 'disabled' &&
urlContentNewState === true &&
isToolUseModeFunction(assistant)
) {

View File

@ -16,7 +16,7 @@ import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders'
import type { ToolQuickPanelController, ToolRenderContext } from '@renderer/pages/home/Inputbar/types'
import { getProviderByModel } from '@renderer/services/AssistantService'
import WebSearchService from '@renderer/services/WebSearchService'
import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types'
import { getEffectiveMcpMode, type WebSearchProvider, type WebSearchProviderId } from '@renderer/types'
import { hasObjectKey } from '@renderer/utils'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
@ -108,8 +108,7 @@ export const useWebSearchPanelController = (assistantId: string, quickPanelContr
isGeminiModel(model) &&
isToolUseModeFunction(assistant) &&
update.enableWebSearch &&
assistant.mcpServers &&
assistant.mcpServers.length > 0
getEffectiveMcpMode(assistant) !== 'disabled'
) {
update.enableWebSearch = false
window.toast.warning(t('chat.mcp.warning.gemini_web_search'))

View File

@ -1,8 +1,9 @@
import { InfoCircleOutlined } from '@ant-design/icons'
import { Box } from '@renderer/components/Layout'
import { useMCPServers } from '@renderer/hooks/useMCPServers'
import type { Assistant, AssistantSettings } from '@renderer/types'
import { Empty, Switch, Tooltip } from 'antd'
import type { Assistant, AssistantSettings, McpMode } from '@renderer/types'
import { getEffectiveMcpMode } from '@renderer/types'
import { Empty, Radio, Switch, Tooltip } from 'antd'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@ -27,22 +28,26 @@ const AssistantMCPSettings: React.FC<Props> = ({ assistant, updateAssistant }) =
const { t } = useTranslation()
const { mcpServers: allMcpServers } = useMCPServers()
const currentMode = getEffectiveMcpMode(assistant)
const handleModeChange = (mode: McpMode) => {
updateAssistant({ ...assistant, mcpMode: mode })
}
const onUpdate = (ids: string[]) => {
const mcpServers = ids
.map((id) => allMcpServers.find((server) => server.id === id))
.filter((server): server is MCPServer => server !== undefined && server.isActive)
updateAssistant({ ...assistant, mcpServers })
updateAssistant({ ...assistant, mcpServers, mcpMode: 'manual' })
}
const handleServerToggle = (serverId: string) => {
const currentServerIds = assistant.mcpServers?.map((server) => server.id) || []
if (currentServerIds.includes(serverId)) {
// Remove server if it's already enabled
onUpdate(currentServerIds.filter((id) => id !== serverId))
} else {
// Add server if it's not enabled
onUpdate([...currentServerIds, serverId])
}
}
@ -58,49 +63,77 @@ const AssistantMCPSettings: React.FC<Props> = ({ assistant, updateAssistant }) =
<InfoIcon />
</Tooltip>
</Box>
{allMcpServers.length > 0 && (
<EnabledCount>
{enabledCount} / {allMcpServers.length} {t('settings.mcp.active')}
</EnabledCount>
)}
</HeaderContainer>
{allMcpServers.length > 0 ? (
<ServerList>
{allMcpServers.map((server) => {
const isEnabled = assistant.mcpServers?.some((s) => s.id === server.id) || false
<ModeSelector>
<Radio.Group value={currentMode} onChange={(e) => handleModeChange(e.target.value)}>
<Radio.Button value="disabled">
<ModeOption>
<ModeLabel>{t('assistants.settings.mcp.mode.disabled.label')}</ModeLabel>
<ModeDescription>{t('assistants.settings.mcp.mode.disabled.description')}</ModeDescription>
</ModeOption>
</Radio.Button>
<Radio.Button value="auto">
<ModeOption>
<ModeLabel>{t('assistants.settings.mcp.mode.auto.label')}</ModeLabel>
<ModeDescription>{t('assistants.settings.mcp.mode.auto.description')}</ModeDescription>
</ModeOption>
</Radio.Button>
<Radio.Button value="manual">
<ModeOption>
<ModeLabel>{t('assistants.settings.mcp.mode.manual.label')}</ModeLabel>
<ModeDescription>{t('assistants.settings.mcp.mode.manual.description')}</ModeDescription>
</ModeOption>
</Radio.Button>
</Radio.Group>
</ModeSelector>
return (
<ServerItem key={server.id} isEnabled={isEnabled}>
<ServerInfo>
<ServerName>{server.name}</ServerName>
{server.description && <ServerDescription>{server.description}</ServerDescription>}
{server.baseUrl && <ServerUrl>{server.baseUrl}</ServerUrl>}
</ServerInfo>
<Tooltip
title={
!server.isActive
? t('assistants.settings.mcp.enableFirst', 'Enable this server in MCP settings first')
: undefined
}>
<Switch
checked={isEnabled}
disabled={!server.isActive}
onChange={() => handleServerToggle(server.id)}
size="small"
/>
</Tooltip>
</ServerItem>
)
})}
</ServerList>
) : (
<EmptyContainer>
<Empty
description={t('assistants.settings.mcp.noServersAvailable', 'No MCP servers available')}
image={Empty.PRESENTED_IMAGE_SIMPLE}
/>
</EmptyContainer>
{currentMode === 'manual' && (
<>
{allMcpServers.length > 0 && (
<EnabledCount>
{enabledCount} / {allMcpServers.length} {t('settings.mcp.active')}
</EnabledCount>
)}
{allMcpServers.length > 0 ? (
<ServerList>
{allMcpServers.map((server) => {
const isEnabled = assistant.mcpServers?.some((s) => s.id === server.id) || false
return (
<ServerItem key={server.id} isEnabled={isEnabled}>
<ServerInfo>
<ServerName>{server.name}</ServerName>
{server.description && <ServerDescription>{server.description}</ServerDescription>}
{server.baseUrl && <ServerUrl>{server.baseUrl}</ServerUrl>}
</ServerInfo>
<Tooltip
title={
!server.isActive
? t('assistants.settings.mcp.enableFirst', 'Enable this server in MCP settings first')
: undefined
}>
<Switch
checked={isEnabled}
disabled={!server.isActive}
onChange={() => handleServerToggle(server.id)}
size="small"
/>
</Tooltip>
</ServerItem>
)
})}
</ServerList>
) : (
<EmptyContainer>
<Empty
description={t('assistants.settings.mcp.noServersAvailable', 'No MCP servers available')}
image={Empty.PRESENTED_IMAGE_SIMPLE}
/>
</EmptyContainer>
)}
</>
)}
</Container>
)
@ -127,9 +160,54 @@ const InfoIcon = styled(InfoCircleOutlined)`
cursor: help;
`
const ModeSelector = styled.div`
margin-bottom: 16px;
.ant-radio-group {
display: flex;
flex-direction: column;
gap: 8px;
}
.ant-radio-button-wrapper {
height: auto;
padding: 12px 16px;
border-radius: 8px;
border: 1px solid var(--color-border);
&:not(:first-child)::before {
display: none;
}
&:first-child {
border-radius: 8px;
}
&:last-child {
border-radius: 8px;
}
}
`
const ModeOption = styled.div`
display: flex;
flex-direction: column;
gap: 2px;
`
const ModeLabel = styled.span`
font-weight: 600;
`
const ModeDescription = styled.span`
font-size: 12px;
color: var(--color-text-2);
`
const EnabledCount = styled.span`
font-size: 12px;
color: var(--color-text-2);
margin-bottom: 8px;
`
const EmptyContainer = styled.div`

View File

@ -8,8 +8,9 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import { hubMCPServer } from '@renderer/store/mcp'
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
import { type FetchChatCompletionParams, getEffectiveMcpMode, isSystemProvider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import type { Message, ResponseError } from '@renderer/types/newMessage'
@ -51,14 +52,60 @@ import type { StreamProcessorCallbacks } from './StreamProcessingService'
const logger = loggerService.withContext('ApiService')
export async function fetchMcpTools(assistant: Assistant) {
// Get MCP tools (Fix duplicate declaration)
let mcpTools: MCPTool[] = [] // Initialize as empty array
/**
* Get the MCP servers to use based on the assistant's MCP mode.
*/
export function getMcpServersForAssistant(assistant: Assistant): MCPServer[] {
const mode = getEffectiveMcpMode(assistant)
const allMcpServers = store.getState().mcp.servers || []
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
const assistantMcpServers = assistant.mcpServers || []
const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
switch (mode) {
case 'disabled':
return []
case 'auto':
return [hubMCPServer]
case 'manual': {
const assistantMcpServers = assistant.mcpServers || []
return activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
}
default:
return []
}
}
export async function fetchAllActiveServerTools(): Promise<MCPTool[]> {
const allMcpServers = store.getState().mcp.servers || []
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
if (activedMcpServers.length === 0) {
return []
}
try {
const toolPromises = activedMcpServers.map(async (mcpServer: MCPServer) => {
try {
const tools = await window.api.mcp.listTools(mcpServer)
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
} catch (error) {
logger.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error as Error)
return []
}
})
const results = await Promise.allSettled(toolPromises)
return results
.filter((result): result is PromiseFulfilledResult<MCPTool[]> => result.status === 'fulfilled')
.map((result) => result.value)
.flat()
} catch (toolError) {
logger.error('Error fetching all active server tools:', toolError as Error)
return []
}
}
export async function fetchMcpTools(assistant: Assistant) {
let mcpTools: MCPTool[] = []
const enabledMCPs = getMcpServersForAssistant(assistant)
if (enabledMCPs && enabledMCPs.length > 0) {
try {

View File

@ -0,0 +1,54 @@
import { describe, expect, it } from 'vitest'
import type { Assistant, MCPServer } from '../../types'
import { getEffectiveMcpMode } from '../../types'
describe('getEffectiveMcpMode', () => {
it('should return mcpMode when explicitly set to auto', () => {
const assistant: Partial<Assistant> = { mcpMode: 'auto' }
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('auto')
})
it('should return disabled when mcpMode is explicitly disabled', () => {
const assistant: Partial<Assistant> = { mcpMode: 'disabled' }
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
})
it('should return manual when mcpMode is explicitly manual', () => {
const assistant: Partial<Assistant> = { mcpMode: 'manual' }
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('manual')
})
it('should return manual when no mcpMode but mcpServers has items (backward compatibility)', () => {
const assistant: Partial<Assistant> = {
mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[]
}
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('manual')
})
it('should return disabled when no mcpMode and no mcpServers (backward compatibility)', () => {
const assistant: Partial<Assistant> = {}
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
})
it('should return disabled when no mcpMode and empty mcpServers (backward compatibility)', () => {
const assistant: Partial<Assistant> = { mcpServers: [] }
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
})
it('should prioritize explicit mcpMode over mcpServers presence', () => {
const assistant: Partial<Assistant> = {
mcpMode: 'disabled',
mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[]
}
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
})
it('should return auto when mcpMode is auto regardless of mcpServers', () => {
const assistant: Partial<Assistant> = {
mcpMode: 'auto',
mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[]
}
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('auto')
})
})

View File

@ -86,6 +86,20 @@ export { mcpSlice }
// Export the reducer as default export
export default mcpSlice.reducer
/**
* Hub MCP server for auto mode - aggregates all MCP servers for LLM code mode.
* This server is injected automatically when mcpMode === 'auto'.
*/
export const hubMCPServer: BuiltinMCPServer = {
id: 'hub',
name: BuiltinMCPServerNames.hub,
type: 'inMemory',
isActive: true,
provider: 'CherryAI',
installSource: 'builtin',
isTrusted: true
}
/**
* User-installable built-in MCP servers shown in the UI.
*

View File

@ -27,6 +27,8 @@ export * from './ocr'
export * from './plugin'
export * from './provider'
export type McpMode = 'disabled' | 'auto' | 'manual'
export type Assistant = {
id: string
name: string
@ -47,6 +49,8 @@ export type Assistant = {
// enableUrlContext 是 Gemini/Anthropic 的特有功能
enableUrlContext?: boolean
enableGenerateImage?: boolean
/** MCP mode: 'disabled' (no MCP), 'auto' (hub server only), 'manual' (user selects servers) */
mcpMode?: McpMode
mcpServers?: MCPServer[]
knowledgeRecognition?: 'off' | 'on'
regularPhrases?: QuickPhrase[] // Added for regular phrase
@ -57,6 +61,15 @@ export type Assistant = {
targetLanguage?: TranslateLanguage
}
/**
* Get the effective MCP mode for an assistant with backward compatibility.
* Legacy assistants without mcpMode default based on mcpServers presence.
*/
export function getEffectiveMcpMode(assistant: Assistant): McpMode {
if (assistant.mcpMode) return assistant.mcpMode
return (assistant.mcpServers?.length ?? 0) > 0 ? 'manual' : 'disabled'
}
export type TranslateAssistant = Assistant & {
model: Model
content: string

View File

@ -13,7 +13,7 @@ import { isFunctionCallingModel, isVisionModel } from '@renderer/config/models'
import i18n from '@renderer/i18n'
import { currentSpan } from '@renderer/services/SpanManagerService'
import store from '@renderer/store'
import { addMCPServer } from '@renderer/store/mcp'
import { addMCPServer, hubMCPServer } from '@renderer/store/mcp'
import type {
Assistant,
MCPCallToolResponse,
@ -325,7 +325,16 @@ export function filterMCPTools(
export function getMcpServerByTool(tool: MCPTool) {
const servers = store.getState().mcp.servers
return servers.find((s) => s.id === tool.serverId)
const server = servers.find((s) => s.id === tool.serverId)
if (server) {
return server
}
// For hub server (auto mode), the server isn't in the store
// Return the hub server constant if the tool's serverId matches
if (tool.serverId === 'hub') {
return hubMCPServer
}
return undefined
}
export function isToolAutoApproved(tool: MCPTool, server?: MCPServer): boolean {