mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-06 21:35:52 +08:00
fix: align MCP tool ids for permissions (#12127)
* fix(agents): align MCP tool IDs for permissions Normalize legacy MCP allowlist entries so auto-approval matches SDK tool names. Signed-off-by: mathholic <h.p.zhumeng@gmail.com> * fix: normalize mcp tool ids in sessions Signed-off-by: macmini <h.p.zhumeng@gmail.com> * fix: align mcp tool ids with buildFunctionCallToolName --------- Signed-off-by: mathholic <h.p.zhumeng@gmail.com> Signed-off-by: macmini <h.p.zhumeng@gmail.com>
This commit is contained in:
parent
528d6d37f2
commit
ed4353b054
@ -2,6 +2,7 @@ import { loggerService } from '@logger'
|
|||||||
import { mcpApiService } from '@main/apiServer/services/mcp'
|
import { mcpApiService } from '@main/apiServer/services/mcp'
|
||||||
import type { ModelValidationError } from '@main/apiServer/utils'
|
import type { ModelValidationError } from '@main/apiServer/utils'
|
||||||
import { validateModelId } from '@main/apiServer/utils'
|
import { validateModelId } from '@main/apiServer/utils'
|
||||||
|
import { buildFunctionCallToolName } from '@main/utils/mcp'
|
||||||
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
|
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
|
||||||
import { objectKeys } from '@types'
|
import { objectKeys } from '@types'
|
||||||
import fs from 'fs'
|
import fs from 'fs'
|
||||||
@ -14,6 +15,17 @@ import { builtinSlashCommands } from './services/claudecode/commands'
|
|||||||
import { builtinTools } from './services/claudecode/tools'
|
import { builtinTools } from './services/claudecode/tools'
|
||||||
|
|
||||||
const logger = loggerService.withContext('BaseService')
|
const logger = loggerService.withContext('BaseService')
|
||||||
|
const MCP_TOOL_ID_PREFIX = 'mcp__'
|
||||||
|
const MCP_TOOL_LEGACY_PREFIX = 'mcp_'
|
||||||
|
|
||||||
|
const buildMcpToolId = (serverId: string, toolName: string) => `${MCP_TOOL_ID_PREFIX}${serverId}__${toolName}`
|
||||||
|
const toLegacyMcpToolId = (toolId: string) => {
|
||||||
|
if (!toolId.startsWith(MCP_TOOL_ID_PREFIX)) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
const rawId = toolId.slice(MCP_TOOL_ID_PREFIX.length)
|
||||||
|
return `${MCP_TOOL_LEGACY_PREFIX}${rawId.replace(/__/g, '_')}`
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base service class providing shared utilities for all agent-related services.
|
* Base service class providing shared utilities for all agent-related services.
|
||||||
@ -35,8 +47,12 @@ export abstract class BaseService {
|
|||||||
'slash_commands'
|
'slash_commands'
|
||||||
]
|
]
|
||||||
|
|
||||||
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
|
public async listMcpTools(
|
||||||
|
agentType: AgentType,
|
||||||
|
ids?: string[]
|
||||||
|
): Promise<{ tools: Tool[]; legacyIdMap: Map<string, string> }> {
|
||||||
const tools: Tool[] = []
|
const tools: Tool[] = []
|
||||||
|
const legacyIdMap = new Map<string, string>()
|
||||||
if (agentType === 'claude-code') {
|
if (agentType === 'claude-code') {
|
||||||
tools.push(...builtinTools)
|
tools.push(...builtinTools)
|
||||||
}
|
}
|
||||||
@ -46,13 +62,21 @@ export abstract class BaseService {
|
|||||||
const server = await mcpApiService.getServerInfo(id)
|
const server = await mcpApiService.getServerInfo(id)
|
||||||
if (server) {
|
if (server) {
|
||||||
server.tools.forEach((tool: MCPTool) => {
|
server.tools.forEach((tool: MCPTool) => {
|
||||||
|
const canonicalId = buildFunctionCallToolName(server.name, tool.name)
|
||||||
|
const serverIdBasedId = buildMcpToolId(id, tool.name)
|
||||||
|
const legacyId = toLegacyMcpToolId(serverIdBasedId)
|
||||||
|
|
||||||
tools.push({
|
tools.push({
|
||||||
id: `mcp_${id}_${tool.name}`,
|
id: canonicalId,
|
||||||
name: tool.name,
|
name: tool.name,
|
||||||
type: 'mcp',
|
type: 'mcp',
|
||||||
description: tool.description || '',
|
description: tool.description || '',
|
||||||
requirePermissions: true
|
requirePermissions: true
|
||||||
})
|
})
|
||||||
|
legacyIdMap.set(serverIdBasedId, canonicalId)
|
||||||
|
if (legacyId) {
|
||||||
|
legacyIdMap.set(legacyId, canonicalId)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@ -64,7 +88,53 @@ export abstract class BaseService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tools
|
return { tools, legacyIdMap }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Normalize MCP tool IDs in allowed_tools to the current format.
|
||||||
|
*
|
||||||
|
* Legacy formats:
|
||||||
|
* - "mcp__<serverId>__<toolName>" (double underscore separators, server ID based)
|
||||||
|
* - "mcp_<serverId>_<toolName>" (single underscore separators)
|
||||||
|
* Current format: "mcp__<serverName>__<toolName>" (double underscore separators).
|
||||||
|
*
|
||||||
|
* This keeps persisted data compatible without requiring a database migration.
|
||||||
|
*/
|
||||||
|
protected normalizeAllowedTools(
|
||||||
|
allowedTools: string[] | undefined,
|
||||||
|
tools: Tool[],
|
||||||
|
legacyIdMap?: Map<string, string>
|
||||||
|
): string[] | undefined {
|
||||||
|
if (!allowedTools || allowedTools.length === 0) {
|
||||||
|
return allowedTools
|
||||||
|
}
|
||||||
|
|
||||||
|
const resolvedLegacyIdMap = new Map<string, string>()
|
||||||
|
|
||||||
|
if (legacyIdMap) {
|
||||||
|
for (const [legacyId, canonicalId] of legacyIdMap) {
|
||||||
|
resolvedLegacyIdMap.set(legacyId, canonicalId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const tool of tools) {
|
||||||
|
if (tool.type !== 'mcp') {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const legacyId = toLegacyMcpToolId(tool.id)
|
||||||
|
if (!legacyId) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resolvedLegacyIdMap.set(legacyId, tool.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resolvedLegacyIdMap.size === 0) {
|
||||||
|
return allowedTools
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalized = allowedTools.map((toolId) => resolvedLegacyIdMap.get(toolId) ?? toolId)
|
||||||
|
return Array.from(new Set(normalized))
|
||||||
}
|
}
|
||||||
|
|
||||||
public async listSlashCommands(agentType: AgentType): Promise<SlashCommand[]> {
|
public async listSlashCommands(agentType: AgentType): Promise<SlashCommand[]> {
|
||||||
|
|||||||
@ -89,7 +89,9 @@ export class AgentService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse
|
const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse
|
||||||
agent.tools = await this.listMcpTools(agent.type, agent.mcps)
|
const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps)
|
||||||
|
agent.tools = tools
|
||||||
|
agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap)
|
||||||
|
|
||||||
// Load installed_plugins from cache file instead of database
|
// Load installed_plugins from cache file instead of database
|
||||||
const workdir = agent.accessible_paths?.[0]
|
const workdir = agent.accessible_paths?.[0]
|
||||||
@ -134,7 +136,9 @@ export class AgentService extends BaseService {
|
|||||||
const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[]
|
const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[]
|
||||||
|
|
||||||
for (const agent of agents) {
|
for (const agent of agents) {
|
||||||
agent.tools = await this.listMcpTools(agent.type, agent.mcps)
|
const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps)
|
||||||
|
agent.tools = tools
|
||||||
|
agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
return { agents, total: totalResult[0].count }
|
return { agents, total: totalResult[0].count }
|
||||||
|
|||||||
@ -156,7 +156,9 @@ export class SessionService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
||||||
session.tools = await this.listMcpTools(session.agent_type, session.mcps)
|
const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps)
|
||||||
|
session.tools = tools
|
||||||
|
session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap)
|
||||||
|
|
||||||
// If slash_commands is not in database yet (e.g., first invoke before init message),
|
// If slash_commands is not in database yet (e.g., first invoke before init message),
|
||||||
// fall back to builtin + local commands. Otherwise, use the merged commands from database.
|
// fall back to builtin + local commands. Otherwise, use the merged commands from database.
|
||||||
@ -202,6 +204,12 @@ export class SessionService extends BaseService {
|
|||||||
|
|
||||||
const sessions = result.map((row) => this.deserializeJsonFields(row)) as GetAgentSessionResponse[]
|
const sessions = result.map((row) => this.deserializeJsonFields(row)) as GetAgentSessionResponse[]
|
||||||
|
|
||||||
|
for (const session of sessions) {
|
||||||
|
const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps)
|
||||||
|
session.tools = tools
|
||||||
|
session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap)
|
||||||
|
}
|
||||||
|
|
||||||
return { sessions, total }
|
return { sessions, total }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
91
src/main/services/agents/tests/BaseService.test.ts
Normal file
91
src/main/services/agents/tests/BaseService.test.ts
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import type { Tool } from '@types'
|
||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
vi.mock('@main/apiServer/services/mcp', () => ({
|
||||||
|
mcpApiService: {
|
||||||
|
getServerInfo: vi.fn()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@main/apiServer/utils', () => ({
|
||||||
|
validateModelId: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
import { BaseService } from '../BaseService'
|
||||||
|
|
||||||
|
class TestBaseService extends BaseService {
|
||||||
|
public normalize(
|
||||||
|
allowedTools: string[] | undefined,
|
||||||
|
tools: Tool[],
|
||||||
|
legacyIdMap?: Map<string, string>
|
||||||
|
): string[] | undefined {
|
||||||
|
return this.normalizeAllowedTools(allowedTools, tools, legacyIdMap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const buildMcpTool = (id: string): Tool => ({
|
||||||
|
id,
|
||||||
|
name: id,
|
||||||
|
type: 'mcp',
|
||||||
|
description: 'test tool',
|
||||||
|
requirePermissions: true
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('BaseService.normalizeAllowedTools', () => {
|
||||||
|
const service = new TestBaseService()
|
||||||
|
|
||||||
|
it('returns undefined or empty inputs unchanged', () => {
|
||||||
|
expect(service.normalize(undefined, [])).toBeUndefined()
|
||||||
|
expect(service.normalize([], [])).toEqual([])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('normalizes legacy MCP tool IDs and deduplicates entries', () => {
|
||||||
|
const tools: Tool[] = [
|
||||||
|
buildMcpTool('mcp__server_one__tool_one'),
|
||||||
|
buildMcpTool('mcp__server_two__tool_two'),
|
||||||
|
{ id: 'custom_tool', name: 'custom_tool', type: 'custom' }
|
||||||
|
]
|
||||||
|
|
||||||
|
const legacyIdMap = new Map<string, string>([
|
||||||
|
['mcp__server-1__tool-one', 'mcp__server_one__tool_one'],
|
||||||
|
['mcp_server-1_tool-one', 'mcp__server_one__tool_one'],
|
||||||
|
['mcp__server-2__tool-two', 'mcp__server_two__tool_two']
|
||||||
|
])
|
||||||
|
|
||||||
|
const allowedTools = [
|
||||||
|
'mcp__server-1__tool-one',
|
||||||
|
'mcp_server-1_tool-one',
|
||||||
|
'mcp_server_one_tool_one',
|
||||||
|
'mcp__server_one__tool_one',
|
||||||
|
'custom_tool',
|
||||||
|
'mcp__server_two__tool_two',
|
||||||
|
'mcp_server_two_tool_two',
|
||||||
|
'mcp__server-2__tool-two'
|
||||||
|
]
|
||||||
|
|
||||||
|
expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([
|
||||||
|
'mcp__server_one__tool_one',
|
||||||
|
'custom_tool',
|
||||||
|
'mcp__server_two__tool_two'
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('keeps legacy IDs when no matching MCP tool exists', () => {
|
||||||
|
const tools: Tool[] = [buildMcpTool('mcp__server_one__tool_one')]
|
||||||
|
const legacyIdMap = new Map<string, string>([['mcp__server-1__tool-one', 'mcp__server_one__tool_one']])
|
||||||
|
|
||||||
|
const allowedTools = ['mcp__unknown__tool', 'mcp__server_one__tool_one']
|
||||||
|
|
||||||
|
expect(service.normalize(allowedTools, tools, legacyIdMap)).toEqual([
|
||||||
|
'mcp__unknown__tool',
|
||||||
|
'mcp__server_one__tool_one'
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns allowed tools unchanged when no MCP tools are available', () => {
|
||||||
|
const allowedTools = ['custom_tool', 'builtin_tool']
|
||||||
|
const tools: Tool[] = [{ id: 'custom_tool', name: 'custom_tool', type: 'custom' }]
|
||||||
|
|
||||||
|
expect(service.normalize(allowedTools, tools)).toEqual(allowedTools)
|
||||||
|
})
|
||||||
|
})
|
||||||
Loading…
Reference in New Issue
Block a user