mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 07:19:02 +08:00
* fix(agents): align MCP tool IDs for permissions Normalize legacy MCP allowlist entries so auto-approval matches SDK tool names. Signed-off-by: mathholic <h.p.zhumeng@gmail.com> * fix: normalize mcp tool ids in sessions Signed-off-by: macmini <h.p.zhumeng@gmail.com> * fix: align mcp tool ids with buildFunctionCallToolName --------- Signed-off-by: mathholic <h.p.zhumeng@gmail.com> Signed-off-by: macmini <h.p.zhumeng@gmail.com>
219 lines
7.2 KiB
TypeScript
219 lines
7.2 KiB
TypeScript
import path from 'node:path'
|
|
|
|
import { loggerService } from '@logger'
|
|
import { pluginService } from '@main/services/agents/plugins/PluginService'
|
|
import { getDataPath } from '@main/utils'
|
|
import type {
|
|
AgentEntity,
|
|
CreateAgentRequest,
|
|
CreateAgentResponse,
|
|
GetAgentResponse,
|
|
ListOptions,
|
|
UpdateAgentRequest,
|
|
UpdateAgentResponse
|
|
} from '@types'
|
|
import { AgentBaseSchema } from '@types'
|
|
import { asc, count, desc, eq } from 'drizzle-orm'
|
|
|
|
import { BaseService } from '../BaseService'
|
|
import { type AgentRow, agentsTable, type InsertAgentRow } from '../database/schema'
|
|
import type { AgentModelField } from '../errors'
|
|
|
|
const logger = loggerService.withContext('AgentService')
|
|
|
|
export class AgentService extends BaseService {
|
|
private static instance: AgentService | null = null
|
|
private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model']
|
|
|
|
static getInstance(): AgentService {
|
|
if (!AgentService.instance) {
|
|
AgentService.instance = new AgentService()
|
|
}
|
|
return AgentService.instance
|
|
}
|
|
|
|
// Agent Methods
|
|
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
|
|
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
|
|
const now = new Date().toISOString()
|
|
|
|
if (!req.accessible_paths || req.accessible_paths.length === 0) {
|
|
const defaultPath = path.join(getDataPath(), 'agents', id)
|
|
req.accessible_paths = [defaultPath]
|
|
}
|
|
|
|
if (req.accessible_paths !== undefined) {
|
|
req.accessible_paths = this.ensurePathsExist(req.accessible_paths)
|
|
}
|
|
|
|
await this.validateAgentModels(req.type, {
|
|
model: req.model,
|
|
plan_model: req.plan_model,
|
|
small_model: req.small_model
|
|
})
|
|
|
|
const serializedReq = this.serializeJsonFields(req)
|
|
|
|
const insertData: InsertAgentRow = {
|
|
id,
|
|
type: req.type,
|
|
name: req.name || 'New Agent',
|
|
description: req.description,
|
|
instructions: req.instructions || 'You are a helpful assistant.',
|
|
model: req.model,
|
|
plan_model: req.plan_model,
|
|
small_model: req.small_model,
|
|
configuration: serializedReq.configuration,
|
|
accessible_paths: serializedReq.accessible_paths,
|
|
created_at: now,
|
|
updated_at: now
|
|
}
|
|
|
|
const database = await this.getDatabase()
|
|
await database.insert(agentsTable).values(insertData)
|
|
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
|
if (!result[0]) {
|
|
throw new Error('Failed to create agent')
|
|
}
|
|
|
|
const agent = this.deserializeJsonFields(result[0]) as AgentEntity
|
|
return agent
|
|
}
|
|
|
|
async getAgent(id: string): Promise<GetAgentResponse | null> {
|
|
const database = await this.getDatabase()
|
|
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
|
|
|
if (!result[0]) {
|
|
return null
|
|
}
|
|
|
|
const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse
|
|
const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps)
|
|
agent.tools = tools
|
|
agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap)
|
|
|
|
// Load installed_plugins from cache file instead of database
|
|
const workdir = agent.accessible_paths?.[0]
|
|
if (workdir) {
|
|
try {
|
|
agent.installed_plugins = await pluginService.listInstalledFromCache(workdir)
|
|
} catch (error) {
|
|
// Log error but don't fail the request
|
|
logger.warn(`Failed to load installed plugins for agent ${id}`, {
|
|
workdir,
|
|
error: error instanceof Error ? error.message : String(error)
|
|
})
|
|
agent.installed_plugins = []
|
|
}
|
|
} else {
|
|
agent.installed_plugins = []
|
|
}
|
|
|
|
return agent
|
|
}
|
|
|
|
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
|
|
// Build query with pagination
|
|
const database = await this.getDatabase()
|
|
const totalResult = await database.select({ count: count() }).from(agentsTable)
|
|
|
|
const sortBy = options.sortBy || 'created_at'
|
|
const orderBy = options.orderBy || 'desc'
|
|
|
|
const sortField = agentsTable[sortBy]
|
|
const orderFn = orderBy === 'asc' ? asc : desc
|
|
|
|
const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
|
|
|
|
const result =
|
|
options.limit !== undefined
|
|
? options.offset !== undefined
|
|
? await baseQuery.limit(options.limit).offset(options.offset)
|
|
: await baseQuery.limit(options.limit)
|
|
: await baseQuery
|
|
|
|
const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[]
|
|
|
|
for (const agent of agents) {
|
|
const { tools, legacyIdMap } = await this.listMcpTools(agent.type, agent.mcps)
|
|
agent.tools = tools
|
|
agent.allowed_tools = this.normalizeAllowedTools(agent.allowed_tools, agent.tools, legacyIdMap)
|
|
}
|
|
|
|
return { agents, total: totalResult[0].count }
|
|
}
|
|
|
|
async updateAgent(
|
|
id: string,
|
|
updates: UpdateAgentRequest,
|
|
options: { replace?: boolean } = {}
|
|
): Promise<UpdateAgentResponse | null> {
|
|
// Check if agent exists
|
|
const existing = await this.getAgent(id)
|
|
if (!existing) {
|
|
return null
|
|
}
|
|
|
|
const now = new Date().toISOString()
|
|
|
|
if (updates.accessible_paths !== undefined) {
|
|
updates.accessible_paths = this.ensurePathsExist(updates.accessible_paths)
|
|
}
|
|
|
|
const modelUpdates: Partial<Record<AgentModelField, string | undefined>> = {}
|
|
for (const field of this.modelFields) {
|
|
if (Object.prototype.hasOwnProperty.call(updates, field)) {
|
|
modelUpdates[field] = updates[field as keyof UpdateAgentRequest] as string | undefined
|
|
}
|
|
}
|
|
|
|
if (Object.keys(modelUpdates).length > 0) {
|
|
await this.validateAgentModels(existing.type, modelUpdates)
|
|
}
|
|
|
|
const serializedUpdates = this.serializeJsonFields(updates)
|
|
|
|
const updateData: Partial<AgentRow> = {
|
|
updated_at: now
|
|
}
|
|
const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof AgentRow)[]
|
|
const shouldReplace = options.replace ?? false
|
|
|
|
for (const field of replaceableFields) {
|
|
if (shouldReplace || Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
|
|
if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
|
|
const value = serializedUpdates[field as keyof typeof serializedUpdates]
|
|
;(updateData as Record<string, unknown>)[field] = value ?? null
|
|
} else if (shouldReplace) {
|
|
;(updateData as Record<string, unknown>)[field] = null
|
|
}
|
|
}
|
|
}
|
|
|
|
const database = await this.getDatabase()
|
|
await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
|
|
return await this.getAgent(id)
|
|
}
|
|
|
|
async deleteAgent(id: string): Promise<boolean> {
|
|
const database = await this.getDatabase()
|
|
const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
|
|
|
|
return result.rowsAffected > 0
|
|
}
|
|
|
|
async agentExists(id: string): Promise<boolean> {
|
|
const database = await this.getDatabase()
|
|
const result = await database
|
|
.select({ id: agentsTable.id })
|
|
.from(agentsTable)
|
|
.where(eq(agentsTable.id, id))
|
|
.limit(1)
|
|
|
|
return result.length > 0
|
|
}
|
|
}
|
|
|
|
export const agentService = AgentService.getInstance()
|