mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 02:20:10 +08:00
feat(agents): implement model validation for agent and session creation/updating
This commit is contained in:
parent
1515f511a1
commit
d6468f33c5
@ -1,12 +1,20 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AgentModelValidationError, agentService } from '@main/services/agents'
|
||||
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import { agentService } from '../../../../services/agents'
|
||||
import type { ValidationRequest } from '../validators/zodValidator'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerAgentsHandlers')
|
||||
|
||||
const modelValidationErrorBody = (error: AgentModelValidationError) => ({
|
||||
error: {
|
||||
message: `Invalid ${error.context.field}: ${error.detail.message}`,
|
||||
type: 'invalid_request_error',
|
||||
code: error.detail.code
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/agents:
|
||||
@ -50,6 +58,16 @@ export const createAgent = async (req: Request, res: Response): Promise<Response
|
||||
logger.info(`Agent created successfully: ${agent.id}`)
|
||||
return res.status(201).json(agent)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Agent model validation error during create:', {
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error creating agent:', error)
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
@ -259,8 +277,8 @@ export const getAgent = async (req: Request, res: Response): Promise<Response> =
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const updateAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId } = req.params
|
||||
try {
|
||||
const { agentId } = req.params
|
||||
logger.info(`Updating agent: ${agentId}`)
|
||||
logger.debug('Update data:', req.body)
|
||||
|
||||
@ -283,6 +301,17 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
||||
logger.info(`Agent updated successfully: ${agentId}`)
|
||||
return res.json(agent)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Agent model validation error during update:', {
|
||||
agentId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error updating agent:', error)
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
@ -394,8 +423,8 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const patchAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId } = req.params
|
||||
try {
|
||||
const { agentId } = req.params
|
||||
logger.info(`Partially updating agent: ${agentId}`)
|
||||
logger.debug('Partial update data:', req.body)
|
||||
|
||||
@ -418,6 +447,17 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
|
||||
logger.info(`Agent partially updated successfully: ${agentId}`)
|
||||
return res.json(agent)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Agent model validation error during partial update:', {
|
||||
agentId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error partially updating agent:', error)
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { sessionMessageService, sessionService } from '@main/services/agents'
|
||||
import {
|
||||
AgentModelValidationError,
|
||||
sessionMessageService,
|
||||
sessionService
|
||||
} from '@main/services/agents'
|
||||
import {
|
||||
CreateSessionResponse,
|
||||
ListAgentSessionsResponse,
|
||||
@ -12,9 +16,17 @@ import type { ValidationRequest } from '../validators/zodValidator'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerSessionsHandlers')
|
||||
|
||||
const modelValidationErrorBody = (error: AgentModelValidationError) => ({
|
||||
error: {
|
||||
message: `Invalid ${error.context.field}: ${error.detail.message}`,
|
||||
type: 'invalid_request_error',
|
||||
code: error.detail.code
|
||||
}
|
||||
})
|
||||
|
||||
export const createSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId } = req.params
|
||||
try {
|
||||
const { agentId } = req.params
|
||||
const sessionData = req.body
|
||||
|
||||
logger.info(`Creating new session for agent: ${agentId}`)
|
||||
@ -25,6 +37,17 @@ export const createSession = async (req: Request, res: Response): Promise<Respon
|
||||
logger.info(`Session created successfully: ${session.id}`)
|
||||
return res.status(201).json(session)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Session model validation error during create:', {
|
||||
agentId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error creating session:', error)
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
@ -120,8 +143,8 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
|
||||
}
|
||||
|
||||
export const updateSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId, sessionId } = req.params
|
||||
try {
|
||||
const { agentId, sessionId } = req.params
|
||||
logger.info(`Updating session: ${sessionId} for agent: ${agentId}`)
|
||||
logger.debug('Update data:', req.body)
|
||||
|
||||
@ -157,6 +180,18 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
||||
logger.info(`Session updated successfully: ${sessionId}`)
|
||||
return res.json(session satisfies UpdateSessionResponse)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Session model validation error during update:', {
|
||||
agentId,
|
||||
sessionId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error updating session:', error)
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
@ -169,8 +204,8 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
||||
}
|
||||
|
||||
export const patchSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId, sessionId } = req.params
|
||||
try {
|
||||
const { agentId, sessionId } = req.params
|
||||
logger.info(`Patching session: ${sessionId} for agent: ${agentId}`)
|
||||
logger.debug('Patch data:', req.body)
|
||||
|
||||
@ -204,6 +239,18 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
||||
logger.info(`Session patched successfully: ${sessionId}`)
|
||||
return res.json(session)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Session model validation error during patch:', {
|
||||
agentId,
|
||||
sessionId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error patching session:', error)
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { type Client, createClient } from '@libsql/client'
|
||||
import { loggerService } from '@logger'
|
||||
import { objectKeys } from '@types'
|
||||
import { ModelValidationError, validateModelId } from '@main/apiServer/utils'
|
||||
import { AgentType, objectKeys, Provider } from '@types'
|
||||
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
@ -8,6 +9,7 @@ import path from 'path'
|
||||
import { MigrationService } from './database/MigrationService'
|
||||
import * as schema from './database/schema'
|
||||
import { dbPath } from './drizzle.config'
|
||||
import { AgentModelField, AgentModelValidationError } from './errors'
|
||||
|
||||
const logger = loggerService.withContext('BaseService')
|
||||
|
||||
@ -186,8 +188,7 @@ export abstract class BaseService {
|
||||
}
|
||||
|
||||
const looksLikeFile =
|
||||
(stats && stats.isFile()) ||
|
||||
(!stats && path.extname(resolvedPath) !== '' && !resolvedPath.endsWith(path.sep))
|
||||
(stats && stats.isFile()) || (!stats && path.extname(resolvedPath) !== '' && !resolvedPath.endsWith(path.sep))
|
||||
|
||||
const directoryToEnsure = looksLikeFile ? path.dirname(resolvedPath) : resolvedPath
|
||||
|
||||
@ -208,6 +209,63 @@ export abstract class BaseService {
|
||||
/**
|
||||
* Force re-initialization (for development/testing)
|
||||
*/
|
||||
protected async validateAgentModels(
|
||||
agentType: AgentType,
|
||||
models: Partial<Record<AgentModelField, string | undefined>>
|
||||
): Promise<void> {
|
||||
const entries = Object.entries(models) as [AgentModelField, string | undefined][]
|
||||
if (entries.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
for (const [field, rawValue] of entries) {
|
||||
if (rawValue === undefined || rawValue === null) {
|
||||
continue
|
||||
}
|
||||
|
||||
const modelValue = rawValue
|
||||
const validation = await validateModelId(modelValue)
|
||||
|
||||
if (!validation.valid || !validation.provider) {
|
||||
const detail: ModelValidationError = validation.error ?? {
|
||||
type: 'invalid_format',
|
||||
message: 'Unknown model validation error',
|
||||
code: 'validation_error'
|
||||
}
|
||||
|
||||
throw new AgentModelValidationError({ agentType, field, model: modelValue }, detail)
|
||||
}
|
||||
|
||||
if (!validation.provider.apiKey) {
|
||||
throw new AgentModelValidationError(
|
||||
{ agentType, field, model: modelValue },
|
||||
{
|
||||
type: 'invalid_format',
|
||||
message: `Provider '${validation.provider.id}' is missing an API key`,
|
||||
code: 'provider_api_key_missing'
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// different agent types may have different provider requirements
|
||||
const agentTypeProviderRequirements: Record<AgentType, Provider['type']> = {
|
||||
'claude-code': 'anthropic'
|
||||
}
|
||||
for (const [ak, pk] of Object.entries(agentTypeProviderRequirements)) {
|
||||
if (agentType === ak && validation.provider.type !== pk) {
|
||||
throw new AgentModelValidationError(
|
||||
{ agentType, field, model: modelValue },
|
||||
{
|
||||
type: 'unsupported_provider_type',
|
||||
message: `Provider type '${validation.provider.type}' is not supported for agent type '${agentType}'. Expected '${pk}'`,
|
||||
code: 'unsupported_provider_type'
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static async reinitialize(): Promise<void> {
|
||||
BaseService.isInitialized = false
|
||||
BaseService.initializationPromise = null
|
||||
|
||||
@ -8,6 +8,7 @@ Simplified Drizzle ORM implementation for agent and session management in Cherry
|
||||
- **Zero CLI dependencies** in production
|
||||
- **Auto-initialization** with retry logic
|
||||
- **Full TypeScript** type safety
|
||||
- **Model validation** to ensure models exist and provider configuration matches the agent type
|
||||
|
||||
## Schema
|
||||
|
||||
@ -24,10 +25,16 @@ import { agentService } from './services'
|
||||
const agent = await agentService.createAgent({
|
||||
type: 'custom',
|
||||
name: 'My Agent',
|
||||
model: 'claude-3-5-sonnet-20241022'
|
||||
model: 'anthropic:claude-3-5-sonnet-20241022'
|
||||
})
|
||||
```
|
||||
|
||||
## Model Validation
|
||||
|
||||
- Model identifiers must use the `provider:model_id` format (for example `anthropic:claude-3-5-sonnet-20241022`).
|
||||
- `model`, `plan_model`, and `small_model` are validated against the configured providers before the database is touched.
|
||||
- Invalid configurations return a `400 invalid_request_error` response and the create/update operation is aborted.
|
||||
|
||||
## Development Commands
|
||||
|
||||
```bash
|
||||
|
||||
23
src/main/services/agents/errors.ts
Normal file
23
src/main/services/agents/errors.ts
Normal file
@ -0,0 +1,23 @@
|
||||
import { ModelValidationError } from '@main/apiServer/utils'
|
||||
import { AgentType } from '@types'
|
||||
|
||||
export type AgentModelField = 'model' | 'plan_model' | 'small_model'
|
||||
|
||||
export interface AgentModelValidationContext {
|
||||
agentType: AgentType
|
||||
field: AgentModelField
|
||||
model?: string
|
||||
}
|
||||
|
||||
export class AgentModelValidationError extends Error {
|
||||
readonly context: AgentModelValidationContext
|
||||
readonly detail: ModelValidationError
|
||||
|
||||
constructor(context: AgentModelValidationContext, detail: ModelValidationError) {
|
||||
super(`Validation failed for ${context.agentType}.${context.field}: ${detail.message}`)
|
||||
this.name = 'AgentModelValidationError'
|
||||
this.context = context
|
||||
this.detail = detail
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,6 +13,9 @@
|
||||
// Main service classes and singleton instances
|
||||
export * from './services'
|
||||
|
||||
// === Error Types ===
|
||||
export { type AgentModelField, AgentModelValidationError } from './errors'
|
||||
|
||||
// === Base Infrastructure ===
|
||||
// Shared database utilities and base service class
|
||||
export { BaseService } from './BaseService'
|
||||
|
||||
@ -15,10 +15,12 @@ import { count, eq } from 'drizzle-orm'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import { type AgentRow, agentsTable, type InsertAgentRow } from '../database/schema'
|
||||
import { AgentModelField } from '../errors'
|
||||
import { builtinTools } from './claudecode/tools'
|
||||
|
||||
export class AgentService extends BaseService {
|
||||
private static instance: AgentService | null = null
|
||||
private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model']
|
||||
|
||||
static getInstance(): AgentService {
|
||||
if (!AgentService.instance) {
|
||||
@ -43,6 +45,12 @@ export class AgentService extends BaseService {
|
||||
req.accessible_paths = [defaultPath]
|
||||
}
|
||||
|
||||
await this.validateAgentModels(req.type, {
|
||||
model: req.model,
|
||||
plan_model: req.plan_model,
|
||||
small_model: req.small_model
|
||||
})
|
||||
|
||||
this.ensurePathsExist(req.accessible_paths)
|
||||
|
||||
const serializedReq = this.serializeJsonFields(req)
|
||||
@ -132,6 +140,18 @@ export class AgentService extends BaseService {
|
||||
if (updates.accessible_paths) {
|
||||
this.ensurePathsExist(updates.accessible_paths)
|
||||
}
|
||||
|
||||
const modelUpdates: Partial<Record<AgentModelField, string | undefined>> = {}
|
||||
for (const field of this.modelFields) {
|
||||
if (Object.prototype.hasOwnProperty.call(updates, field)) {
|
||||
modelUpdates[field] = updates[field as keyof UpdateAgentRequest] as string | undefined
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(modelUpdates).length > 0) {
|
||||
await this.validateAgentModels(existing.type, modelUpdates)
|
||||
}
|
||||
|
||||
const serializedUpdates = this.serializeJsonFields(updates)
|
||||
|
||||
const updateData: Partial<AgentRow> = {
|
||||
|
||||
@ -13,9 +13,11 @@ import { and, count, eq, type SQL } from 'drizzle-orm'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema'
|
||||
import { AgentModelField } from '../errors'
|
||||
|
||||
export class SessionService extends BaseService {
|
||||
private static instance: SessionService | null = null
|
||||
private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model']
|
||||
|
||||
static getInstance(): SessionService {
|
||||
if (!SessionService.instance) {
|
||||
@ -50,6 +52,12 @@ export class SessionService extends BaseService {
|
||||
...req
|
||||
}
|
||||
|
||||
await this.validateAgentModels(agent.type, {
|
||||
model: sessionData.model,
|
||||
plan_model: sessionData.plan_model,
|
||||
small_model: sessionData.small_model
|
||||
})
|
||||
|
||||
this.ensurePathsExist(sessionData.accessible_paths)
|
||||
|
||||
const serializedData = this.serializeJsonFields(sessionData)
|
||||
@ -174,6 +182,18 @@ export class SessionService extends BaseService {
|
||||
if (updates.accessible_paths) {
|
||||
this.ensurePathsExist(updates.accessible_paths)
|
||||
}
|
||||
|
||||
const modelUpdates: Partial<Record<AgentModelField, string | undefined>> = {}
|
||||
for (const field of this.modelFields) {
|
||||
if (Object.prototype.hasOwnProperty.call(updates, field)) {
|
||||
modelUpdates[field] = updates[field as keyof UpdateSessionRequest] as string | undefined
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(modelUpdates).length > 0) {
|
||||
await this.validateAgentModels(existing.agent_type, modelUpdates)
|
||||
}
|
||||
|
||||
const serializedUpdates = this.serializeJsonFields(updates)
|
||||
|
||||
const updateData: Partial<SessionRow> = {
|
||||
|
||||
@ -50,14 +50,12 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
return aiStream
|
||||
}
|
||||
|
||||
// Validate model
|
||||
const modelId = session.model
|
||||
logger.info('Invoking Claude Code with model', { modelId, cwd })
|
||||
const modelInfo = await validateModelId(modelId)
|
||||
// Validate model info
|
||||
const modelInfo = await validateModelId(session.model)
|
||||
if (!modelInfo.valid) {
|
||||
aiStream.emit('data', {
|
||||
type: 'error',
|
||||
error: new Error(`Invalid model ID '${modelId}': ${JSON.stringify(modelInfo.error)}`)
|
||||
error: new Error(`Invalid model ID '${session.model}': ${JSON.stringify(modelInfo.error)}`)
|
||||
})
|
||||
return aiStream
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user