mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-12 00:49:14 +08:00
✨ feat: add available tools section to HUB_MODE_SYSTEM_PROMPT
- Add shared utility for generating MCP tool function names (serverName_toolName format) - Update hub server to use consistent function naming across search, exec and prompt - Add fetchAllActiveServerTools to ApiService for renderer process - Update parameterBuilder to include available tools in auto/hub mode prompt - Use CacheService for 1-minute tools caching in hub server - Remove ToolRegistry in favor of direct fetching with caching - Update search ranking to include server name matching - Fix tests to use new naming format Amp-Thread-ID: https://ampcode.com/threads/T-019b6971-d5c9-7719-9245-a89390078647 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
6c951a3945
commit
7a5b011bfa
43
packages/shared/mcp.ts
Normal file
43
packages/shared/mcp.ts
Normal file
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Convert a string to camelCase, ensuring it's a valid JavaScript identifier.
|
||||
*/
|
||||
export function toCamelCase(str: string): string {
|
||||
let result = str
|
||||
.replace(/[^a-zA-Z0-9]+(.)/g, (_, char) => char.toUpperCase())
|
||||
.replace(/^[A-Z]/, (char) => char.toLowerCase())
|
||||
.replace(/[^a-zA-Z0-9]/g, '')
|
||||
|
||||
// Ensure valid JS identifier: must start with letter or underscore
|
||||
if (result && !/^[a-zA-Z_]/.test(result)) {
|
||||
result = '_' + result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a unique function name from server name and tool name.
|
||||
* Format: serverName_toolName (camelCase)
|
||||
*/
|
||||
export function generateMcpToolFunctionName(
|
||||
serverName: string | undefined,
|
||||
toolName: string,
|
||||
existingNames?: Set<string>
|
||||
): string {
|
||||
const serverPrefix = serverName ? toCamelCase(serverName) : ''
|
||||
const toolNameCamel = toCamelCase(toolName)
|
||||
const baseName = serverPrefix ? `${serverPrefix}_${toolNameCamel}` : toolNameCamel
|
||||
|
||||
if (!existingNames) {
|
||||
return baseName
|
||||
}
|
||||
|
||||
let name = baseName
|
||||
let counter = 1
|
||||
while (existingNames.has(name)) {
|
||||
name = `${baseName}${counter}`
|
||||
counter++
|
||||
}
|
||||
existingNames.add(name)
|
||||
return name
|
||||
}
|
||||
@ -6,6 +6,62 @@ A built-in MCP server that aggregates all active MCP servers in Cherry Studio an
|
||||
|
||||
The Hub server enables LLMs to discover and call tools from all active MCP servers without needing to know the specific server names or tool signatures upfront.
|
||||
|
||||
## Auto Mode Integration
|
||||
|
||||
The Hub server is the core component of Cherry Studio's **Auto MCP Mode**. When an assistant is set to Auto mode:
|
||||
|
||||
1. **Automatic Injection**: The Hub server is automatically injected as the only MCP server for the assistant
|
||||
2. **System Prompt**: A specialized system prompt (`HUB_MODE_SYSTEM_PROMPT`) is appended to guide the LLM on how to use the `search` and `exec` tools
|
||||
3. **Dynamic Discovery**: The LLM can discover and use any tools from all active MCP servers without manual configuration
|
||||
|
||||
### MCP Modes
|
||||
|
||||
Cherry Studio supports three MCP modes per assistant:
|
||||
|
||||
| Mode | Description | Tools Available |
|
||||
|------|-------------|-----------------|
|
||||
| **Disabled** | No MCP tools | None |
|
||||
| **Auto** | Hub server only | `search`, `exec` |
|
||||
| **Manual** | User selects servers | Selected server tools |
|
||||
|
||||
### How Auto Mode Works
|
||||
|
||||
```
|
||||
User Message
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Assistant (mcpMode: 'auto') │
|
||||
│ │
|
||||
│ System Prompt + HUB_MODE_SYSTEM_PROMPT │
|
||||
│ Tools: [hub.search, hub.exec] │
|
||||
└─────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ LLM decides to use MCP tools │
|
||||
│ │
|
||||
│ 1. search({ query: "github,repo" }) │
|
||||
│ 2. exec({ code: "await searchRepos()" })│
|
||||
└─────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Hub Server │
|
||||
│ │
|
||||
│ Aggregates all active MCP servers │
|
||||
│ Routes tool calls to appropriate server │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Relevant Code
|
||||
|
||||
- **Type Definition**: `src/renderer/src/types/index.ts` - `McpMode` type and `getEffectiveMcpMode()`
|
||||
- **Hub Server Constant**: `src/renderer/src/store/mcp.ts` - `hubMCPServer`
|
||||
- **Server Selection**: `src/renderer/src/services/ApiService.ts` - `getMcpServersForAssistant()`
|
||||
- **System Prompt**: `src/renderer/src/config/prompts.ts` - `HUB_MODE_SYSTEM_PROMPT`
|
||||
- **Prompt Injection**: `src/renderer/src/aiCore/prepareParams/parameterBuilder.ts`
|
||||
|
||||
## Tools
|
||||
|
||||
### `search`
|
||||
@ -95,11 +151,24 @@ return { users, repos };
|
||||
|
||||
## Configuration
|
||||
|
||||
The Hub server is a built-in server identified as `@cherry/hub`. To enable it:
|
||||
The Hub server is a built-in server identified as `@cherry/hub`.
|
||||
|
||||
### Using Auto Mode (Recommended)
|
||||
|
||||
The easiest way to use the Hub server is through Auto mode:
|
||||
|
||||
1. Click the **MCP Tools** button (hammer icon) in the input bar
|
||||
2. Select **Auto** mode
|
||||
3. The Hub server is automatically enabled for the assistant
|
||||
|
||||
### Manual Configuration
|
||||
|
||||
Alternatively, you can enable the Hub server manually:
|
||||
|
||||
1. Go to **Settings** → **MCP Servers**
|
||||
2. Find **Hub** in the built-in servers list
|
||||
3. Toggle it on
|
||||
4. In the assistant's MCP settings, select the Hub server
|
||||
|
||||
## Caching
|
||||
|
||||
|
||||
@ -23,20 +23,13 @@ describe('generator', () => {
|
||||
type: 'mcp' as const
|
||||
}
|
||||
|
||||
const server = {
|
||||
id: 'github',
|
||||
name: 'github-server',
|
||||
isActive: true
|
||||
}
|
||||
|
||||
const existingNames = new Set<string>()
|
||||
const callTool = async () => ({ success: true })
|
||||
|
||||
const result = generateToolFunction(tool, server as any, existingNames, callTool)
|
||||
const result = generateToolFunction(tool, existingNames, callTool)
|
||||
|
||||
expect(result.toolId).toBe('github__search_repos')
|
||||
expect(result.functionName).toBe('searchRepos')
|
||||
expect(result.jsCode).toContain('async function searchRepos')
|
||||
expect(result.functionName).toBe('githubServer_searchRepos')
|
||||
expect(result.jsCode).toContain('async function githubServer_searchRepos')
|
||||
expect(result.jsCode).toContain('Search for GitHub repositories')
|
||||
expect(result.jsCode).toContain('__callTool')
|
||||
})
|
||||
@ -51,13 +44,12 @@ describe('generator', () => {
|
||||
type: 'mcp' as const
|
||||
}
|
||||
|
||||
const server = { id: 'server1', name: 'server1', isActive: true }
|
||||
const existingNames = new Set<string>(['search'])
|
||||
const existingNames = new Set<string>(['server1_search'])
|
||||
const callTool = async () => ({})
|
||||
|
||||
const result = generateToolFunction(tool, server as any, existingNames, callTool)
|
||||
const result = generateToolFunction(tool, existingNames, callTool)
|
||||
|
||||
expect(result.functionName).toBe('search1')
|
||||
expect(result.functionName).toBe('server1_search1')
|
||||
})
|
||||
|
||||
it('handles enum types in schema', () => {
|
||||
@ -78,11 +70,10 @@ describe('generator', () => {
|
||||
type: 'mcp' as const
|
||||
}
|
||||
|
||||
const server = { id: 'browser', name: 'browser', isActive: true }
|
||||
const existingNames = new Set<string>()
|
||||
const callTool = async () => ({})
|
||||
|
||||
const result = generateToolFunction(tool, server as any, existingNames, callTool)
|
||||
const result = generateToolFunction(tool, existingNames, callTool)
|
||||
|
||||
expect(result.jsCode).toContain('"chromium" | "firefox" | "webkit"')
|
||||
})
|
||||
@ -95,9 +86,8 @@ describe('generator', () => {
|
||||
serverId: 's1',
|
||||
serverName: 'server1',
|
||||
toolName: 'tool1',
|
||||
toolId: 's1__tool1',
|
||||
functionName: 'tool1',
|
||||
jsCode: 'async function tool1() {}',
|
||||
functionName: 'server1_tool1',
|
||||
jsCode: 'async function server1_tool1() {}',
|
||||
fn: async () => ({}),
|
||||
signature: '{}',
|
||||
returns: 'unknown'
|
||||
@ -106,9 +96,8 @@ describe('generator', () => {
|
||||
serverId: 's2',
|
||||
serverName: 'server2',
|
||||
toolName: 'tool2',
|
||||
toolId: 's2__tool2',
|
||||
functionName: 'tool2',
|
||||
jsCode: 'async function tool2() {}',
|
||||
functionName: 'server2_tool2',
|
||||
jsCode: 'async function server2_tool2() {}',
|
||||
fn: async () => ({}),
|
||||
signature: '{}',
|
||||
returns: 'unknown'
|
||||
@ -118,8 +107,8 @@ describe('generator', () => {
|
||||
const result = generateToolsCode(tools)
|
||||
|
||||
expect(result).toContain('Found 2 tool(s)')
|
||||
expect(result).toContain('async function tool1')
|
||||
expect(result).toContain('async function tool2')
|
||||
expect(result).toContain('async function server1_tool1')
|
||||
expect(result).toContain('async function server2_tool2')
|
||||
})
|
||||
|
||||
it('returns message for empty tools', () => {
|
||||
|
||||
@ -1,102 +1,93 @@
|
||||
import type { MCPServer } from '@types'
|
||||
import type { MCPTool } from '@types'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { HubServer } from '../index'
|
||||
import { initHubBridge } from '../mcp-bridge'
|
||||
|
||||
const mockMcpServers: MCPServer[] = [
|
||||
const mockTools: MCPTool[] = [
|
||||
{
|
||||
id: 'github',
|
||||
name: 'GitHub',
|
||||
command: 'npx',
|
||||
args: ['-y', 'github-mcp-server'],
|
||||
isActive: true
|
||||
} as MCPServer,
|
||||
id: 'github__search_repos',
|
||||
name: 'search_repos',
|
||||
description: 'Search for GitHub repositories',
|
||||
serverId: 'github',
|
||||
serverName: 'GitHub',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string', description: 'Search query' },
|
||||
limit: { type: 'number', description: 'Max results' }
|
||||
},
|
||||
required: ['query']
|
||||
},
|
||||
type: 'mcp'
|
||||
},
|
||||
{
|
||||
id: 'database',
|
||||
name: 'Database',
|
||||
command: 'npx',
|
||||
args: ['-y', 'db-mcp-server'],
|
||||
isActive: true
|
||||
} as MCPServer
|
||||
id: 'github__get_user',
|
||||
name: 'get_user',
|
||||
description: 'Get GitHub user profile',
|
||||
serverId: 'github',
|
||||
serverName: 'GitHub',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
username: { type: 'string', description: 'GitHub username' }
|
||||
},
|
||||
required: ['username']
|
||||
},
|
||||
type: 'mcp'
|
||||
},
|
||||
{
|
||||
id: 'database__query',
|
||||
name: 'query',
|
||||
description: 'Execute a database query',
|
||||
serverId: 'database',
|
||||
serverName: 'Database',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
sql: { type: 'string', description: 'SQL query to execute' }
|
||||
},
|
||||
required: ['sql']
|
||||
},
|
||||
type: 'mcp'
|
||||
}
|
||||
]
|
||||
|
||||
const mockToolDefinitions = {
|
||||
github: [
|
||||
{
|
||||
name: 'search_repos',
|
||||
description: 'Search for GitHub repositories',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string', description: 'Search query' },
|
||||
limit: { type: 'number', description: 'Max results' }
|
||||
},
|
||||
required: ['query']
|
||||
vi.mock('@main/services/MCPService', () => ({
|
||||
default: {
|
||||
listAllActiveServerTools: vi.fn(async () => mockTools),
|
||||
callToolById: vi.fn(async (toolId: string, args: unknown) => {
|
||||
if (toolId === 'github__search_repos') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args }) }]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'get_user',
|
||||
description: 'Get GitHub user profile',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
username: { type: 'string', description: 'GitHub username' }
|
||||
},
|
||||
required: ['username']
|
||||
if (toolId === 'github__get_user') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ username: (args as any).username, id: 123 }) }]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
database: [
|
||||
{
|
||||
name: 'query',
|
||||
description: 'Execute a database query',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
sql: { type: 'string', description: 'SQL query to execute' }
|
||||
},
|
||||
required: ['sql']
|
||||
if (toolId === 'database__query') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
return { content: [{ type: 'text', text: '{}' }] }
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
const mockMcpService = {
|
||||
listTools: vi.fn(async (_: null, server: MCPServer) => {
|
||||
return mockToolDefinitions[server.id as keyof typeof mockToolDefinitions] || []
|
||||
}),
|
||||
callTool: vi.fn(async (_: null, args: { server: MCPServer; name: string; args: unknown }) => {
|
||||
if (args.server.id === 'github' && args.name === 'search_repos') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ repos: ['repo1', 'repo2'], query: args.args }) }]
|
||||
}
|
||||
}
|
||||
if (args.server.id === 'github' && args.name === 'get_user') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ username: (args.args as any).username, id: 123 }) }]
|
||||
}
|
||||
}
|
||||
if (args.server.id === 'database' && args.name === 'query') {
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify({ rows: [{ id: 1 }, { id: 2 }] }) }]
|
||||
}
|
||||
}
|
||||
return { content: [{ type: 'text', text: '{}' }] }
|
||||
})
|
||||
}
|
||||
import mcpService from '@main/services/MCPService'
|
||||
|
||||
describe('HubServer Integration', () => {
|
||||
let hubServer: HubServer
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
initHubBridge(mockMcpService as any, () => mockMcpServers)
|
||||
hubServer = new HubServer()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
hubServer.invalidateCache()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('full search → exec flow', () => {
|
||||
@ -106,10 +97,10 @@ describe('HubServer Integration', () => {
|
||||
expect(searchResult.content).toBeDefined()
|
||||
const searchText = JSON.parse(searchResult.content[0].text)
|
||||
expect(searchText.total).toBeGreaterThan(0)
|
||||
expect(searchText.tools).toContain('searchRepos')
|
||||
expect(searchText.tools).toContain('gitHub_searchRepos')
|
||||
|
||||
const execResult = await (hubServer as any).handleExec({
|
||||
code: 'return await searchRepos({ query: "test" })'
|
||||
code: 'return await gitHub_searchRepos({ query: "test" })'
|
||||
})
|
||||
|
||||
expect(execResult.content).toBeDefined()
|
||||
@ -123,8 +114,8 @@ describe('HubServer Integration', () => {
|
||||
const execResult = await (hubServer as any).handleExec({
|
||||
code: `
|
||||
const results = await parallel(
|
||||
searchRepos({ query: "react" }),
|
||||
getUser({ username: "octocat" })
|
||||
gitHub_searchRepos({ query: "react" }),
|
||||
gitHub_getUser({ username: "octocat" })
|
||||
);
|
||||
return results
|
||||
`
|
||||
@ -140,21 +131,31 @@ describe('HubServer Integration', () => {
|
||||
const searchResult = await (hubServer as any).handleSearch({ query: 'query' })
|
||||
|
||||
const searchText = JSON.parse(searchResult.content[0].text)
|
||||
expect(searchText.tools).toContain('query')
|
||||
expect(searchText.tools).toContain('database_query')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cache invalidation', () => {
|
||||
it('refreshes tools after invalidation', async () => {
|
||||
describe('tools caching', () => {
|
||||
it('uses cached tools within TTL', async () => {
|
||||
await (hubServer as any).handleSearch({ query: 'github' })
|
||||
const firstCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
|
||||
|
||||
const initialCallCount = mockMcpService.listTools.mock.calls.length
|
||||
await (hubServer as any).handleSearch({ query: 'github' })
|
||||
const secondCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
|
||||
|
||||
expect(secondCallCount).toBe(firstCallCount) // Should use cache
|
||||
})
|
||||
|
||||
it('refreshes tools after cache invalidation', async () => {
|
||||
await (hubServer as any).handleSearch({ query: 'github' })
|
||||
const firstCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
|
||||
|
||||
hubServer.invalidateCache()
|
||||
|
||||
await (hubServer as any).handleSearch({ query: 'github' })
|
||||
const secondCallCount = vi.mocked(mcpService.listAllActiveServerTools).mock.calls.length
|
||||
|
||||
expect(mockMcpService.listTools.mock.calls.length).toBeGreaterThan(initialCallCount)
|
||||
expect(secondCallCount).toBe(firstCallCount + 1)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -16,9 +16,8 @@ const createMockTool = (partial: Partial<GeneratedTool>): GeneratedTool => ({
|
||||
serverId: 'server1',
|
||||
serverName: 'server1',
|
||||
toolName: 'tool',
|
||||
toolId: 'server1__tool',
|
||||
functionName: 'mockTool',
|
||||
jsCode: 'async function mockTool() {}',
|
||||
functionName: 'server1_mockTool',
|
||||
jsCode: 'async function server1_mockTool() {}',
|
||||
fn: async (params) => ({ result: params }),
|
||||
signature: '{}',
|
||||
returns: 'unknown',
|
||||
|
||||
@ -4,12 +4,11 @@ import { searchTools } from '../search'
|
||||
import type { GeneratedTool } from '../types'
|
||||
|
||||
const createMockTool = (partial: Partial<GeneratedTool>): GeneratedTool => {
|
||||
const functionName = partial.functionName || 'tool'
|
||||
const functionName = partial.functionName || 'server1_tool'
|
||||
return {
|
||||
serverId: 'server1',
|
||||
serverName: 'server1',
|
||||
toolName: partial.toolName || 'tool',
|
||||
toolId: partial.toolId || 'server1__tool',
|
||||
functionName,
|
||||
jsCode: `async function ${functionName}() {}`,
|
||||
fn: async () => ({}),
|
||||
@ -86,13 +85,13 @@ describe('search', () => {
|
||||
|
||||
it('respects limit parameter', () => {
|
||||
const tools = Array.from({ length: 20 }, (_, i) =>
|
||||
createMockTool({ toolName: `tool${i}`, functionName: `tool${i}`, toolId: `s__tool${i}` })
|
||||
createMockTool({ toolName: `tool${i}`, functionName: `server1_tool${i}` })
|
||||
)
|
||||
|
||||
const result = searchTools(tools, { query: 'tool', limit: 5 })
|
||||
|
||||
expect(result.total).toBe(20)
|
||||
const matches = (result.tools.match(/async function tool\d+/g) || []).length
|
||||
const matches = (result.tools.match(/async function server1_tool\d+/g) || []).length
|
||||
expect(matches).toBe(5)
|
||||
})
|
||||
|
||||
|
||||
@ -1,25 +1,8 @@
|
||||
import type { MCPServer, MCPTool } from '@types'
|
||||
import { generateMcpToolFunctionName } from '@shared/mcp'
|
||||
import type { MCPTool } from '@types'
|
||||
|
||||
import type { GeneratedTool } from './types'
|
||||
|
||||
function toCamelCase(str: string): string {
|
||||
return str
|
||||
.replace(/[^a-zA-Z0-9]+(.)/g, (_, char) => char.toUpperCase())
|
||||
.replace(/^[A-Z]/, (char) => char.toLowerCase())
|
||||
.replace(/[^a-zA-Z0-9]/g, '')
|
||||
}
|
||||
|
||||
function makeUniqueFunctionName(baseName: string, existingNames: Set<string>): string {
|
||||
let name = baseName
|
||||
let counter = 1
|
||||
while (existingNames.has(name)) {
|
||||
name = `${baseName}${counter}`
|
||||
counter++
|
||||
}
|
||||
existingNames.add(name)
|
||||
return name
|
||||
}
|
||||
|
||||
function jsonSchemaToSignature(schema: Record<string, unknown> | undefined): string {
|
||||
if (!schema || typeof schema !== 'object') {
|
||||
return '{}'
|
||||
@ -109,13 +92,10 @@ function generateJSDoc(tool: MCPTool, signature: string, returns: string): strin
|
||||
|
||||
export function generateToolFunction(
|
||||
tool: MCPTool,
|
||||
server: MCPServer,
|
||||
existingNames: Set<string>,
|
||||
callToolFn: (toolId: string, params: unknown) => Promise<unknown>
|
||||
callToolFn: (functionName: string, params: unknown) => Promise<unknown>
|
||||
): GeneratedTool {
|
||||
const toolId = `${server.id}__${tool.name}`
|
||||
const baseName = toCamelCase(tool.name)
|
||||
const functionName = makeUniqueFunctionName(baseName, existingNames)
|
||||
const functionName = generateMcpToolFunctionName(tool.serverName, tool.name, existingNames)
|
||||
|
||||
const inputSchema = tool.inputSchema as Record<string, unknown> | undefined
|
||||
const outputSchema = tool.outputSchema as Record<string, unknown> | undefined
|
||||
@ -127,18 +107,17 @@ export function generateToolFunction(
|
||||
|
||||
const jsCode = `${jsDoc}
|
||||
async function ${functionName}(params) {
|
||||
return await __callTool("${toolId}", params);
|
||||
return await __callTool("${functionName}", params);
|
||||
}`
|
||||
|
||||
const fn = async (params: unknown): Promise<unknown> => {
|
||||
return await callToolFn(toolId, params)
|
||||
return await callToolFn(functionName, params)
|
||||
}
|
||||
|
||||
return {
|
||||
serverId: server.id,
|
||||
serverName: server.name,
|
||||
serverId: tool.serverId,
|
||||
serverName: tool.serverName,
|
||||
toolName: tool.name,
|
||||
toolId,
|
||||
functionName,
|
||||
jsCode,
|
||||
fn,
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { CacheService } from '@main/services/CacheService'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js'
|
||||
|
||||
import { generateToolFunction } from './generator'
|
||||
import { callMcpTool, listAllTools } from './mcp-bridge'
|
||||
import { Runtime } from './runtime'
|
||||
import { searchTools } from './search'
|
||||
import { ToolRegistry } from './tool-registry'
|
||||
import type { ExecInput, SearchQuery } from './types'
|
||||
import type { ExecInput, GeneratedTool, SearchQuery } from './types'
|
||||
|
||||
const logger = loggerService.withContext('MCPServer:Hub')
|
||||
const TOOLS_CACHE_KEY = 'hub:tools'
|
||||
const TOOLS_CACHE_TTL = 60 * 1000 // 1 minute
|
||||
|
||||
/**
|
||||
* Hub MCP Server - A meta-server that aggregates all active MCP servers.
|
||||
@ -23,11 +27,9 @@ const logger = loggerService.withContext('MCPServer:Hub')
|
||||
*/
|
||||
export class HubServer {
|
||||
public server: Server
|
||||
private toolRegistry: ToolRegistry
|
||||
private runtime: Runtime
|
||||
|
||||
constructor() {
|
||||
this.toolRegistry = new ToolRegistry()
|
||||
this.runtime = new Runtime()
|
||||
|
||||
this.server = new Server(
|
||||
@ -118,12 +120,32 @@ export class HubServer {
|
||||
})
|
||||
}
|
||||
|
||||
private async fetchTools(): Promise<GeneratedTool[]> {
|
||||
const cached = CacheService.get<GeneratedTool[]>(TOOLS_CACHE_KEY)
|
||||
if (cached) {
|
||||
logger.debug('Returning cached tools')
|
||||
return cached
|
||||
}
|
||||
|
||||
logger.debug('Fetching fresh tools')
|
||||
const allTools = await listAllTools()
|
||||
const existingNames = new Set<string>()
|
||||
const tools = allTools.map((tool) => generateToolFunction(tool, existingNames, callMcpTool))
|
||||
CacheService.set(TOOLS_CACHE_KEY, tools, TOOLS_CACHE_TTL)
|
||||
return tools
|
||||
}
|
||||
|
||||
invalidateCache(): void {
|
||||
CacheService.remove(TOOLS_CACHE_KEY)
|
||||
logger.debug('Tools cache invalidated')
|
||||
}
|
||||
|
||||
private async handleSearch(query: SearchQuery) {
|
||||
if (!query.query || typeof query.query !== 'string') {
|
||||
throw new McpError(ErrorCode.InvalidParams, 'query parameter is required and must be a string')
|
||||
}
|
||||
|
||||
const tools = await this.toolRegistry.getTools()
|
||||
const tools = await this.fetchTools()
|
||||
const result = searchTools(tools, query)
|
||||
|
||||
return {
|
||||
@ -141,7 +163,7 @@ export class HubServer {
|
||||
throw new McpError(ErrorCode.InvalidParams, 'code parameter is required and must be a string')
|
||||
}
|
||||
|
||||
const tools = await this.toolRegistry.getTools()
|
||||
const tools = await this.fetchTools()
|
||||
const result = await this.runtime.execute(input.code, tools)
|
||||
|
||||
return {
|
||||
@ -153,10 +175,6 @@ export class HubServer {
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
invalidateCache(): void {
|
||||
this.toolRegistry.invalidate()
|
||||
}
|
||||
}
|
||||
|
||||
export default HubServer
|
||||
|
||||
@ -1,84 +1,38 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { BuiltinMCPServerNames, type MCPServer, type MCPTool } from '@types'
|
||||
/**
|
||||
* Bridge module for Hub server to access MCPService.
|
||||
* Re-exports the methods needed by tool-registry and runtime.
|
||||
*/
|
||||
import mcpService from '@main/services/MCPService'
|
||||
import { generateMcpToolFunctionName } from '@shared/mcp'
|
||||
|
||||
const logger = loggerService.withContext('MCPServer:Hub:Bridge')
|
||||
export const listAllTools = () => mcpService.listAllActiveServerTools()
|
||||
|
||||
let mcpServiceInstance: MCPServiceInterface | null = null
|
||||
let mcpServersGetter: (() => MCPServer[]) | null = null
|
||||
const toolFunctionNameToIdMap = new Map<string, { serverId: string; toolName: string }>()
|
||||
|
||||
interface MCPServiceInterface {
|
||||
listTools(_: null, server: MCPServer): Promise<MCPTool[]>
|
||||
callTool(
|
||||
_: null,
|
||||
args: { server: MCPServer; name: string; args: unknown; callId?: string }
|
||||
): Promise<{ content: Array<{ type: string; text?: string }> }>
|
||||
}
|
||||
|
||||
export function setMCPService(service: MCPServiceInterface): void {
|
||||
mcpServiceInstance = service
|
||||
}
|
||||
|
||||
export function setMCPServersGetter(getter: () => MCPServer[]): void {
|
||||
mcpServersGetter = getter
|
||||
}
|
||||
|
||||
export function initHubBridge(service: MCPServiceInterface, serversGetter: () => MCPServer[]): void {
|
||||
mcpServiceInstance = service
|
||||
mcpServersGetter = serversGetter
|
||||
}
|
||||
|
||||
export function getActiveServers(): MCPServer[] {
|
||||
if (!mcpServersGetter) {
|
||||
logger.warn('MCP servers getter not set')
|
||||
return []
|
||||
}
|
||||
|
||||
const servers = mcpServersGetter()
|
||||
return servers.filter((s) => s.isActive && s.name !== BuiltinMCPServerNames.hub)
|
||||
}
|
||||
|
||||
export async function listToolsFromServer(server: MCPServer): Promise<MCPTool[]> {
|
||||
if (!mcpServiceInstance) {
|
||||
logger.error('MCP service not initialized')
|
||||
return []
|
||||
}
|
||||
|
||||
try {
|
||||
return await mcpServiceInstance.listTools(null, server)
|
||||
} catch (error) {
|
||||
logger.error(`Failed to list tools from server ${server.name}:`, error as Error)
|
||||
return []
|
||||
export async function refreshToolMap(): Promise<void> {
|
||||
const tools = await listAllTools()
|
||||
toolFunctionNameToIdMap.clear()
|
||||
const existingNames = new Set<string>()
|
||||
for (const tool of tools) {
|
||||
const functionName = generateMcpToolFunctionName(tool.serverName, tool.name, existingNames)
|
||||
toolFunctionNameToIdMap.set(functionName, { serverId: tool.serverId, toolName: tool.name })
|
||||
}
|
||||
}
|
||||
|
||||
export async function callMcpTool(toolId: string, params: unknown): Promise<unknown> {
|
||||
if (!mcpServiceInstance) {
|
||||
throw new Error('MCP service not initialized')
|
||||
export const callMcpTool = async (functionName: string, params: unknown): Promise<unknown> => {
|
||||
const toolInfo = toolFunctionNameToIdMap.get(functionName)
|
||||
if (!toolInfo) {
|
||||
await refreshToolMap()
|
||||
const retryToolInfo = toolFunctionNameToIdMap.get(functionName)
|
||||
if (!retryToolInfo) {
|
||||
throw new Error(`Tool not found: ${functionName}`)
|
||||
}
|
||||
const toolId = `${retryToolInfo.serverId}__${retryToolInfo.toolName}`
|
||||
const result = await mcpService.callToolById(toolId, params)
|
||||
return extractToolResult(result)
|
||||
}
|
||||
|
||||
const parts = toolId.split('__')
|
||||
if (parts.length < 2) {
|
||||
throw new Error(`Invalid tool ID format: ${toolId}`)
|
||||
}
|
||||
|
||||
const serverId = parts[0]
|
||||
const toolName = parts.slice(1).join('__')
|
||||
|
||||
const servers = getActiveServers()
|
||||
const server = servers.find((s) => s.id === serverId)
|
||||
|
||||
if (!server) {
|
||||
throw new Error(`Server not found: ${serverId}`)
|
||||
}
|
||||
|
||||
logger.debug(`Calling tool ${toolName} on server ${server.name}`)
|
||||
|
||||
const result = await mcpServiceInstance.callTool(null, {
|
||||
server,
|
||||
name: toolName,
|
||||
args: params
|
||||
})
|
||||
|
||||
const toolId = `${toolInfo.serverId}__${toolInfo.toolName}`
|
||||
const result = await mcpService.callToolById(toolId, params)
|
||||
return extractToolResult(result)
|
||||
}
|
||||
|
||||
|
||||
@ -37,7 +37,15 @@ export function searchTools(tools: GeneratedTool[], query: SearchQuery): SearchR
|
||||
}
|
||||
|
||||
function buildSearchText(tool: GeneratedTool): string {
|
||||
const parts = [tool.toolName, tool.functionName, tool.serverName, tool.description || '', tool.signature]
|
||||
const combinedName = tool.serverName ? `${tool.serverName}_${tool.toolName}` : tool.toolName
|
||||
const parts = [
|
||||
tool.toolName,
|
||||
tool.functionName,
|
||||
tool.serverName,
|
||||
combinedName,
|
||||
tool.description || '',
|
||||
tool.signature
|
||||
]
|
||||
return parts.join(' ')
|
||||
}
|
||||
|
||||
@ -55,10 +63,12 @@ function rankTools(tools: GeneratedTool[], keywords: string[]): GeneratedTool[]
|
||||
function calculateScore(tool: GeneratedTool, keywords: string[]): number {
|
||||
let score = 0
|
||||
const toolName = tool.toolName.toLowerCase()
|
||||
const serverName = (tool.serverName || '').toLowerCase()
|
||||
const functionName = tool.functionName.toLowerCase()
|
||||
const description = (tool.description || '').toLowerCase()
|
||||
|
||||
for (const keyword of keywords) {
|
||||
// Match tool name
|
||||
if (toolName === keyword) {
|
||||
score += 10
|
||||
} else if (toolName.startsWith(keyword)) {
|
||||
@ -67,12 +77,24 @@ function calculateScore(tool: GeneratedTool, keywords: string[]): number {
|
||||
score += 3
|
||||
}
|
||||
|
||||
if (functionName === keyword) {
|
||||
// Match server name
|
||||
if (serverName === keyword) {
|
||||
score += 8
|
||||
} else if (functionName.includes(keyword)) {
|
||||
} else if (serverName.startsWith(keyword)) {
|
||||
score += 4
|
||||
} else if (serverName.includes(keyword)) {
|
||||
score += 2
|
||||
}
|
||||
|
||||
// Match function name (serverName_toolName format)
|
||||
if (functionName === keyword) {
|
||||
score += 10
|
||||
} else if (functionName.startsWith(keyword)) {
|
||||
score += 5
|
||||
} else if (functionName.includes(keyword)) {
|
||||
score += 3
|
||||
}
|
||||
|
||||
if (description.includes(keyword)) {
|
||||
const count = (description.match(new RegExp(escapeRegex(keyword), 'g')) || []).length
|
||||
score += Math.min(count, 3)
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
import { generateToolFunction } from './generator'
|
||||
import { callMcpTool, getActiveServers, listToolsFromServer } from './mcp-bridge'
|
||||
import type { GeneratedTool, ToolRegistryOptions } from './types'
|
||||
|
||||
const logger = loggerService.withContext('MCPServer:Hub:Registry')
|
||||
|
||||
const DEFAULT_TTL = 10 * 60 * 1000
|
||||
|
||||
export class ToolRegistry {
|
||||
private tools: Map<string, GeneratedTool> = new Map()
|
||||
private lastRefresh: number = 0
|
||||
private readonly ttl: number
|
||||
private refreshPromise: Promise<void> | null = null
|
||||
|
||||
constructor(options: ToolRegistryOptions = {}) {
|
||||
this.ttl = options.ttl ?? DEFAULT_TTL
|
||||
}
|
||||
|
||||
async getTools(): Promise<GeneratedTool[]> {
|
||||
if (this.isExpired()) {
|
||||
await this.refresh()
|
||||
}
|
||||
return Array.from(this.tools.values())
|
||||
}
|
||||
|
||||
async getTool(toolId: string): Promise<GeneratedTool | undefined> {
|
||||
if (this.isExpired()) {
|
||||
await this.refresh()
|
||||
}
|
||||
return this.tools.get(toolId)
|
||||
}
|
||||
|
||||
getToolByFunctionName(functionName: string): GeneratedTool | undefined {
|
||||
for (const tool of this.tools.values()) {
|
||||
if (tool.functionName === functionName) {
|
||||
return tool
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
private isExpired(): boolean {
|
||||
return Date.now() - this.lastRefresh > this.ttl
|
||||
}
|
||||
|
||||
invalidate(): void {
|
||||
this.lastRefresh = 0
|
||||
this.tools.clear()
|
||||
logger.debug('Tool registry invalidated')
|
||||
}
|
||||
|
||||
async refresh(): Promise<void> {
|
||||
if (this.refreshPromise) {
|
||||
return this.refreshPromise
|
||||
}
|
||||
|
||||
this.refreshPromise = this.doRefresh()
|
||||
|
||||
try {
|
||||
await this.refreshPromise
|
||||
} finally {
|
||||
this.refreshPromise = null
|
||||
}
|
||||
}
|
||||
|
||||
private async doRefresh(): Promise<void> {
|
||||
logger.debug('Refreshing tool registry')
|
||||
|
||||
const servers = getActiveServers()
|
||||
const newTools = new Map<string, GeneratedTool>()
|
||||
const existingNames = new Set<string>()
|
||||
|
||||
for (const server of servers) {
|
||||
try {
|
||||
const serverTools = await listToolsFromServer(server)
|
||||
|
||||
for (const tool of serverTools) {
|
||||
const generatedTool = generateToolFunction(tool, server, existingNames, callMcpTool)
|
||||
|
||||
newTools.set(generatedTool.toolId, generatedTool)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to list tools from server ${server.name}:`, error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
this.tools = newTools
|
||||
this.lastRefresh = Date.now()
|
||||
|
||||
logger.debug(`Tool registry refreshed with ${this.tools.size} tools`)
|
||||
}
|
||||
|
||||
getToolCount(): number {
|
||||
return this.tools.size
|
||||
}
|
||||
}
|
||||
@ -4,7 +4,6 @@ export interface GeneratedTool {
|
||||
serverId: string
|
||||
serverName: string
|
||||
toolName: string
|
||||
toolId: string
|
||||
functionName: string
|
||||
jsCode: string
|
||||
fn: (params: unknown) => Promise<unknown>
|
||||
@ -42,7 +41,7 @@ export interface MCPToolWithServer extends MCPTool {
|
||||
}
|
||||
|
||||
export interface ExecutionContext {
|
||||
__callTool: (toolId: string, params: unknown) => Promise<unknown>
|
||||
__callTool: (functionName: string, params: unknown) => Promise<unknown>
|
||||
parallel: <T>(...promises: Promise<T>[]) => Promise<T[]>
|
||||
settle: <T>(...promises: Promise<T>[]) => Promise<PromiseSettledResult<T>[]>
|
||||
console: ConsoleMethods
|
||||
|
||||
@ -3,8 +3,8 @@ import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { getMCPServersFromRedux } from '@main/apiServer/utils/mcp'
|
||||
import { createInMemoryMCPServer } from '@main/mcpServers/factory'
|
||||
import { initHubBridge } from '@main/mcpServers/hub/mcp-bridge'
|
||||
import { makeSureDirExists, removeEnvProxy } from '@main/utils'
|
||||
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||
import { findCommandInShellEnv, getBinaryName, getBinaryPath, isBinaryExists } from '@main/utils/process'
|
||||
@ -59,7 +59,6 @@ import DxtService from './DxtService'
|
||||
import { CallBackServer } from './mcp/oauth/callback'
|
||||
import { McpOAuthClientProvider } from './mcp/oauth/provider'
|
||||
import { ServerLogBuffer } from './mcp/ServerLogBuffer'
|
||||
import { reduxService } from './ReduxService'
|
||||
import { windowService } from './WindowService'
|
||||
|
||||
// Generic type for caching wrapped functions
|
||||
@ -165,27 +164,61 @@ class McpService {
|
||||
this.checkMcpConnectivity = this.checkMcpConnectivity.bind(this)
|
||||
this.getServerVersion = this.getServerVersion.bind(this)
|
||||
this.getServerLogs = this.getServerLogs.bind(this)
|
||||
|
||||
this.initializeHubDependencies()
|
||||
}
|
||||
|
||||
private initializeHubDependencies(): void {
|
||||
initHubBridge(
|
||||
{
|
||||
listTools: (_: null, server: MCPServer) => this.listToolsImpl(server),
|
||||
callTool: async (_: null, args: { server: MCPServer; name: string; args: unknown; callId?: string }) => {
|
||||
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, args)
|
||||
}
|
||||
},
|
||||
() => {
|
||||
try {
|
||||
const servers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
|
||||
return servers || []
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
/**
|
||||
* List all tools from all active MCP servers (excluding hub).
|
||||
* Used by Hub server's tool registry.
|
||||
*/
|
||||
public async listAllActiveServerTools(): Promise<MCPTool[]> {
|
||||
const servers = await getMCPServersFromRedux()
|
||||
const allTools: MCPTool[] = []
|
||||
|
||||
for (const server of servers) {
|
||||
if (!server.isActive) {
|
||||
continue
|
||||
}
|
||||
)
|
||||
try {
|
||||
const tools = await this.listToolsImpl(server)
|
||||
allTools.push(...tools)
|
||||
} catch (error) {
|
||||
logger.error(`[listAllActiveServerTools] Failed to list tools from ${server.name}:`, error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
return allTools
|
||||
}
|
||||
|
||||
/**
|
||||
* Call a tool by its full ID (serverId__toolName format).
|
||||
* Used by Hub server's runtime.
|
||||
*/
|
||||
public async callToolById(
|
||||
toolId: string,
|
||||
params: unknown
|
||||
): Promise<{ content: Array<{ type: string; text?: string }> }> {
|
||||
const parts = toolId.split('__')
|
||||
if (parts.length < 2) {
|
||||
throw new Error(`Invalid tool ID format: ${toolId}`)
|
||||
}
|
||||
|
||||
const serverId = parts[0]
|
||||
const toolName = parts.slice(1).join('__')
|
||||
|
||||
const servers = await getMCPServersFromRedux()
|
||||
const server = servers.find((s) => s.id === serverId)
|
||||
|
||||
if (!server) {
|
||||
throw new Error(`Server not found: ${serverId}`)
|
||||
}
|
||||
|
||||
logger.debug(`[callToolById] Calling tool ${toolName} on server ${server.name}`)
|
||||
|
||||
return this.callTool(null as unknown as Electron.IpcMainInvokeEvent, {
|
||||
server,
|
||||
name: toolName,
|
||||
args: params
|
||||
})
|
||||
}
|
||||
|
||||
private getServerKey(server: MCPServer): string {
|
||||
|
||||
@ -26,11 +26,13 @@ import {
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getHubModeSystemPrompt } from '@renderer/config/prompts'
|
||||
import { fetchAllActiveServerTools } from '@renderer/services/ApiService'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { type Assistant, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import { type Assistant, getEffectiveMcpMode, type MCPTool, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
||||
import { replacePromptVariables } from '@renderer/utils/prompt'
|
||||
@ -244,7 +246,15 @@ export async function buildStreamTextParams(
|
||||
}
|
||||
|
||||
if (assistant.prompt) {
|
||||
params.system = await replacePromptVariables(assistant.prompt, model.name)
|
||||
let systemPrompt = await replacePromptVariables(assistant.prompt, model.name)
|
||||
if (getEffectiveMcpMode(assistant) === 'auto') {
|
||||
const allActiveTools = await fetchAllActiveServerTools()
|
||||
systemPrompt = systemPrompt + '\n\n' + getHubModeSystemPrompt(allActiveTools)
|
||||
}
|
||||
params.system = systemPrompt
|
||||
} else if (getEffectiveMcpMode(assistant) === 'auto') {
|
||||
const allActiveTools = await fetchAllActiveServerTools()
|
||||
params.system = getHubModeSystemPrompt(allActiveTools)
|
||||
}
|
||||
|
||||
logger.debug('params', params)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { generateMcpToolFunctionName } from '@shared/mcp'
|
||||
import dayjs from 'dayjs'
|
||||
|
||||
export const AGENT_PROMPT = `
|
||||
@ -468,3 +469,68 @@ Example: [nytimes.com](https://nytimes.com/some-page).
|
||||
If have multiple citations, please directly list them like this:
|
||||
[www.nytimes.com](https://nytimes.com/some-page)[www.bbc.com](https://bbc.com/some-page)
|
||||
`
|
||||
|
||||
const HUB_MODE_SYSTEM_PROMPT_BASE = `
|
||||
## MCP Tools (Code Mode)
|
||||
|
||||
You have access to MCP tools via the hub server.
|
||||
|
||||
### Workflow
|
||||
1. Call \`search\` with relevant keywords to discover tools
|
||||
2. Review the returned function signatures and their parameters
|
||||
3. Call \`exec\` with JavaScript code using those functions
|
||||
4. The last expression in your code becomes the result
|
||||
|
||||
### Example Usage
|
||||
|
||||
**Step 1: Search for tools**
|
||||
\`\`\`
|
||||
search({ query: "github,repository" })
|
||||
\`\`\`
|
||||
|
||||
**Step 2: Use discovered tools**
|
||||
\`\`\`javascript
|
||||
exec({
|
||||
code: \`
|
||||
const repos = await searchRepos({ query: "react", limit: 5 })
|
||||
const details = await parallel(
|
||||
repos.map(r => getRepoDetails({ owner: r.owner, repo: r.name }))
|
||||
)
|
||||
return { repos, details }
|
||||
\`
|
||||
})
|
||||
\`\`\`
|
||||
|
||||
### Best Practices
|
||||
- Always search first to discover available tools and their exact signatures
|
||||
- Use descriptive variable names in your code
|
||||
- Handle errors gracefully with try/catch when needed
|
||||
- Use \`parallel()\` for independent operations to improve performance
|
||||
`
|
||||
|
||||
interface ToolInfo {
|
||||
name: string
|
||||
serverName?: string
|
||||
description?: string
|
||||
}
|
||||
|
||||
export function getHubModeSystemPrompt(tools: ToolInfo[] = []): string {
|
||||
if (tools.length === 0) {
|
||||
return HUB_MODE_SYSTEM_PROMPT_BASE
|
||||
}
|
||||
|
||||
const existingNames = new Set<string>()
|
||||
const toolsSection = tools
|
||||
.map((t) => {
|
||||
const functionName = generateMcpToolFunctionName(t.serverName, t.name, existingNames)
|
||||
const desc = t.description || ''
|
||||
const truncatedDesc = desc.length > 50 ? `${desc.slice(0, 50)}...` : desc
|
||||
return `- ${functionName}: ${truncatedDesc}`
|
||||
})
|
||||
.join('\n')
|
||||
|
||||
return `${HUB_MODE_SYSTEM_PROMPT_BASE}
|
||||
### Available Tools
|
||||
${toolsSection}
|
||||
`
|
||||
}
|
||||
|
||||
@ -544,6 +544,20 @@
|
||||
"description": "Default enabled MCP servers",
|
||||
"enableFirst": "Enable this server in MCP settings first",
|
||||
"label": "MCP Servers",
|
||||
"mode": {
|
||||
"auto": {
|
||||
"description": "AI discovers and uses tools automatically",
|
||||
"label": "Auto"
|
||||
},
|
||||
"disabled": {
|
||||
"description": "No MCP tools",
|
||||
"label": "Disabled"
|
||||
},
|
||||
"manual": {
|
||||
"description": "Select specific MCP servers",
|
||||
"label": "Manual"
|
||||
}
|
||||
},
|
||||
"noServersAvailable": "No MCP servers available. Add servers in settings",
|
||||
"title": "MCP Settings"
|
||||
},
|
||||
|
||||
@ -544,6 +544,20 @@
|
||||
"description": "默认启用的 MCP 服务器",
|
||||
"enableFirst": "请先在 MCP 设置中启用此服务器",
|
||||
"label": "MCP 服务器",
|
||||
"mode": {
|
||||
"auto": {
|
||||
"description": "AI 自动发现和使用工具",
|
||||
"label": "自动"
|
||||
},
|
||||
"disabled": {
|
||||
"description": "不使用 MCP 工具",
|
||||
"label": "禁用"
|
||||
},
|
||||
"manual": {
|
||||
"description": "选择特定的 MCP 服务器",
|
||||
"label": "手动"
|
||||
}
|
||||
},
|
||||
"noServersAvailable": "无可用 MCP 服务器。请在设置中添加服务器",
|
||||
"title": "MCP 服务器"
|
||||
},
|
||||
|
||||
@ -544,6 +544,20 @@
|
||||
"description": "預設啟用的 MCP 伺服器",
|
||||
"enableFirst": "請先在 MCP 設定中啟用此伺服器",
|
||||
"label": "MCP 伺服器",
|
||||
"mode": {
|
||||
"auto": {
|
||||
"description": "AI 自動發現和使用工具",
|
||||
"label": "自動"
|
||||
},
|
||||
"disabled": {
|
||||
"description": "不使用 MCP 工具",
|
||||
"label": "停用"
|
||||
},
|
||||
"manual": {
|
||||
"description": "選擇特定的 MCP 伺服器",
|
||||
"label": "手動"
|
||||
}
|
||||
},
|
||||
"noServersAvailable": "無可用 MCP 伺服器。請在設定中新增伺服器",
|
||||
"title": "MCP 設定"
|
||||
},
|
||||
|
||||
@ -8,11 +8,12 @@ import { useTimer } from '@renderer/hooks/useTimer'
|
||||
import type { ToolQuickPanelApi } from '@renderer/pages/home/Inputbar/types'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { EventEmitter } from '@renderer/services/EventService'
|
||||
import type { MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
|
||||
import type { McpMode, MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
|
||||
import { getEffectiveMcpMode } from '@renderer/types'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import { isGeminiWebSearchProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
|
||||
import { Form, Input, Tooltip } from 'antd'
|
||||
import { CircleX, Hammer, Plus } from 'lucide-react'
|
||||
import { CircleX, Hammer, Plus, Sparkles } from 'lucide-react'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -25,7 +26,6 @@ interface Props {
|
||||
resizeTextArea: () => void
|
||||
}
|
||||
|
||||
// 添加类型定义
|
||||
interface PromptArgument {
|
||||
name: string
|
||||
description?: string
|
||||
@ -44,24 +44,19 @@ interface ResourceData {
|
||||
uri?: string
|
||||
}
|
||||
|
||||
// 提取到组件外的工具函数
|
||||
const extractPromptContent = (response: any): string | null => {
|
||||
// Handle string response (backward compatibility)
|
||||
if (typeof response === 'string') {
|
||||
return response
|
||||
}
|
||||
|
||||
// Handle GetMCPPromptResponse format
|
||||
if (response && Array.isArray(response.messages)) {
|
||||
let formattedContent = ''
|
||||
|
||||
for (const message of response.messages) {
|
||||
if (!message.content) continue
|
||||
|
||||
// Add role prefix if available
|
||||
const rolePrefix = message.role ? `**${message.role.charAt(0).toUpperCase() + message.role.slice(1)}:** ` : ''
|
||||
|
||||
// Process different content types
|
||||
switch (message.content.type) {
|
||||
case 'text':
|
||||
formattedContent += `${rolePrefix}${message.content.text}\n\n`
|
||||
@ -98,7 +93,6 @@ const extractPromptContent = (response: any): string | null => {
|
||||
return formattedContent.trim()
|
||||
}
|
||||
|
||||
// Fallback handling for single message format
|
||||
if (response && response.messages && response.messages.length > 0) {
|
||||
const message = response.messages[0]
|
||||
if (message.content && message.content.text) {
|
||||
@ -121,7 +115,6 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
const model = assistant.model
|
||||
const { setTimeoutTimer } = useTimer()
|
||||
|
||||
// 使用 useRef 存储不需要触发重渲染的值
|
||||
const isMountedRef = useRef(true)
|
||||
|
||||
useEffect(() => {
|
||||
@ -130,11 +123,30 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
}
|
||||
}, [])
|
||||
|
||||
const currentMode = useMemo(() => getEffectiveMcpMode(assistant), [assistant])
|
||||
|
||||
const mcpServers = useMemo(() => assistant.mcpServers || [], [assistant.mcpServers])
|
||||
const assistantMcpServers = useMemo(
|
||||
() => activedMcpServers.filter((server) => mcpServers.some((s) => s.id === server.id)),
|
||||
[activedMcpServers, mcpServers]
|
||||
)
|
||||
|
||||
const handleModeChange = useCallback(
|
||||
(mode: McpMode) => {
|
||||
setTimeoutTimer(
|
||||
'updateMcpMode',
|
||||
() => {
|
||||
updateAssistant({
|
||||
...assistant,
|
||||
mcpMode: mode
|
||||
})
|
||||
},
|
||||
200
|
||||
)
|
||||
},
|
||||
[assistant, setTimeoutTimer, updateAssistant]
|
||||
)
|
||||
|
||||
const handleMcpServerSelect = useCallback(
|
||||
(server: MCPServer) => {
|
||||
const update = { ...assistant }
|
||||
@ -144,29 +156,24 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
update.mcpServers = [...mcpServers, server]
|
||||
}
|
||||
|
||||
// only for gemini
|
||||
if (update.mcpServers.length > 0 && isGeminiModel(model) && isToolUseModeFunction(assistant)) {
|
||||
const provider = getProviderByModel(model)
|
||||
if (isSupportUrlContextProvider(provider) && assistant.enableUrlContext) {
|
||||
window.toast.warning(t('chat.mcp.warning.url_context'))
|
||||
update.enableUrlContext = false
|
||||
}
|
||||
if (
|
||||
// 非官方 API (openrouter etc.) 可能支持同时启用内置搜索和函数调用
|
||||
// 这里先假设 gemini type 和 vertexai type 不支持
|
||||
isGeminiWebSearchProvider(provider) &&
|
||||
assistant.enableWebSearch
|
||||
) {
|
||||
if (isGeminiWebSearchProvider(provider) && assistant.enableWebSearch) {
|
||||
window.toast.warning(t('chat.mcp.warning.gemini_web_search'))
|
||||
update.enableWebSearch = false
|
||||
}
|
||||
}
|
||||
|
||||
update.mcpMode = 'manual'
|
||||
updateAssistant(update)
|
||||
},
|
||||
[assistant, assistantMcpServers, mcpServers, model, t, updateAssistant]
|
||||
)
|
||||
|
||||
// 使用 useRef 缓存事件处理函数
|
||||
const handleMcpServerSelectRef = useRef(handleMcpServerSelect)
|
||||
handleMcpServerSelectRef.current = handleMcpServerSelect
|
||||
|
||||
@ -176,23 +183,7 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
return () => EventEmitter.off('mcp-server-select', handler)
|
||||
}, [])
|
||||
|
||||
const updateMcpEnabled = useCallback(
|
||||
(enabled: boolean) => {
|
||||
setTimeoutTimer(
|
||||
'updateMcpEnabled',
|
||||
() => {
|
||||
updateAssistant({
|
||||
...assistant,
|
||||
mcpServers: enabled ? assistant.mcpServers || [] : []
|
||||
})
|
||||
},
|
||||
200
|
||||
)
|
||||
},
|
||||
[assistant, setTimeoutTimer, updateAssistant]
|
||||
)
|
||||
|
||||
const menuItems = useMemo(() => {
|
||||
const manualModeMenuItems = useMemo(() => {
|
||||
const newList: QuickPanelListItem[] = activedMcpServers.map((server) => ({
|
||||
label: server.name,
|
||||
description: server.description || server.baseUrl,
|
||||
@ -207,33 +198,69 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
action: () => navigate('/settings/mcp')
|
||||
})
|
||||
|
||||
newList.unshift({
|
||||
label: t('settings.input.clear.all'),
|
||||
description: t('settings.mcp.disable.description'),
|
||||
icon: <CircleX />,
|
||||
isSelected: false,
|
||||
action: () => {
|
||||
updateMcpEnabled(false)
|
||||
quickPanelHook.close()
|
||||
}
|
||||
})
|
||||
|
||||
return newList
|
||||
}, [activedMcpServers, t, assistantMcpServers, navigate, updateMcpEnabled, quickPanelHook])
|
||||
}, [activedMcpServers, t, assistantMcpServers, navigate])
|
||||
|
||||
const openQuickPanel = useCallback(() => {
|
||||
const openManualModePanel = useCallback(() => {
|
||||
quickPanelHook.open({
|
||||
title: t('settings.mcp.title'),
|
||||
list: menuItems,
|
||||
title: t('assistants.settings.mcp.mode.manual.label'),
|
||||
list: manualModeMenuItems,
|
||||
symbol: QuickPanelReservedSymbol.Mcp,
|
||||
multiple: true,
|
||||
afterAction({ item }) {
|
||||
item.isSelected = !item.isSelected
|
||||
}
|
||||
})
|
||||
}, [manualModeMenuItems, quickPanelHook, t])
|
||||
|
||||
const menuItems = useMemo(() => {
|
||||
const newList: QuickPanelListItem[] = []
|
||||
|
||||
newList.push({
|
||||
label: t('assistants.settings.mcp.mode.disabled.label'),
|
||||
description: t('assistants.settings.mcp.mode.disabled.description'),
|
||||
icon: <CircleX />,
|
||||
isSelected: currentMode === 'disabled',
|
||||
action: () => {
|
||||
handleModeChange('disabled')
|
||||
quickPanelHook.close()
|
||||
}
|
||||
})
|
||||
|
||||
newList.push({
|
||||
label: t('assistants.settings.mcp.mode.auto.label'),
|
||||
description: t('assistants.settings.mcp.mode.auto.description'),
|
||||
icon: <Sparkles />,
|
||||
isSelected: currentMode === 'auto',
|
||||
action: () => {
|
||||
handleModeChange('auto')
|
||||
quickPanelHook.close()
|
||||
}
|
||||
})
|
||||
|
||||
newList.push({
|
||||
label: t('assistants.settings.mcp.mode.manual.label'),
|
||||
description: t('assistants.settings.mcp.mode.manual.description'),
|
||||
icon: <Hammer />,
|
||||
isSelected: currentMode === 'manual',
|
||||
isMenu: true,
|
||||
action: () => {
|
||||
openManualModePanel()
|
||||
}
|
||||
})
|
||||
|
||||
return newList
|
||||
}, [t, currentMode, handleModeChange, quickPanelHook, openManualModePanel])
|
||||
|
||||
const openQuickPanel = useCallback(() => {
|
||||
quickPanelHook.open({
|
||||
title: t('settings.mcp.title'),
|
||||
list: menuItems,
|
||||
symbol: QuickPanelReservedSymbol.Mcp,
|
||||
multiple: false
|
||||
})
|
||||
}, [menuItems, quickPanelHook, t])
|
||||
|
||||
// 使用 useCallback 优化 insertPromptIntoTextArea
|
||||
const insertPromptIntoTextArea = useCallback(
|
||||
(promptText: string) => {
|
||||
setInputValue((prev) => {
|
||||
@ -245,7 +272,6 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
const selectionEndPosition = cursorPosition + promptText.length
|
||||
const newText = prev.slice(0, cursorPosition) + promptText + prev.slice(cursorPosition)
|
||||
|
||||
// 使用 requestAnimationFrame 优化 DOM 操作
|
||||
requestAnimationFrame(() => {
|
||||
textArea.focus()
|
||||
textArea.setSelectionRange(selectionStart, selectionEndPosition)
|
||||
@ -424,7 +450,6 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
[activedMcpServers, t, insertPromptIntoTextArea]
|
||||
)
|
||||
|
||||
// 优化 resourcesList 的状态更新
|
||||
const [resourcesList, setResourcesList] = useState<QuickPanelListItem[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
@ -514,17 +539,26 @@ const MCPToolsButton: FC<Props> = ({ quickPanel, setInputValue, resizeTextArea,
|
||||
}
|
||||
}, [openPromptList, openQuickPanel, openResourcesList, quickPanel, t])
|
||||
|
||||
const isActive = currentMode !== 'disabled'
|
||||
|
||||
const getButtonIcon = () => {
|
||||
switch (currentMode) {
|
||||
case 'auto':
|
||||
return <Sparkles size={18} />
|
||||
case 'disabled':
|
||||
case 'manual':
|
||||
default:
|
||||
return <Hammer size={18} />
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip placement="top" title={t('settings.mcp.title')} mouseLeaveDelay={0} arrow>
|
||||
<ActionIconButton
|
||||
onClick={handleOpenQuickPanel}
|
||||
active={assistant.mcpServers && assistant.mcpServers.length > 0}
|
||||
aria-label={t('settings.mcp.title')}>
|
||||
<Hammer size={18} />
|
||||
<ActionIconButton onClick={handleOpenQuickPanel} active={isActive} aria-label={t('settings.mcp.title')}>
|
||||
{getButtonIcon()}
|
||||
</ActionIconButton>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
// 使用 React.memo 包装组件
|
||||
export default React.memo(MCPToolsButton)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { ActionIconButton } from '@renderer/components/Buttons'
|
||||
import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||
import { useTimer } from '@renderer/hooks/useTimer'
|
||||
import { getEffectiveMcpMode } from '@renderer/types'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import { Tooltip } from 'antd'
|
||||
import { Link } from 'lucide-react'
|
||||
@ -30,8 +31,7 @@ const UrlContextButton: FC<Props> = ({ assistantId }) => {
|
||||
() => {
|
||||
const update = { ...assistant }
|
||||
if (
|
||||
assistant.mcpServers &&
|
||||
assistant.mcpServers.length > 0 &&
|
||||
getEffectiveMcpMode(assistant) !== 'disabled' &&
|
||||
urlContentNewState === true &&
|
||||
isToolUseModeFunction(assistant)
|
||||
) {
|
||||
|
||||
@ -16,7 +16,7 @@ import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders'
|
||||
import type { ToolQuickPanelController, ToolRenderContext } from '@renderer/pages/home/Inputbar/types'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types'
|
||||
import { getEffectiveMcpMode, type WebSearchProvider, type WebSearchProviderId } from '@renderer/types'
|
||||
import { hasObjectKey } from '@renderer/utils'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
|
||||
@ -108,8 +108,7 @@ export const useWebSearchPanelController = (assistantId: string, quickPanelContr
|
||||
isGeminiModel(model) &&
|
||||
isToolUseModeFunction(assistant) &&
|
||||
update.enableWebSearch &&
|
||||
assistant.mcpServers &&
|
||||
assistant.mcpServers.length > 0
|
||||
getEffectiveMcpMode(assistant) !== 'disabled'
|
||||
) {
|
||||
update.enableWebSearch = false
|
||||
window.toast.warning(t('chat.mcp.warning.gemini_web_search'))
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import { InfoCircleOutlined } from '@ant-design/icons'
|
||||
import { Box } from '@renderer/components/Layout'
|
||||
import { useMCPServers } from '@renderer/hooks/useMCPServers'
|
||||
import type { Assistant, AssistantSettings } from '@renderer/types'
|
||||
import { Empty, Switch, Tooltip } from 'antd'
|
||||
import type { Assistant, AssistantSettings, McpMode } from '@renderer/types'
|
||||
import { getEffectiveMcpMode } from '@renderer/types'
|
||||
import { Empty, Radio, Switch, Tooltip } from 'antd'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@ -27,22 +28,26 @@ const AssistantMCPSettings: React.FC<Props> = ({ assistant, updateAssistant }) =
|
||||
const { t } = useTranslation()
|
||||
const { mcpServers: allMcpServers } = useMCPServers()
|
||||
|
||||
const currentMode = getEffectiveMcpMode(assistant)
|
||||
|
||||
const handleModeChange = (mode: McpMode) => {
|
||||
updateAssistant({ ...assistant, mcpMode: mode })
|
||||
}
|
||||
|
||||
const onUpdate = (ids: string[]) => {
|
||||
const mcpServers = ids
|
||||
.map((id) => allMcpServers.find((server) => server.id === id))
|
||||
.filter((server): server is MCPServer => server !== undefined && server.isActive)
|
||||
|
||||
updateAssistant({ ...assistant, mcpServers })
|
||||
updateAssistant({ ...assistant, mcpServers, mcpMode: 'manual' })
|
||||
}
|
||||
|
||||
const handleServerToggle = (serverId: string) => {
|
||||
const currentServerIds = assistant.mcpServers?.map((server) => server.id) || []
|
||||
|
||||
if (currentServerIds.includes(serverId)) {
|
||||
// Remove server if it's already enabled
|
||||
onUpdate(currentServerIds.filter((id) => id !== serverId))
|
||||
} else {
|
||||
// Add server if it's not enabled
|
||||
onUpdate([...currentServerIds, serverId])
|
||||
}
|
||||
}
|
||||
@ -58,49 +63,77 @@ const AssistantMCPSettings: React.FC<Props> = ({ assistant, updateAssistant }) =
|
||||
<InfoIcon />
|
||||
</Tooltip>
|
||||
</Box>
|
||||
{allMcpServers.length > 0 && (
|
||||
<EnabledCount>
|
||||
{enabledCount} / {allMcpServers.length} {t('settings.mcp.active')}
|
||||
</EnabledCount>
|
||||
)}
|
||||
</HeaderContainer>
|
||||
|
||||
{allMcpServers.length > 0 ? (
|
||||
<ServerList>
|
||||
{allMcpServers.map((server) => {
|
||||
const isEnabled = assistant.mcpServers?.some((s) => s.id === server.id) || false
|
||||
<ModeSelector>
|
||||
<Radio.Group value={currentMode} onChange={(e) => handleModeChange(e.target.value)}>
|
||||
<Radio.Button value="disabled">
|
||||
<ModeOption>
|
||||
<ModeLabel>{t('assistants.settings.mcp.mode.disabled.label')}</ModeLabel>
|
||||
<ModeDescription>{t('assistants.settings.mcp.mode.disabled.description')}</ModeDescription>
|
||||
</ModeOption>
|
||||
</Radio.Button>
|
||||
<Radio.Button value="auto">
|
||||
<ModeOption>
|
||||
<ModeLabel>{t('assistants.settings.mcp.mode.auto.label')}</ModeLabel>
|
||||
<ModeDescription>{t('assistants.settings.mcp.mode.auto.description')}</ModeDescription>
|
||||
</ModeOption>
|
||||
</Radio.Button>
|
||||
<Radio.Button value="manual">
|
||||
<ModeOption>
|
||||
<ModeLabel>{t('assistants.settings.mcp.mode.manual.label')}</ModeLabel>
|
||||
<ModeDescription>{t('assistants.settings.mcp.mode.manual.description')}</ModeDescription>
|
||||
</ModeOption>
|
||||
</Radio.Button>
|
||||
</Radio.Group>
|
||||
</ModeSelector>
|
||||
|
||||
return (
|
||||
<ServerItem key={server.id} isEnabled={isEnabled}>
|
||||
<ServerInfo>
|
||||
<ServerName>{server.name}</ServerName>
|
||||
{server.description && <ServerDescription>{server.description}</ServerDescription>}
|
||||
{server.baseUrl && <ServerUrl>{server.baseUrl}</ServerUrl>}
|
||||
</ServerInfo>
|
||||
<Tooltip
|
||||
title={
|
||||
!server.isActive
|
||||
? t('assistants.settings.mcp.enableFirst', 'Enable this server in MCP settings first')
|
||||
: undefined
|
||||
}>
|
||||
<Switch
|
||||
checked={isEnabled}
|
||||
disabled={!server.isActive}
|
||||
onChange={() => handleServerToggle(server.id)}
|
||||
size="small"
|
||||
/>
|
||||
</Tooltip>
|
||||
</ServerItem>
|
||||
)
|
||||
})}
|
||||
</ServerList>
|
||||
) : (
|
||||
<EmptyContainer>
|
||||
<Empty
|
||||
description={t('assistants.settings.mcp.noServersAvailable', 'No MCP servers available')}
|
||||
image={Empty.PRESENTED_IMAGE_SIMPLE}
|
||||
/>
|
||||
</EmptyContainer>
|
||||
{currentMode === 'manual' && (
|
||||
<>
|
||||
{allMcpServers.length > 0 && (
|
||||
<EnabledCount>
|
||||
{enabledCount} / {allMcpServers.length} {t('settings.mcp.active')}
|
||||
</EnabledCount>
|
||||
)}
|
||||
|
||||
{allMcpServers.length > 0 ? (
|
||||
<ServerList>
|
||||
{allMcpServers.map((server) => {
|
||||
const isEnabled = assistant.mcpServers?.some((s) => s.id === server.id) || false
|
||||
|
||||
return (
|
||||
<ServerItem key={server.id} isEnabled={isEnabled}>
|
||||
<ServerInfo>
|
||||
<ServerName>{server.name}</ServerName>
|
||||
{server.description && <ServerDescription>{server.description}</ServerDescription>}
|
||||
{server.baseUrl && <ServerUrl>{server.baseUrl}</ServerUrl>}
|
||||
</ServerInfo>
|
||||
<Tooltip
|
||||
title={
|
||||
!server.isActive
|
||||
? t('assistants.settings.mcp.enableFirst', 'Enable this server in MCP settings first')
|
||||
: undefined
|
||||
}>
|
||||
<Switch
|
||||
checked={isEnabled}
|
||||
disabled={!server.isActive}
|
||||
onChange={() => handleServerToggle(server.id)}
|
||||
size="small"
|
||||
/>
|
||||
</Tooltip>
|
||||
</ServerItem>
|
||||
)
|
||||
})}
|
||||
</ServerList>
|
||||
) : (
|
||||
<EmptyContainer>
|
||||
<Empty
|
||||
description={t('assistants.settings.mcp.noServersAvailable', 'No MCP servers available')}
|
||||
image={Empty.PRESENTED_IMAGE_SIMPLE}
|
||||
/>
|
||||
</EmptyContainer>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Container>
|
||||
)
|
||||
@ -127,9 +160,54 @@ const InfoIcon = styled(InfoCircleOutlined)`
|
||||
cursor: help;
|
||||
`
|
||||
|
||||
const ModeSelector = styled.div`
|
||||
margin-bottom: 16px;
|
||||
|
||||
.ant-radio-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.ant-radio-button-wrapper {
|
||||
height: auto;
|
||||
padding: 12px 16px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--color-border);
|
||||
|
||||
&:not(:first-child)::before {
|
||||
display: none;
|
||||
}
|
||||
|
||||
&:first-child {
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
&:last-child {
|
||||
border-radius: 8px;
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
const ModeOption = styled.div`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2px;
|
||||
`
|
||||
|
||||
const ModeLabel = styled.span`
|
||||
font-weight: 600;
|
||||
`
|
||||
|
||||
const ModeDescription = styled.span`
|
||||
font-size: 12px;
|
||||
color: var(--color-text-2);
|
||||
`
|
||||
|
||||
const EnabledCount = styled.span`
|
||||
font-size: 12px;
|
||||
color: var(--color-text-2);
|
||||
margin-bottom: 8px;
|
||||
`
|
||||
|
||||
const EmptyContainer = styled.div`
|
||||
|
||||
@ -8,8 +8,9 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import { hubMCPServer } from '@renderer/store/mcp'
|
||||
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
||||
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
|
||||
import { type FetchChatCompletionParams, getEffectiveMcpMode, isSystemProvider } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message, ResponseError } from '@renderer/types/newMessage'
|
||||
@ -51,14 +52,60 @@ import type { StreamProcessorCallbacks } from './StreamProcessingService'
|
||||
|
||||
const logger = loggerService.withContext('ApiService')
|
||||
|
||||
export async function fetchMcpTools(assistant: Assistant) {
|
||||
// Get MCP tools (Fix duplicate declaration)
|
||||
let mcpTools: MCPTool[] = [] // Initialize as empty array
|
||||
/**
|
||||
* Get the MCP servers to use based on the assistant's MCP mode.
|
||||
*/
|
||||
export function getMcpServersForAssistant(assistant: Assistant): MCPServer[] {
|
||||
const mode = getEffectiveMcpMode(assistant)
|
||||
const allMcpServers = store.getState().mcp.servers || []
|
||||
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
|
||||
const assistantMcpServers = assistant.mcpServers || []
|
||||
|
||||
const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
|
||||
switch (mode) {
|
||||
case 'disabled':
|
||||
return []
|
||||
case 'auto':
|
||||
return [hubMCPServer]
|
||||
case 'manual': {
|
||||
const assistantMcpServers = assistant.mcpServers || []
|
||||
return activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
|
||||
}
|
||||
default:
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchAllActiveServerTools(): Promise<MCPTool[]> {
|
||||
const allMcpServers = store.getState().mcp.servers || []
|
||||
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
|
||||
|
||||
if (activedMcpServers.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
try {
|
||||
const toolPromises = activedMcpServers.map(async (mcpServer: MCPServer) => {
|
||||
try {
|
||||
const tools = await window.api.mcp.listTools(mcpServer)
|
||||
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
|
||||
} catch (error) {
|
||||
logger.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error as Error)
|
||||
return []
|
||||
}
|
||||
})
|
||||
const results = await Promise.allSettled(toolPromises)
|
||||
return results
|
||||
.filter((result): result is PromiseFulfilledResult<MCPTool[]> => result.status === 'fulfilled')
|
||||
.map((result) => result.value)
|
||||
.flat()
|
||||
} catch (toolError) {
|
||||
logger.error('Error fetching all active server tools:', toolError as Error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchMcpTools(assistant: Assistant) {
|
||||
let mcpTools: MCPTool[] = []
|
||||
const enabledMCPs = getMcpServersForAssistant(assistant)
|
||||
|
||||
if (enabledMCPs && enabledMCPs.length > 0) {
|
||||
try {
|
||||
|
||||
54
src/renderer/src/services/__tests__/mcpMode.test.ts
Normal file
54
src/renderer/src/services/__tests__/mcpMode.test.ts
Normal file
@ -0,0 +1,54 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import type { Assistant, MCPServer } from '../../types'
|
||||
import { getEffectiveMcpMode } from '../../types'
|
||||
|
||||
describe('getEffectiveMcpMode', () => {
|
||||
it('should return mcpMode when explicitly set to auto', () => {
|
||||
const assistant: Partial<Assistant> = { mcpMode: 'auto' }
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('auto')
|
||||
})
|
||||
|
||||
it('should return disabled when mcpMode is explicitly disabled', () => {
|
||||
const assistant: Partial<Assistant> = { mcpMode: 'disabled' }
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
|
||||
})
|
||||
|
||||
it('should return manual when mcpMode is explicitly manual', () => {
|
||||
const assistant: Partial<Assistant> = { mcpMode: 'manual' }
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('manual')
|
||||
})
|
||||
|
||||
it('should return manual when no mcpMode but mcpServers has items (backward compatibility)', () => {
|
||||
const assistant: Partial<Assistant> = {
|
||||
mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[]
|
||||
}
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('manual')
|
||||
})
|
||||
|
||||
it('should return disabled when no mcpMode and no mcpServers (backward compatibility)', () => {
|
||||
const assistant: Partial<Assistant> = {}
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
|
||||
})
|
||||
|
||||
it('should return disabled when no mcpMode and empty mcpServers (backward compatibility)', () => {
|
||||
const assistant: Partial<Assistant> = { mcpServers: [] }
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
|
||||
})
|
||||
|
||||
it('should prioritize explicit mcpMode over mcpServers presence', () => {
|
||||
const assistant: Partial<Assistant> = {
|
||||
mcpMode: 'disabled',
|
||||
mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[]
|
||||
}
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('disabled')
|
||||
})
|
||||
|
||||
it('should return auto when mcpMode is auto regardless of mcpServers', () => {
|
||||
const assistant: Partial<Assistant> = {
|
||||
mcpMode: 'auto',
|
||||
mcpServers: [{ id: 'test', name: 'Test Server', isActive: true }] as MCPServer[]
|
||||
}
|
||||
expect(getEffectiveMcpMode(assistant as Assistant)).toBe('auto')
|
||||
})
|
||||
})
|
||||
@ -86,6 +86,20 @@ export { mcpSlice }
|
||||
// Export the reducer as default export
|
||||
export default mcpSlice.reducer
|
||||
|
||||
/**
|
||||
* Hub MCP server for auto mode - aggregates all MCP servers for LLM code mode.
|
||||
* This server is injected automatically when mcpMode === 'auto'.
|
||||
*/
|
||||
export const hubMCPServer: BuiltinMCPServer = {
|
||||
id: 'hub',
|
||||
name: BuiltinMCPServerNames.hub,
|
||||
type: 'inMemory',
|
||||
isActive: true,
|
||||
provider: 'CherryAI',
|
||||
installSource: 'builtin',
|
||||
isTrusted: true
|
||||
}
|
||||
|
||||
/**
|
||||
* User-installable built-in MCP servers shown in the UI.
|
||||
*
|
||||
|
||||
@ -27,6 +27,8 @@ export * from './ocr'
|
||||
export * from './plugin'
|
||||
export * from './provider'
|
||||
|
||||
export type McpMode = 'disabled' | 'auto' | 'manual'
|
||||
|
||||
export type Assistant = {
|
||||
id: string
|
||||
name: string
|
||||
@ -47,6 +49,8 @@ export type Assistant = {
|
||||
// enableUrlContext 是 Gemini/Anthropic 的特有功能
|
||||
enableUrlContext?: boolean
|
||||
enableGenerateImage?: boolean
|
||||
/** MCP mode: 'disabled' (no MCP), 'auto' (hub server only), 'manual' (user selects servers) */
|
||||
mcpMode?: McpMode
|
||||
mcpServers?: MCPServer[]
|
||||
knowledgeRecognition?: 'off' | 'on'
|
||||
regularPhrases?: QuickPhrase[] // Added for regular phrase
|
||||
@ -57,6 +61,15 @@ export type Assistant = {
|
||||
targetLanguage?: TranslateLanguage
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the effective MCP mode for an assistant with backward compatibility.
|
||||
* Legacy assistants without mcpMode default based on mcpServers presence.
|
||||
*/
|
||||
export function getEffectiveMcpMode(assistant: Assistant): McpMode {
|
||||
if (assistant.mcpMode) return assistant.mcpMode
|
||||
return (assistant.mcpServers?.length ?? 0) > 0 ? 'manual' : 'disabled'
|
||||
}
|
||||
|
||||
export type TranslateAssistant = Assistant & {
|
||||
model: Model
|
||||
content: string
|
||||
|
||||
@ -13,7 +13,7 @@ import { isFunctionCallingModel, isVisionModel } from '@renderer/config/models'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { currentSpan } from '@renderer/services/SpanManagerService'
|
||||
import store from '@renderer/store'
|
||||
import { addMCPServer } from '@renderer/store/mcp'
|
||||
import { addMCPServer, hubMCPServer } from '@renderer/store/mcp'
|
||||
import type {
|
||||
Assistant,
|
||||
MCPCallToolResponse,
|
||||
@ -325,7 +325,16 @@ export function filterMCPTools(
|
||||
|
||||
export function getMcpServerByTool(tool: MCPTool) {
|
||||
const servers = store.getState().mcp.servers
|
||||
return servers.find((s) => s.id === tool.serverId)
|
||||
const server = servers.find((s) => s.id === tool.serverId)
|
||||
if (server) {
|
||||
return server
|
||||
}
|
||||
// For hub server (auto mode), the server isn't in the store
|
||||
// Return the hub server constant if the tool's serverId matches
|
||||
if (tool.serverId === 'hub') {
|
||||
return hubMCPServer
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
export function isToolAutoApproved(tool: MCPTool, server?: MCPServer): boolean {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user