From d6468f33c5fe20c78c55169c3acb7bdd5664b4be Mon Sep 17 00:00:00 2001 From: Vaayne Date: Fri, 19 Sep 2025 19:54:38 +0800 Subject: [PATCH] feat(agents): implement model validation for agent and session creation/updating --- .../routes/agents/handlers/agents.ts | 46 ++++++++++++- .../routes/agents/handlers/sessions.ts | 55 ++++++++++++++-- src/main/services/agents/BaseService.ts | 64 ++++++++++++++++++- src/main/services/agents/README.md | 9 ++- src/main/services/agents/errors.ts | 23 +++++++ src/main/services/agents/index.ts | 3 + .../services/agents/services/AgentService.ts | 20 ++++++ .../agents/services/SessionService.ts | 20 ++++++ .../agents/services/claudecode/index.ts | 8 +-- 9 files changed, 232 insertions(+), 16 deletions(-) create mode 100644 src/main/services/agents/errors.ts diff --git a/src/main/apiServer/routes/agents/handlers/agents.ts b/src/main/apiServer/routes/agents/handlers/agents.ts index aa021af063..1ae661c726 100644 --- a/src/main/apiServer/routes/agents/handlers/agents.ts +++ b/src/main/apiServer/routes/agents/handlers/agents.ts @@ -1,12 +1,20 @@ import { loggerService } from '@logger' +import { AgentModelValidationError, agentService } from '@main/services/agents' import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types' import { Request, Response } from 'express' -import { agentService } from '../../../../services/agents' import type { ValidationRequest } from '../validators/zodValidator' const logger = loggerService.withContext('ApiServerAgentsHandlers') +const modelValidationErrorBody = (error: AgentModelValidationError) => ({ + error: { + message: `Invalid ${error.context.field}: ${error.detail.message}`, + type: 'invalid_request_error', + code: error.detail.code + } +}) + /** * @swagger * /v1/agents: @@ -50,6 +58,16 @@ export const createAgent = async (req: Request, res: Response): Promise = * $ref: '#/components/schemas/Error' */ export const updateAgent = async (req: Request, res: Response): Promise => { + const { agentId } = req.params try { - const { agentId } = req.params logger.info(`Updating agent: ${agentId}`) logger.debug('Update data:', req.body) @@ -283,6 +301,17 @@ export const updateAgent = async (req: Request, res: Response): Promise => { + const { agentId } = req.params try { - const { agentId } = req.params logger.info(`Partially updating agent: ${agentId}`) logger.debug('Partial update data:', req.body) @@ -418,6 +447,17 @@ export const patchAgent = async (req: Request, res: Response): Promise logger.info(`Agent partially updated successfully: ${agentId}`) return res.json(agent) } catch (error: any) { + if (error instanceof AgentModelValidationError) { + logger.warn('Agent model validation error during partial update:', { + agentId, + agentType: error.context.agentType, + field: error.context.field, + model: error.context.model, + detail: error.detail + }) + return res.status(400).json(modelValidationErrorBody(error)) + } + logger.error('Error partially updating agent:', error) return res.status(500).json({ error: { diff --git a/src/main/apiServer/routes/agents/handlers/sessions.ts b/src/main/apiServer/routes/agents/handlers/sessions.ts index 8b9b161048..81d3036a18 100644 --- a/src/main/apiServer/routes/agents/handlers/sessions.ts +++ b/src/main/apiServer/routes/agents/handlers/sessions.ts @@ -1,5 +1,9 @@ import { loggerService } from '@logger' -import { sessionMessageService, sessionService } from '@main/services/agents' +import { + AgentModelValidationError, + sessionMessageService, + sessionService +} from '@main/services/agents' import { CreateSessionResponse, ListAgentSessionsResponse, @@ -12,9 +16,17 @@ import type { ValidationRequest } from '../validators/zodValidator' const logger = loggerService.withContext('ApiServerSessionsHandlers') +const modelValidationErrorBody = (error: AgentModelValidationError) => ({ + error: { + message: `Invalid ${error.context.field}: ${error.detail.message}`, + type: 'invalid_request_error', + code: error.detail.code + } +}) + export const createSession = async (req: Request, res: Response): Promise => { + const { agentId } = req.params try { - const { agentId } = req.params const sessionData = req.body logger.info(`Creating new session for agent: ${agentId}`) @@ -25,6 +37,17 @@ export const createSession = async (req: Request, res: Response): Promise } export const updateSession = async (req: Request, res: Response): Promise => { + const { agentId, sessionId } = req.params try { - const { agentId, sessionId } = req.params logger.info(`Updating session: ${sessionId} for agent: ${agentId}`) logger.debug('Update data:', req.body) @@ -157,6 +180,18 @@ export const updateSession = async (req: Request, res: Response): Promise => { + const { agentId, sessionId } = req.params try { - const { agentId, sessionId } = req.params logger.info(`Patching session: ${sessionId} for agent: ${agentId}`) logger.debug('Patch data:', req.body) @@ -204,6 +239,18 @@ export const patchSession = async (req: Request, res: Response): Promise> + ): Promise { + const entries = Object.entries(models) as [AgentModelField, string | undefined][] + if (entries.length === 0) { + return + } + + for (const [field, rawValue] of entries) { + if (rawValue === undefined || rawValue === null) { + continue + } + + const modelValue = rawValue + const validation = await validateModelId(modelValue) + + if (!validation.valid || !validation.provider) { + const detail: ModelValidationError = validation.error ?? { + type: 'invalid_format', + message: 'Unknown model validation error', + code: 'validation_error' + } + + throw new AgentModelValidationError({ agentType, field, model: modelValue }, detail) + } + + if (!validation.provider.apiKey) { + throw new AgentModelValidationError( + { agentType, field, model: modelValue }, + { + type: 'invalid_format', + message: `Provider '${validation.provider.id}' is missing an API key`, + code: 'provider_api_key_missing' + } + ) + } + + // different agent types may have different provider requirements + const agentTypeProviderRequirements: Record = { + 'claude-code': 'anthropic' + } + for (const [ak, pk] of Object.entries(agentTypeProviderRequirements)) { + if (agentType === ak && validation.provider.type !== pk) { + throw new AgentModelValidationError( + { agentType, field, model: modelValue }, + { + type: 'unsupported_provider_type', + message: `Provider type '${validation.provider.type}' is not supported for agent type '${agentType}'. Expected '${pk}'`, + code: 'unsupported_provider_type' + } + ) + } + } + } + } + static async reinitialize(): Promise { BaseService.isInitialized = false BaseService.initializationPromise = null diff --git a/src/main/services/agents/README.md b/src/main/services/agents/README.md index 542fba6f61..986ac8b8df 100644 --- a/src/main/services/agents/README.md +++ b/src/main/services/agents/README.md @@ -8,6 +8,7 @@ Simplified Drizzle ORM implementation for agent and session management in Cherry - **Zero CLI dependencies** in production - **Auto-initialization** with retry logic - **Full TypeScript** type safety +- **Model validation** to ensure models exist and provider configuration matches the agent type ## Schema @@ -24,10 +25,16 @@ import { agentService } from './services' const agent = await agentService.createAgent({ type: 'custom', name: 'My Agent', - model: 'claude-3-5-sonnet-20241022' + model: 'anthropic:claude-3-5-sonnet-20241022' }) ``` +## Model Validation + +- Model identifiers must use the `provider:model_id` format (for example `anthropic:claude-3-5-sonnet-20241022`). +- `model`, `plan_model`, and `small_model` are validated against the configured providers before the database is touched. +- Invalid configurations return a `400 invalid_request_error` response and the create/update operation is aborted. + ## Development Commands ```bash diff --git a/src/main/services/agents/errors.ts b/src/main/services/agents/errors.ts new file mode 100644 index 0000000000..a831867afc --- /dev/null +++ b/src/main/services/agents/errors.ts @@ -0,0 +1,23 @@ +import { ModelValidationError } from '@main/apiServer/utils' +import { AgentType } from '@types' + +export type AgentModelField = 'model' | 'plan_model' | 'small_model' + +export interface AgentModelValidationContext { + agentType: AgentType + field: AgentModelField + model?: string +} + +export class AgentModelValidationError extends Error { + readonly context: AgentModelValidationContext + readonly detail: ModelValidationError + + constructor(context: AgentModelValidationContext, detail: ModelValidationError) { + super(`Validation failed for ${context.agentType}.${context.field}: ${detail.message}`) + this.name = 'AgentModelValidationError' + this.context = context + this.detail = detail + } +} + diff --git a/src/main/services/agents/index.ts b/src/main/services/agents/index.ts index 9440889033..00409e7c64 100644 --- a/src/main/services/agents/index.ts +++ b/src/main/services/agents/index.ts @@ -13,6 +13,9 @@ // Main service classes and singleton instances export * from './services' +// === Error Types === +export { type AgentModelField, AgentModelValidationError } from './errors' + // === Base Infrastructure === // Shared database utilities and base service class export { BaseService } from './BaseService' diff --git a/src/main/services/agents/services/AgentService.ts b/src/main/services/agents/services/AgentService.ts index ff2dcb18f2..9d3c372647 100644 --- a/src/main/services/agents/services/AgentService.ts +++ b/src/main/services/agents/services/AgentService.ts @@ -15,10 +15,12 @@ import { count, eq } from 'drizzle-orm' import { BaseService } from '../BaseService' import { type AgentRow, agentsTable, type InsertAgentRow } from '../database/schema' +import { AgentModelField } from '../errors' import { builtinTools } from './claudecode/tools' export class AgentService extends BaseService { private static instance: AgentService | null = null + private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model'] static getInstance(): AgentService { if (!AgentService.instance) { @@ -43,6 +45,12 @@ export class AgentService extends BaseService { req.accessible_paths = [defaultPath] } + await this.validateAgentModels(req.type, { + model: req.model, + plan_model: req.plan_model, + small_model: req.small_model + }) + this.ensurePathsExist(req.accessible_paths) const serializedReq = this.serializeJsonFields(req) @@ -132,6 +140,18 @@ export class AgentService extends BaseService { if (updates.accessible_paths) { this.ensurePathsExist(updates.accessible_paths) } + + const modelUpdates: Partial> = {} + for (const field of this.modelFields) { + if (Object.prototype.hasOwnProperty.call(updates, field)) { + modelUpdates[field] = updates[field as keyof UpdateAgentRequest] as string | undefined + } + } + + if (Object.keys(modelUpdates).length > 0) { + await this.validateAgentModels(existing.type, modelUpdates) + } + const serializedUpdates = this.serializeJsonFields(updates) const updateData: Partial = { diff --git a/src/main/services/agents/services/SessionService.ts b/src/main/services/agents/services/SessionService.ts index eb1820cbd6..3c328f6ee9 100644 --- a/src/main/services/agents/services/SessionService.ts +++ b/src/main/services/agents/services/SessionService.ts @@ -13,9 +13,11 @@ import { and, count, eq, type SQL } from 'drizzle-orm' import { BaseService } from '../BaseService' import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema' +import { AgentModelField } from '../errors' export class SessionService extends BaseService { private static instance: SessionService | null = null + private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model'] static getInstance(): SessionService { if (!SessionService.instance) { @@ -50,6 +52,12 @@ export class SessionService extends BaseService { ...req } + await this.validateAgentModels(agent.type, { + model: sessionData.model, + plan_model: sessionData.plan_model, + small_model: sessionData.small_model + }) + this.ensurePathsExist(sessionData.accessible_paths) const serializedData = this.serializeJsonFields(sessionData) @@ -174,6 +182,18 @@ export class SessionService extends BaseService { if (updates.accessible_paths) { this.ensurePathsExist(updates.accessible_paths) } + + const modelUpdates: Partial> = {} + for (const field of this.modelFields) { + if (Object.prototype.hasOwnProperty.call(updates, field)) { + modelUpdates[field] = updates[field as keyof UpdateSessionRequest] as string | undefined + } + } + + if (Object.keys(modelUpdates).length > 0) { + await this.validateAgentModels(existing.agent_type, modelUpdates) + } + const serializedUpdates = this.serializeJsonFields(updates) const updateData: Partial = { diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts index 46ad151323..42ac0747e6 100644 --- a/src/main/services/agents/services/claudecode/index.ts +++ b/src/main/services/agents/services/claudecode/index.ts @@ -50,14 +50,12 @@ class ClaudeCodeService implements AgentServiceInterface { return aiStream } - // Validate model - const modelId = session.model - logger.info('Invoking Claude Code with model', { modelId, cwd }) - const modelInfo = await validateModelId(modelId) + // Validate model info + const modelInfo = await validateModelId(session.model) if (!modelInfo.valid) { aiStream.emit('data', { type: 'error', - error: new Error(`Invalid model ID '${modelId}': ${JSON.stringify(modelInfo.error)}`) + error: new Error(`Invalid model ID '${session.model}': ${JSON.stringify(modelInfo.error)}`) }) return aiStream }