feat(agents): implement model validation for agent and session creation/updating

This commit is contained in:
Vaayne 2025-09-19 19:54:38 +08:00
parent 1515f511a1
commit d6468f33c5
9 changed files with 232 additions and 16 deletions

View File

@ -1,12 +1,20 @@
import { loggerService } from '@logger'
import { AgentModelValidationError, agentService } from '@main/services/agents'
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
import { Request, Response } from 'express'
import { agentService } from '../../../../services/agents'
import type { ValidationRequest } from '../validators/zodValidator'
const logger = loggerService.withContext('ApiServerAgentsHandlers')
const modelValidationErrorBody = (error: AgentModelValidationError) => ({
error: {
message: `Invalid ${error.context.field}: ${error.detail.message}`,
type: 'invalid_request_error',
code: error.detail.code
}
})
/**
* @swagger
* /v1/agents:
@ -50,6 +58,16 @@ export const createAgent = async (req: Request, res: Response): Promise<Response
logger.info(`Agent created successfully: ${agent.id}`)
return res.status(201).json(agent)
} catch (error: any) {
if (error instanceof AgentModelValidationError) {
logger.warn('Agent model validation error during create:', {
agentType: error.context.agentType,
field: error.context.field,
model: error.context.model,
detail: error.detail
})
return res.status(400).json(modelValidationErrorBody(error))
}
logger.error('Error creating agent:', error)
return res.status(500).json({
error: {
@ -259,8 +277,8 @@ export const getAgent = async (req: Request, res: Response): Promise<Response> =
* $ref: '#/components/schemas/Error'
*/
export const updateAgent = async (req: Request, res: Response): Promise<Response> => {
const { agentId } = req.params
try {
const { agentId } = req.params
logger.info(`Updating agent: ${agentId}`)
logger.debug('Update data:', req.body)
@ -283,6 +301,17 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
logger.info(`Agent updated successfully: ${agentId}`)
return res.json(agent)
} catch (error: any) {
if (error instanceof AgentModelValidationError) {
logger.warn('Agent model validation error during update:', {
agentId,
agentType: error.context.agentType,
field: error.context.field,
model: error.context.model,
detail: error.detail
})
return res.status(400).json(modelValidationErrorBody(error))
}
logger.error('Error updating agent:', error)
return res.status(500).json({
error: {
@ -394,8 +423,8 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
* $ref: '#/components/schemas/Error'
*/
export const patchAgent = async (req: Request, res: Response): Promise<Response> => {
const { agentId } = req.params
try {
const { agentId } = req.params
logger.info(`Partially updating agent: ${agentId}`)
logger.debug('Partial update data:', req.body)
@ -418,6 +447,17 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
logger.info(`Agent partially updated successfully: ${agentId}`)
return res.json(agent)
} catch (error: any) {
if (error instanceof AgentModelValidationError) {
logger.warn('Agent model validation error during partial update:', {
agentId,
agentType: error.context.agentType,
field: error.context.field,
model: error.context.model,
detail: error.detail
})
return res.status(400).json(modelValidationErrorBody(error))
}
logger.error('Error partially updating agent:', error)
return res.status(500).json({
error: {

View File

@ -1,5 +1,9 @@
import { loggerService } from '@logger'
import { sessionMessageService, sessionService } from '@main/services/agents'
import {
AgentModelValidationError,
sessionMessageService,
sessionService
} from '@main/services/agents'
import {
CreateSessionResponse,
ListAgentSessionsResponse,
@ -12,9 +16,17 @@ import type { ValidationRequest } from '../validators/zodValidator'
const logger = loggerService.withContext('ApiServerSessionsHandlers')
const modelValidationErrorBody = (error: AgentModelValidationError) => ({
error: {
message: `Invalid ${error.context.field}: ${error.detail.message}`,
type: 'invalid_request_error',
code: error.detail.code
}
})
export const createSession = async (req: Request, res: Response): Promise<Response> => {
const { agentId } = req.params
try {
const { agentId } = req.params
const sessionData = req.body
logger.info(`Creating new session for agent: ${agentId}`)
@ -25,6 +37,17 @@ export const createSession = async (req: Request, res: Response): Promise<Respon
logger.info(`Session created successfully: ${session.id}`)
return res.status(201).json(session)
} catch (error: any) {
if (error instanceof AgentModelValidationError) {
logger.warn('Session model validation error during create:', {
agentId,
agentType: error.context.agentType,
field: error.context.field,
model: error.context.model,
detail: error.detail
})
return res.status(400).json(modelValidationErrorBody(error))
}
logger.error('Error creating session:', error)
return res.status(500).json({
error: {
@ -120,8 +143,8 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
}
export const updateSession = async (req: Request, res: Response): Promise<Response> => {
const { agentId, sessionId } = req.params
try {
const { agentId, sessionId } = req.params
logger.info(`Updating session: ${sessionId} for agent: ${agentId}`)
logger.debug('Update data:', req.body)
@ -157,6 +180,18 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
logger.info(`Session updated successfully: ${sessionId}`)
return res.json(session satisfies UpdateSessionResponse)
} catch (error: any) {
if (error instanceof AgentModelValidationError) {
logger.warn('Session model validation error during update:', {
agentId,
sessionId,
agentType: error.context.agentType,
field: error.context.field,
model: error.context.model,
detail: error.detail
})
return res.status(400).json(modelValidationErrorBody(error))
}
logger.error('Error updating session:', error)
return res.status(500).json({
error: {
@ -169,8 +204,8 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
}
export const patchSession = async (req: Request, res: Response): Promise<Response> => {
const { agentId, sessionId } = req.params
try {
const { agentId, sessionId } = req.params
logger.info(`Patching session: ${sessionId} for agent: ${agentId}`)
logger.debug('Patch data:', req.body)
@ -204,6 +239,18 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
logger.info(`Session patched successfully: ${sessionId}`)
return res.json(session)
} catch (error: any) {
if (error instanceof AgentModelValidationError) {
logger.warn('Session model validation error during patch:', {
agentId,
sessionId,
agentType: error.context.agentType,
field: error.context.field,
model: error.context.model,
detail: error.detail
})
return res.status(400).json(modelValidationErrorBody(error))
}
logger.error('Error patching session:', error)
return res.status(500).json({
error: {

View File

@ -1,6 +1,7 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import { objectKeys } from '@types'
import { ModelValidationError, validateModelId } from '@main/apiServer/utils'
import { AgentType, objectKeys, Provider } from '@types'
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
@ -8,6 +9,7 @@ import path from 'path'
import { MigrationService } from './database/MigrationService'
import * as schema from './database/schema'
import { dbPath } from './drizzle.config'
import { AgentModelField, AgentModelValidationError } from './errors'
const logger = loggerService.withContext('BaseService')
@ -186,8 +188,7 @@ export abstract class BaseService {
}
const looksLikeFile =
(stats && stats.isFile()) ||
(!stats && path.extname(resolvedPath) !== '' && !resolvedPath.endsWith(path.sep))
(stats && stats.isFile()) || (!stats && path.extname(resolvedPath) !== '' && !resolvedPath.endsWith(path.sep))
const directoryToEnsure = looksLikeFile ? path.dirname(resolvedPath) : resolvedPath
@ -208,6 +209,63 @@ export abstract class BaseService {
/**
* Force re-initialization (for development/testing)
*/
protected async validateAgentModels(
agentType: AgentType,
models: Partial<Record<AgentModelField, string | undefined>>
): Promise<void> {
const entries = Object.entries(models) as [AgentModelField, string | undefined][]
if (entries.length === 0) {
return
}
for (const [field, rawValue] of entries) {
if (rawValue === undefined || rawValue === null) {
continue
}
const modelValue = rawValue
const validation = await validateModelId(modelValue)
if (!validation.valid || !validation.provider) {
const detail: ModelValidationError = validation.error ?? {
type: 'invalid_format',
message: 'Unknown model validation error',
code: 'validation_error'
}
throw new AgentModelValidationError({ agentType, field, model: modelValue }, detail)
}
if (!validation.provider.apiKey) {
throw new AgentModelValidationError(
{ agentType, field, model: modelValue },
{
type: 'invalid_format',
message: `Provider '${validation.provider.id}' is missing an API key`,
code: 'provider_api_key_missing'
}
)
}
// different agent types may have different provider requirements
const agentTypeProviderRequirements: Record<AgentType, Provider['type']> = {
'claude-code': 'anthropic'
}
for (const [ak, pk] of Object.entries(agentTypeProviderRequirements)) {
if (agentType === ak && validation.provider.type !== pk) {
throw new AgentModelValidationError(
{ agentType, field, model: modelValue },
{
type: 'unsupported_provider_type',
message: `Provider type '${validation.provider.type}' is not supported for agent type '${agentType}'. Expected '${pk}'`,
code: 'unsupported_provider_type'
}
)
}
}
}
}
static async reinitialize(): Promise<void> {
BaseService.isInitialized = false
BaseService.initializationPromise = null

View File

@ -8,6 +8,7 @@ Simplified Drizzle ORM implementation for agent and session management in Cherry
- **Zero CLI dependencies** in production
- **Auto-initialization** with retry logic
- **Full TypeScript** type safety
- **Model validation** to ensure models exist and provider configuration matches the agent type
## Schema
@ -24,10 +25,16 @@ import { agentService } from './services'
const agent = await agentService.createAgent({
type: 'custom',
name: 'My Agent',
model: 'claude-3-5-sonnet-20241022'
model: 'anthropic:claude-3-5-sonnet-20241022'
})
```
## Model Validation
- Model identifiers must use the `provider:model_id` format (for example `anthropic:claude-3-5-sonnet-20241022`).
- `model`, `plan_model`, and `small_model` are validated against the configured providers before the database is touched.
- Invalid configurations return a `400 invalid_request_error` response and the create/update operation is aborted.
## Development Commands
```bash

View File

@ -0,0 +1,23 @@
import { ModelValidationError } from '@main/apiServer/utils'
import { AgentType } from '@types'
export type AgentModelField = 'model' | 'plan_model' | 'small_model'
export interface AgentModelValidationContext {
agentType: AgentType
field: AgentModelField
model?: string
}
export class AgentModelValidationError extends Error {
readonly context: AgentModelValidationContext
readonly detail: ModelValidationError
constructor(context: AgentModelValidationContext, detail: ModelValidationError) {
super(`Validation failed for ${context.agentType}.${context.field}: ${detail.message}`)
this.name = 'AgentModelValidationError'
this.context = context
this.detail = detail
}
}

View File

@ -13,6 +13,9 @@
// Main service classes and singleton instances
export * from './services'
// === Error Types ===
export { type AgentModelField, AgentModelValidationError } from './errors'
// === Base Infrastructure ===
// Shared database utilities and base service class
export { BaseService } from './BaseService'

View File

@ -15,10 +15,12 @@ import { count, eq } from 'drizzle-orm'
import { BaseService } from '../BaseService'
import { type AgentRow, agentsTable, type InsertAgentRow } from '../database/schema'
import { AgentModelField } from '../errors'
import { builtinTools } from './claudecode/tools'
export class AgentService extends BaseService {
private static instance: AgentService | null = null
private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model']
static getInstance(): AgentService {
if (!AgentService.instance) {
@ -43,6 +45,12 @@ export class AgentService extends BaseService {
req.accessible_paths = [defaultPath]
}
await this.validateAgentModels(req.type, {
model: req.model,
plan_model: req.plan_model,
small_model: req.small_model
})
this.ensurePathsExist(req.accessible_paths)
const serializedReq = this.serializeJsonFields(req)
@ -132,6 +140,18 @@ export class AgentService extends BaseService {
if (updates.accessible_paths) {
this.ensurePathsExist(updates.accessible_paths)
}
const modelUpdates: Partial<Record<AgentModelField, string | undefined>> = {}
for (const field of this.modelFields) {
if (Object.prototype.hasOwnProperty.call(updates, field)) {
modelUpdates[field] = updates[field as keyof UpdateAgentRequest] as string | undefined
}
}
if (Object.keys(modelUpdates).length > 0) {
await this.validateAgentModels(existing.type, modelUpdates)
}
const serializedUpdates = this.serializeJsonFields(updates)
const updateData: Partial<AgentRow> = {

View File

@ -13,9 +13,11 @@ import { and, count, eq, type SQL } from 'drizzle-orm'
import { BaseService } from '../BaseService'
import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema'
import { AgentModelField } from '../errors'
export class SessionService extends BaseService {
private static instance: SessionService | null = null
private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model']
static getInstance(): SessionService {
if (!SessionService.instance) {
@ -50,6 +52,12 @@ export class SessionService extends BaseService {
...req
}
await this.validateAgentModels(agent.type, {
model: sessionData.model,
plan_model: sessionData.plan_model,
small_model: sessionData.small_model
})
this.ensurePathsExist(sessionData.accessible_paths)
const serializedData = this.serializeJsonFields(sessionData)
@ -174,6 +182,18 @@ export class SessionService extends BaseService {
if (updates.accessible_paths) {
this.ensurePathsExist(updates.accessible_paths)
}
const modelUpdates: Partial<Record<AgentModelField, string | undefined>> = {}
for (const field of this.modelFields) {
if (Object.prototype.hasOwnProperty.call(updates, field)) {
modelUpdates[field] = updates[field as keyof UpdateSessionRequest] as string | undefined
}
}
if (Object.keys(modelUpdates).length > 0) {
await this.validateAgentModels(existing.agent_type, modelUpdates)
}
const serializedUpdates = this.serializeJsonFields(updates)
const updateData: Partial<SessionRow> = {

View File

@ -50,14 +50,12 @@ class ClaudeCodeService implements AgentServiceInterface {
return aiStream
}
// Validate model
const modelId = session.model
logger.info('Invoking Claude Code with model', { modelId, cwd })
const modelInfo = await validateModelId(modelId)
// Validate model info
const modelInfo = await validateModelId(session.model)
if (!modelInfo.valid) {
aiStream.emit('data', {
type: 'error',
error: new Error(`Invalid model ID '${modelId}': ${JSON.stringify(modelInfo.error)}`)
error: new Error(`Invalid model ID '${session.model}': ${JSON.stringify(modelInfo.error)}`)
})
return aiStream
}