cherry-studio/src/main/services/agents/services/SessionService.ts
nujabse ed4353b054
fix: align MCP tool ids for permissions (#12127)
* fix(agents): align MCP tool IDs for permissions

Normalize legacy MCP allowlist entries so auto-approval matches SDK tool names.

Signed-off-by: mathholic <h.p.zhumeng@gmail.com>

* fix: normalize mcp tool ids in sessions

Signed-off-by: macmini <h.p.zhumeng@gmail.com>

* fix: align mcp tool ids with buildFunctionCallToolName

---------

Signed-off-by: mathholic <h.p.zhumeng@gmail.com>
Signed-off-by: macmini <h.p.zhumeng@gmail.com>
2025-12-30 13:33:09 +08:00

289 lines
9.8 KiB
TypeScript

import { loggerService } from '@logger'
import type { SlashCommand, UpdateSessionResponse } from '@types'
import {
AgentBaseSchema,
type AgentEntity,
type AgentSessionEntity,
type CreateSessionRequest,
type GetAgentSessionResponse,
type ListOptions,
type UpdateSessionRequest
} from '@types'
import { and, count, desc, eq, type SQL } from 'drizzle-orm'
import { BaseService } from '../BaseService'
import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema'
import type { AgentModelField } from '../errors'
import { pluginService } from '../plugins/PluginService'
import { builtinSlashCommands } from './claudecode/commands'
const logger = loggerService.withContext('SessionService')
export class SessionService extends BaseService {
private static instance: SessionService | null = null
private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model']
static getInstance(): SessionService {
if (!SessionService.instance) {
SessionService.instance = new SessionService()
}
return SessionService.instance
}
/**
* Override BaseService.listSlashCommands to merge builtin and plugin commands
*/
async listSlashCommands(agentType: string, agentId?: string): Promise<SlashCommand[]> {
const commands: SlashCommand[] = []
// Add builtin slash commands
if (agentType === 'claude-code') {
commands.push(...builtinSlashCommands)
}
// Add local command plugins from .claude/commands/
if (agentId) {
try {
const installedPlugins = await pluginService.listInstalled(agentId)
// Filter for command type plugins
const commandPlugins = installedPlugins.filter((p) => p.type === 'command')
// Convert plugin metadata to SlashCommand format
for (const plugin of commandPlugins) {
const commandName = plugin.metadata.filename.replace(/\.md$/i, '')
commands.push({
command: `/${commandName}`,
description: plugin.metadata.description
})
}
logger.info('Listed slash commands', {
agentType,
agentId,
builtinCount: builtinSlashCommands.length,
localCount: commandPlugins.length,
totalCount: commands.length
})
} catch (error) {
logger.warn('Failed to list local command plugins', {
agentId,
error: error instanceof Error ? error.message : String(error)
})
}
}
return commands
}
async createSession(
agentId: string,
req: Partial<CreateSessionRequest> = {}
): Promise<GetAgentSessionResponse | null> {
// Validate agent exists - we'll need to import AgentService for this check
// For now, we'll skip this validation to avoid circular dependencies
// The database foreign key constraint will handle this
const database = await this.getDatabase()
const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
if (!agents[0]) {
throw new Error('Agent not found')
}
const agent = this.deserializeJsonFields(agents[0]) as AgentEntity
const id = `session_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
const now = new Date().toISOString()
// inherit configuration from agent by default, can be overridden by sessionData
const sessionData: Partial<CreateSessionRequest> = {
...agent,
...req
}
await this.validateAgentModels(agent.type, {
model: sessionData.model,
plan_model: sessionData.plan_model,
small_model: sessionData.small_model
})
if (sessionData.accessible_paths !== undefined) {
sessionData.accessible_paths = this.ensurePathsExist(sessionData.accessible_paths)
}
const serializedData = this.serializeJsonFields(sessionData)
const insertData: InsertSessionRow = {
id,
agent_id: agentId,
agent_type: agent.type,
name: serializedData.name || null,
description: serializedData.description || null,
accessible_paths: serializedData.accessible_paths || null,
instructions: serializedData.instructions || null,
model: serializedData.model || null,
plan_model: serializedData.plan_model || null,
small_model: serializedData.small_model || null,
mcps: serializedData.mcps || null,
allowed_tools: serializedData.allowed_tools || null,
configuration: serializedData.configuration || null,
created_at: now,
updated_at: now
}
const db = await this.getDatabase()
await db.insert(sessionsTable).values(insertData)
const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create session')
}
const session = this.deserializeJsonFields(result[0])
return await this.getSession(agentId, session.id)
}
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
const database = await this.getDatabase()
const result = await database
.select()
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
.limit(1)
if (!result[0]) {
return null
}
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps)
session.tools = tools
session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap)
// If slash_commands is not in database yet (e.g., first invoke before init message),
// fall back to builtin + local commands. Otherwise, use the merged commands from database.
if (!session.slash_commands || session.slash_commands.length === 0) {
session.slash_commands = await this.listSlashCommands(session.agent_type, agentId)
}
return session
}
async listSessions(
agentId?: string,
options: ListOptions = {}
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
// Build where conditions
const whereConditions: SQL[] = []
if (agentId) {
whereConditions.push(eq(sessionsTable.agent_id, agentId))
}
const whereClause =
whereConditions.length > 1
? and(...whereConditions)
: whereConditions.length === 1
? whereConditions[0]
: undefined
// Get total count
const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
const total = totalResult[0].count
// Build list query with pagination - sort by updated_at descending (latest first)
const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
const result =
options.limit !== undefined
? options.offset !== undefined
? await baseQuery.limit(options.limit).offset(options.offset)
: await baseQuery.limit(options.limit)
: await baseQuery
const sessions = result.map((row) => this.deserializeJsonFields(row)) as GetAgentSessionResponse[]
for (const session of sessions) {
const { tools, legacyIdMap } = await this.listMcpTools(session.agent_type, session.mcps)
session.tools = tools
session.allowed_tools = this.normalizeAllowedTools(session.allowed_tools, session.tools, legacyIdMap)
}
return { sessions, total }
}
async updateSession(
agentId: string,
id: string,
updates: UpdateSessionRequest
): Promise<UpdateSessionResponse | null> {
// Check if session exists
const existing = await this.getSession(agentId, id)
if (!existing) {
return null
}
// Validate agent exists if changing main_agent_id
// We'll skip this validation for now to avoid circular dependencies
const now = new Date().toISOString()
if (updates.accessible_paths !== undefined) {
updates.accessible_paths = this.ensurePathsExist(updates.accessible_paths)
}
const modelUpdates: Partial<Record<AgentModelField, string | undefined>> = {}
for (const field of this.modelFields) {
if (Object.prototype.hasOwnProperty.call(updates, field)) {
modelUpdates[field] = updates[field as keyof UpdateSessionRequest] as string | undefined
}
}
if (Object.keys(modelUpdates).length > 0) {
await this.validateAgentModels(existing.agent_type, modelUpdates)
}
const serializedUpdates = this.serializeJsonFields(updates)
const updateData: Partial<SessionRow> = {
updated_at: now
}
const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof SessionRow)[]
for (const field of replaceableFields) {
if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
const value = serializedUpdates[field as keyof typeof serializedUpdates]
;(updateData as Record<string, unknown>)[field] = value ?? null
}
}
const database = await this.getDatabase()
await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
return await this.getSession(agentId, id)
}
async deleteSession(agentId: string, id: string): Promise<boolean> {
const database = await this.getDatabase()
const result = await database
.delete(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
return result.rowsAffected > 0
}
async sessionExists(agentId: string, id: string): Promise<boolean> {
const database = await this.getDatabase()
const result = await database
.select({ id: sessionsTable.id })
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
.limit(1)
return result.length > 0
}
}
export const sessionService = SessionService.getInstance()