feat(agents, sessions): implement replace functionality for agent and session updates

This commit is contained in:
Vaayne 2025-09-19 11:13:05 +08:00
parent df1d4cd62b
commit 514b60f704
9 changed files with 97 additions and 46 deletions

View File

@ -1,8 +1,9 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { ListAgentsResponse } from '@types' import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
import { Request, Response } from 'express' import { Request, Response } from 'express'
import { agentService } from '../../../../services/agents' import { agentService } from '../../../../services/agents'
import type { ValidationRequest } from '../validators/zodValidator'
const logger = loggerService.withContext('ApiServerAgentsHandlers') const logger = loggerService.withContext('ApiServerAgentsHandlers')
@ -263,7 +264,10 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
logger.info(`Updating agent: ${agentId}`) logger.info(`Updating agent: ${agentId}`)
logger.debug('Update data:', req.body) logger.debug('Update data:', req.body)
const agent = await agentService.updateAgent(agentId, req.body) const { validatedBody } = req as ValidationRequest
const replacePayload = (validatedBody ?? {}) as ReplaceAgentRequest
const agent = await agentService.updateAgent(agentId, replacePayload, { replace: true })
if (!agent) { if (!agent) {
logger.warn(`Agent not found for update: ${agentId}`) logger.warn(`Agent not found for update: ${agentId}`)
@ -395,7 +399,10 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
logger.info(`Partially updating agent: ${agentId}`) logger.info(`Partially updating agent: ${agentId}`)
logger.debug('Partial update data:', req.body) logger.debug('Partial update data:', req.body)
const agent = await agentService.updateAgent(agentId, req.body) const { validatedBody } = req as ValidationRequest
const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest
const agent = await agentService.updateAgent(agentId, updatePayload)
if (!agent) { if (!agent) {
logger.warn(`Agent not found for partial update: ${agentId}`) logger.warn(`Agent not found for partial update: ${agentId}`)

View File

@ -1,8 +1,10 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { sessionMessageService, sessionService } from '@main/services/agents' import { sessionMessageService, sessionService } from '@main/services/agents'
import { CreateSessionResponse, ListAgentSessionsResponse } from '@types' import { CreateSessionResponse, ListAgentSessionsResponse,type ReplaceSessionRequest } from '@types'
import { Request, Response } from 'express' import { Request, Response } from 'express'
import type { ValidationRequest } from '../validators/zodValidator'
const logger = loggerService.withContext('ApiServerSessionsHandlers') const logger = loggerService.withContext('ApiServerSessionsHandlers')
export const createSession = async (req: Request, res: Response): Promise<Response> => { export const createSession = async (req: Request, res: Response): Promise<Response> => {
@ -131,9 +133,10 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
}) })
} }
// For PUT, we replace the entire resource const { validatedBody } = req as ValidationRequest
const sessionData = { ...req.body, main_agent_id: agentId } const replacePayload = (validatedBody ?? {}) as ReplaceSessionRequest
const session = await sessionService.updateSession(agentId, sessionId, sessionData)
const session = await sessionService.updateSession(agentId, sessionId, replacePayload)
if (!session) { if (!session) {
logger.warn(`Session not found for update: ${sessionId}`) logger.warn(`Session not found for update: ${sessionId}`)

View File

@ -5,11 +5,13 @@ import { checkAgentExists, handleValidationErrors } from './middleware'
import { import {
validateAgent, validateAgent,
validateAgentId, validateAgentId,
validateAgentReplace,
validateAgentUpdate, validateAgentUpdate,
validatePagination, validatePagination,
validateSession, validateSession,
validateSessionId, validateSessionId,
validateSessionMessage, validateSessionMessage,
validateSessionReplace,
validateSessionUpdate validateSessionUpdate
} from './validators' } from './validators'
@ -152,7 +154,13 @@ const agentsRouter = express.Router()
agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.createAgent) agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.createAgent)
agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.listAgents) agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.listAgents)
agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent) agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent)
agentsRouter.put('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.updateAgent) agentsRouter.put(
'/:agentId',
validateAgentId,
validateAgentReplace,
handleValidationErrors,
agentHandlers.updateAgent
)
agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent) agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent)
agentsRouter.delete('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.deleteAgent) agentsRouter.delete('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.deleteAgent)
@ -167,7 +175,7 @@ const createSessionsRouter = (): express.Router => {
sessionsRouter.put( sessionsRouter.put(
'/:sessionId', '/:sessionId',
validateSessionId, validateSessionId,
validateSessionUpdate, validateSessionReplace,
handleValidationErrors, handleValidationErrors,
sessionHandlers.updateSession sessionHandlers.updateSession
) )

View File

@ -1,4 +1,9 @@
import { AgentIdParamSchema, CreateAgentRequestSchema, UpdateAgentRequestSchema } from '@types' import {
AgentIdParamSchema,
CreateAgentRequestSchema,
ReplaceAgentRequestSchema,
UpdateAgentRequestSchema
} from '@types'
import { createZodValidator } from './zodValidator' import { createZodValidator } from './zodValidator'
@ -6,6 +11,10 @@ export const validateAgent = createZodValidator({
body: CreateAgentRequestSchema body: CreateAgentRequestSchema
}) })
export const validateAgentReplace = createZodValidator({
body: ReplaceAgentRequestSchema
})
export const validateAgentUpdate = createZodValidator({ export const validateAgentUpdate = createZodValidator({
body: UpdateAgentRequestSchema body: UpdateAgentRequestSchema
}) })

