diff --git a/src/main/apiServer/routes/agents/handlers/agents.ts b/src/main/apiServer/routes/agents/handlers/agents.ts index cd1cbf850b..aa021af063 100644 --- a/src/main/apiServer/routes/agents/handlers/agents.ts +++ b/src/main/apiServer/routes/agents/handlers/agents.ts @@ -1,8 +1,9 @@ import { loggerService } from '@logger' -import { ListAgentsResponse } from '@types' +import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types' import { Request, Response } from 'express' import { agentService } from '../../../../services/agents' +import type { ValidationRequest } from '../validators/zodValidator' const logger = loggerService.withContext('ApiServerAgentsHandlers') @@ -263,7 +264,10 @@ export const updateAgent = async (req: Request, res: Response): Promise logger.info(`Partially updating agent: ${agentId}`) logger.debug('Partial update data:', req.body) - const agent = await agentService.updateAgent(agentId, req.body) + const { validatedBody } = req as ValidationRequest + const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest + + const agent = await agentService.updateAgent(agentId, updatePayload) if (!agent) { logger.warn(`Agent not found for partial update: ${agentId}`) diff --git a/src/main/apiServer/routes/agents/handlers/sessions.ts b/src/main/apiServer/routes/agents/handlers/sessions.ts index f164d6cd52..fb28515558 100644 --- a/src/main/apiServer/routes/agents/handlers/sessions.ts +++ b/src/main/apiServer/routes/agents/handlers/sessions.ts @@ -1,8 +1,10 @@ import { loggerService } from '@logger' import { sessionMessageService, sessionService } from '@main/services/agents' -import { CreateSessionResponse, ListAgentSessionsResponse } from '@types' +import { CreateSessionResponse, ListAgentSessionsResponse,type ReplaceSessionRequest } from '@types' import { Request, Response } from 'express' +import type { ValidationRequest } from '../validators/zodValidator' + const logger = loggerService.withContext('ApiServerSessionsHandlers') export const createSession = async (req: Request, res: Response): Promise => { @@ -131,9 +133,10 @@ export const updateSession = async (req: Request, res: Response): Promise { sessionsRouter.put( '/:sessionId', validateSessionId, - validateSessionUpdate, + validateSessionReplace, handleValidationErrors, sessionHandlers.updateSession ) diff --git a/src/main/apiServer/routes/agents/validators/agents.ts b/src/main/apiServer/routes/agents/validators/agents.ts index d9c23b13f1..4b29e66929 100644 --- a/src/main/apiServer/routes/agents/validators/agents.ts +++ b/src/main/apiServer/routes/agents/validators/agents.ts @@ -1,4 +1,9 @@ -import { AgentIdParamSchema, CreateAgentRequestSchema, UpdateAgentRequestSchema } from '@types' +import { + AgentIdParamSchema, + CreateAgentRequestSchema, + ReplaceAgentRequestSchema, + UpdateAgentRequestSchema +} from '@types' import { createZodValidator } from './zodValidator' @@ -6,6 +11,10 @@ export const validateAgent = createZodValidator({ body: CreateAgentRequestSchema }) +export const validateAgentReplace = createZodValidator({ + body: ReplaceAgentRequestSchema +}) + export const validateAgentUpdate = createZodValidator({ body: UpdateAgentRequestSchema }) diff --git a/src/main/apiServer/routes/agents/validators/sessions.ts b/src/main/apiServer/routes/agents/validators/sessions.ts index 3c87349230..5081849649 100644 --- a/src/main/apiServer/routes/agents/validators/sessions.ts +++ b/src/main/apiServer/routes/agents/validators/sessions.ts @@ -1,4 +1,9 @@ -import { CreateSessionRequestSchema, SessionIdParamSchema, UpdateSessionRequestSchema } from '@types' +import { + CreateSessionRequestSchema, + ReplaceSessionRequestSchema, + SessionIdParamSchema, + UpdateSessionRequestSchema +} from '@types' import { createZodValidator } from './zodValidator' @@ -6,6 +11,10 @@ export const validateSession = createZodValidator({ body: CreateSessionRequestSchema }) +export const validateSessionReplace = createZodValidator({ + body: ReplaceSessionRequestSchema +}) + export const validateSessionUpdate = createZodValidator({ body: UpdateSessionRequestSchema }) diff --git a/src/main/services/agents/BaseService.ts b/src/main/services/agents/BaseService.ts index 2da1a8aa96..40a3d059dc 100644 --- a/src/main/services/agents/BaseService.ts +++ b/src/main/services/agents/BaseService.ts @@ -27,7 +27,7 @@ export abstract class BaseService { protected static db: LibSQLDatabase | null = null protected static isInitialized = false protected static initializationPromise: Promise | null = null - protected jsonFields: string[] = ['built_in_tools', 'mcps', 'configuration', 'accessible_paths'] + protected jsonFields: string[] = ['built_in_tools', 'mcps', 'configuration', 'accessible_paths', 'allowed_tools'] /** * Initialize database with retry logic and proper error handling diff --git a/src/main/services/agents/services/AgentService.ts b/src/main/services/agents/services/AgentService.ts index ec72bacd54..03110cc55a 100644 --- a/src/main/services/agents/services/AgentService.ts +++ b/src/main/services/agents/services/AgentService.ts @@ -1,7 +1,8 @@ import path from 'node:path' import { getDataPath } from '@main/utils' -import type { +import { + AgentBaseSchema, AgentEntity, CreateAgentRequest, CreateAgentResponse, @@ -111,7 +112,11 @@ export class AgentService extends BaseService { return { agents, total: totalResult[0].count } } - async updateAgent(id: string, updates: UpdateAgentRequest): Promise { + async updateAgent( + id: string, + updates: UpdateAgentRequest, + options: { replace?: boolean } = {} + ): Promise { this.ensureInitialized() // Check if agent exists @@ -126,18 +131,20 @@ export class AgentService extends BaseService { const updateData: Partial = { updated_at: now } + const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof AgentRow)[] + const shouldReplace = options.replace ?? false + + for (const field of replaceableFields) { + if (shouldReplace || Object.prototype.hasOwnProperty.call(serializedUpdates, field)) { + if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) { + const value = serializedUpdates[field as keyof typeof serializedUpdates] + ;(updateData as Record)[field] = value ?? null + } else if (shouldReplace) { + ;(updateData as Record)[field] = null + } + } + } - // Only update fields that are provided - if (serializedUpdates.name !== undefined) updateData.name = serializedUpdates.name - if (serializedUpdates.description !== undefined) updateData.description = serializedUpdates.description - if (serializedUpdates.instructions !== undefined) updateData.instructions = serializedUpdates.instructions - if (serializedUpdates.model !== undefined) updateData.model = serializedUpdates.model - if (serializedUpdates.plan_model !== undefined) updateData.plan_model = serializedUpdates.plan_model - if (serializedUpdates.small_model !== undefined) updateData.small_model = serializedUpdates.small_model - if (serializedUpdates.mcps !== undefined) updateData.mcps = serializedUpdates.mcps - if (serializedUpdates.configuration !== undefined) updateData.configuration = serializedUpdates.configuration - if (serializedUpdates.accessible_paths !== undefined) - updateData.accessible_paths = serializedUpdates.accessible_paths await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id)) return await this.getAgent(id) } diff --git a/src/main/services/agents/services/SessionService.ts b/src/main/services/agents/services/SessionService.ts index 63a64c1efb..07c4c40cce 100644 --- a/src/main/services/agents/services/SessionService.ts +++ b/src/main/services/agents/services/SessionService.ts @@ -1,11 +1,12 @@ -import type { - AgentEntity, - AgentSessionEntity, - CreateSessionRequest, - CreateSessionResponse, - GetAgentSessionResponse, - ListOptions, - UpdateSessionRequest +import { + AgentBaseSchema, + type AgentEntity, + type AgentSessionEntity, + type CreateSessionRequest, + type CreateSessionResponse, + type GetAgentSessionResponse, + type ListOptions, + type UpdateSessionRequest } from '@types' import { and, count, eq, type SQL } from 'drizzle-orm' @@ -149,7 +150,11 @@ export class SessionService extends BaseService { return { sessions, total } } - async updateSession(agentId: string, id: string, updates: UpdateSessionRequest): Promise { + async updateSession( + agentId: string, + id: string, + updates: UpdateSessionRequest + ): Promise { this.ensureInitialized() // Check if session exists @@ -167,19 +172,14 @@ export class SessionService extends BaseService { const updateData: Partial = { updated_at: now } + const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof SessionRow)[] - // Only update fields that are provided - if (serializedUpdates.name !== undefined) updateData.name = serializedUpdates.name - - if (serializedUpdates.model !== undefined) updateData.model = serializedUpdates.model - if (serializedUpdates.plan_model !== undefined) updateData.plan_model = serializedUpdates.plan_model - if (serializedUpdates.small_model !== undefined) updateData.small_model = serializedUpdates.small_model - - if (serializedUpdates.mcps !== undefined) updateData.mcps = serializedUpdates.mcps - - if (serializedUpdates.configuration !== undefined) updateData.configuration = serializedUpdates.configuration - if (serializedUpdates.accessible_paths !== undefined) - updateData.accessible_paths = serializedUpdates.accessible_paths + for (const field of replaceableFields) { + if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) { + const value = serializedUpdates[field as keyof typeof serializedUpdates] + ;(updateData as Record)[field] = value ?? null + } + } await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id)) diff --git a/src/renderer/src/types/agent.ts b/src/renderer/src/types/agent.ts index ecce2941cf..9e2f967a0d 100644 --- a/src/renderer/src/types/agent.ts +++ b/src/renderer/src/types/agent.ts @@ -163,6 +163,8 @@ export type CreateAgentResponse = AgentEntity export interface UpdateAgentRequest extends Partial {} +export type ReplaceAgentRequest = AgentBase + export const GetAgentResponseSchema = AgentEntitySchema.extend({ built_in_tools: z.array(ToolSchema).optional() // Built-in tools available to the agent }) @@ -250,6 +252,8 @@ export const CreateAgentRequestSchema = agentCreatableSchema.extend({ export const UpdateAgentRequestSchema = AgentBaseSchema.partial() +export const ReplaceAgentRequestSchema = AgentBaseSchema + const sessionCreatableSchema = AgentBaseSchema.extend({ model: z.string().min(1, 'Model is required') }) @@ -258,6 +262,10 @@ export const CreateSessionRequestSchema = sessionCreatableSchema export const UpdateSessionRequestSchema = sessionCreatableSchema.partial() +export const ReplaceSessionRequestSchema = sessionCreatableSchema + +export type ReplaceSessionRequest = z.infer + export const CreateSessionMessageRequestSchema = z.object({ content: z.string().min(1, 'Content must be a valid string') })