View File

@ -1,4 +1,9 @@
import { CreateSessionRequestSchema, SessionIdParamSchema, UpdateSessionRequestSchema } from '@types' import {
CreateSessionRequestSchema,
ReplaceSessionRequestSchema,
SessionIdParamSchema,
UpdateSessionRequestSchema
} from '@types'
import { createZodValidator } from './zodValidator' import { createZodValidator } from './zodValidator'
@ -6,6 +11,10 @@ export const validateSession = createZodValidator({
body: CreateSessionRequestSchema body: CreateSessionRequestSchema
}) })
export const validateSessionReplace = createZodValidator({
body: ReplaceSessionRequestSchema
})
export const validateSessionUpdate = createZodValidator({ export const validateSessionUpdate = createZodValidator({
body: UpdateSessionRequestSchema body: UpdateSessionRequestSchema
}) })

View File

@ -27,7 +27,7 @@ export abstract class BaseService {
protected static db: LibSQLDatabase<typeof schema> | null = null protected static db: LibSQLDatabase<typeof schema> | null = null
protected static isInitialized = false protected static isInitialized = false
protected static initializationPromise: Promise<void> | null = null protected static initializationPromise: Promise<void> | null = null
protected jsonFields: string[] = ['built_in_tools', 'mcps', 'configuration', 'accessible_paths'] protected jsonFields: string[] = ['built_in_tools', 'mcps', 'configuration', 'accessible_paths', 'allowed_tools']
/** /**
* Initialize database with retry logic and proper error handling * Initialize database with retry logic and proper error handling

View File

@ -1,7 +1,8 @@
import path from 'node:path' import path from 'node:path'
import { getDataPath } from '@main/utils' import { getDataPath } from '@main/utils'
import type { import {
AgentBaseSchema,
AgentEntity, AgentEntity,
CreateAgentRequest, CreateAgentRequest,
CreateAgentResponse, CreateAgentResponse,
@ -111,7 +112,11 @@ export class AgentService extends BaseService {
return { agents, total: totalResult[0].count } return { agents, total: totalResult[0].count }
} }
async updateAgent(id: string, updates: UpdateAgentRequest): Promise<UpdateAgentResponse | null> { async updateAgent(
id: string,
updates: UpdateAgentRequest,
options: { replace?: boolean } = {}
): Promise<UpdateAgentResponse | null> {
this.ensureInitialized() this.ensureInitialized()
// Check if agent exists // Check if agent exists
@ -126,18 +131,20 @@ export class AgentService extends BaseService {
const updateData: Partial<AgentRow> = { const updateData: Partial<AgentRow> = {
updated_at: now updated_at: now
} }
const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof AgentRow)[]
const shouldReplace = options.replace ?? false
for (const field of replaceableFields) {
if (shouldReplace || Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
const value = serializedUpdates[field as keyof typeof serializedUpdates]
;(updateData as Record<string, unknown>)[field] = value ?? null
} else if (shouldReplace) {
;(updateData as Record<string, unknown>)[field] = null
}
}
}
// Only update fields that are provided
if (serializedUpdates.name !== undefined) updateData.name = serializedUpdates.name
if (serializedUpdates.description !== undefined) updateData.description = serializedUpdates.description
if (serializedUpdates.instructions !== undefined) updateData.instructions = serializedUpdates.instructions
if (serializedUpdates.model !== undefined) updateData.model = serializedUpdates.model
if (serializedUpdates.plan_model !== undefined) updateData.plan_model = serializedUpdates.plan_model
if (serializedUpdates.small_model !== undefined) updateData.small_model = serializedUpdates.small_model
if (serializedUpdates.mcps !== undefined) updateData.mcps = serializedUpdates.mcps
if (serializedUpdates.configuration !== undefined) updateData.configuration = serializedUpdates.configuration
if (serializedUpdates.accessible_paths !== undefined)
updateData.accessible_paths = serializedUpdates.accessible_paths
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id)) await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
return await this.getAgent(id) return await this.getAgent(id)
} }

View File

@ -1,11 +1,12 @@
import type { import {
AgentEntity, AgentBaseSchema,
AgentSessionEntity, type AgentEntity,
CreateSessionRequest, type AgentSessionEntity,
CreateSessionResponse, type CreateSessionRequest,
GetAgentSessionResponse, type CreateSessionResponse,
ListOptions, type GetAgentSessionResponse,
UpdateSessionRequest type ListOptions,
type UpdateSessionRequest
} from '@types' } from '@types'
import { and, count, eq, type SQL } from 'drizzle-orm' import { and, count, eq, type SQL } from 'drizzle-orm'
@ -149,7 +150,11 @@ export class SessionService extends BaseService {
return { sessions, total } return { sessions, total }
} }
async updateSession(agentId: string, id: string, updates: UpdateSessionRequest): Promise<GetAgentSessionResponse | null> { async updateSession(
agentId: string,
id: string,
updates: UpdateSessionRequest
): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized() this.ensureInitialized()
// Check if session exists // Check if session exists
@ -167,19 +172,14 @@ export class SessionService extends BaseService {
const updateData: Partial<SessionRow> = { const updateData: Partial<SessionRow> = {
updated_at: now updated_at: now
} }
const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof SessionRow)[]
// Only update fields that are provided for (const field of replaceableFields) {
if (serializedUpdates.name !== undefined) updateData.name = serializedUpdates.name if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
const value = serializedUpdates[field as keyof typeof serializedUpdates]
if (serializedUpdates.model !== undefined) updateData.model = serializedUpdates.model ;(updateData as Record<string, unknown>)[field] = value ?? null
if (serializedUpdates.plan_model !== undefined) updateData.plan_model = serializedUpdates.plan_model }
if (serializedUpdates.small_model !== undefined) updateData.small_model = serializedUpdates.small_model }
if (serializedUpdates.mcps !== undefined) updateData.mcps = serializedUpdates.mcps
if (serializedUpdates.configuration !== undefined) updateData.configuration = serializedUpdates.configuration
if (serializedUpdates.accessible_paths !== undefined)
updateData.accessible_paths = serializedUpdates.accessible_paths
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id)) await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))

View File

@ -163,6 +163,8 @@ export type CreateAgentResponse = AgentEntity
export interface UpdateAgentRequest extends Partial<AgentBase> {} export interface UpdateAgentRequest extends Partial<AgentBase> {}
export type ReplaceAgentRequest = AgentBase
export const GetAgentResponseSchema = AgentEntitySchema.extend({ export const GetAgentResponseSchema = AgentEntitySchema.extend({
built_in_tools: z.array(ToolSchema).optional() // Built-in tools available to the agent built_in_tools: z.array(ToolSchema).optional() // Built-in tools available to the agent
}) })
@ -250,6 +252,8 @@ export const CreateAgentRequestSchema = agentCreatableSchema.extend({
export const UpdateAgentRequestSchema = AgentBaseSchema.partial() export const UpdateAgentRequestSchema = AgentBaseSchema.partial()
export const ReplaceAgentRequestSchema = AgentBaseSchema
const sessionCreatableSchema = AgentBaseSchema.extend({ const sessionCreatableSchema = AgentBaseSchema.extend({
model: z.string().min(1, 'Model is required') model: z.string().min(1, 'Model is required')
}) })
@ -258,6 +262,10 @@ export const CreateSessionRequestSchema = sessionCreatableSchema
export const UpdateSessionRequestSchema = sessionCreatableSchema.partial() export const UpdateSessionRequestSchema = sessionCreatableSchema.partial()
export const ReplaceSessionRequestSchema = sessionCreatableSchema
export type ReplaceSessionRequest = z.infer<typeof ReplaceSessionRequestSchema>
export const CreateSessionMessageRequestSchema = z.object({ export const CreateSessionMessageRequestSchema = z.object({
content: z.string().min(1, 'Content must be a valid string') content: z.string().min(1, 'Content must be a valid string')
}) })