mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 23:12:38 +08:00
feat(agents, sessions): implement replace functionality for agent and session updates
This commit is contained in:
parent
df1d4cd62b
commit
514b60f704
@ -1,8 +1,9 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { ListAgentsResponse } from '@types'
|
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||||
import { Request, Response } from 'express'
|
import { Request, Response } from 'express'
|
||||||
|
|
||||||
import { agentService } from '../../../../services/agents'
|
import { agentService } from '../../../../services/agents'
|
||||||
|
import type { ValidationRequest } from '../validators/zodValidator'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerAgentsHandlers')
|
const logger = loggerService.withContext('ApiServerAgentsHandlers')
|
||||||
|
|
||||||
@ -263,7 +264,10 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
|||||||
logger.info(`Updating agent: ${agentId}`)
|
logger.info(`Updating agent: ${agentId}`)
|
||||||
logger.debug('Update data:', req.body)
|
logger.debug('Update data:', req.body)
|
||||||
|
|
||||||
const agent = await agentService.updateAgent(agentId, req.body)
|
const { validatedBody } = req as ValidationRequest
|
||||||
|
const replacePayload = (validatedBody ?? {}) as ReplaceAgentRequest
|
||||||
|
|
||||||
|
const agent = await agentService.updateAgent(agentId, replacePayload, { replace: true })
|
||||||
|
|
||||||
if (!agent) {
|
if (!agent) {
|
||||||
logger.warn(`Agent not found for update: ${agentId}`)
|
logger.warn(`Agent not found for update: ${agentId}`)
|
||||||
@ -395,7 +399,10 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
|
|||||||
logger.info(`Partially updating agent: ${agentId}`)
|
logger.info(`Partially updating agent: ${agentId}`)
|
||||||
logger.debug('Partial update data:', req.body)
|
logger.debug('Partial update data:', req.body)
|
||||||
|
|
||||||
const agent = await agentService.updateAgent(agentId, req.body)
|
const { validatedBody } = req as ValidationRequest
|
||||||
|
const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest
|
||||||
|
|
||||||
|
const agent = await agentService.updateAgent(agentId, updatePayload)
|
||||||
|
|
||||||
if (!agent) {
|
if (!agent) {
|
||||||
logger.warn(`Agent not found for partial update: ${agentId}`)
|
logger.warn(`Agent not found for partial update: ${agentId}`)
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { sessionMessageService, sessionService } from '@main/services/agents'
|
import { sessionMessageService, sessionService } from '@main/services/agents'
|
||||||
import { CreateSessionResponse, ListAgentSessionsResponse } from '@types'
|
import { CreateSessionResponse, ListAgentSessionsResponse,type ReplaceSessionRequest } from '@types'
|
||||||
import { Request, Response } from 'express'
|
import { Request, Response } from 'express'
|
||||||
|
|
||||||
|
import type { ValidationRequest } from '../validators/zodValidator'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerSessionsHandlers')
|
const logger = loggerService.withContext('ApiServerSessionsHandlers')
|
||||||
|
|
||||||
export const createSession = async (req: Request, res: Response): Promise<Response> => {
|
export const createSession = async (req: Request, res: Response): Promise<Response> => {
|
||||||
@ -131,9 +133,10 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// For PUT, we replace the entire resource
|
const { validatedBody } = req as ValidationRequest
|
||||||
const sessionData = { ...req.body, main_agent_id: agentId }
|
const replacePayload = (validatedBody ?? {}) as ReplaceSessionRequest
|
||||||
const session = await sessionService.updateSession(agentId, sessionId, sessionData)
|
|
||||||
|
const session = await sessionService.updateSession(agentId, sessionId, replacePayload)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found for update: ${sessionId}`)
|
logger.warn(`Session not found for update: ${sessionId}`)
|
||||||
|
|||||||
@ -5,11 +5,13 @@ import { checkAgentExists, handleValidationErrors } from './middleware'
|
|||||||
import {
|
import {
|
||||||
validateAgent,
|
validateAgent,
|
||||||
validateAgentId,
|
validateAgentId,
|
||||||
|
validateAgentReplace,
|
||||||
validateAgentUpdate,
|
validateAgentUpdate,
|
||||||
validatePagination,
|
validatePagination,
|
||||||
validateSession,
|
validateSession,
|
||||||
validateSessionId,
|
validateSessionId,
|
||||||
validateSessionMessage,
|
validateSessionMessage,
|
||||||
|
validateSessionReplace,
|
||||||
validateSessionUpdate
|
validateSessionUpdate
|
||||||
} from './validators'
|
} from './validators'
|
||||||
|
|
||||||
@ -152,7 +154,13 @@ const agentsRouter = express.Router()
|
|||||||
agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.createAgent)
|
agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.createAgent)
|
||||||
agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.listAgents)
|
agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.listAgents)
|
||||||
agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent)
|
agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent)
|
||||||
agentsRouter.put('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.updateAgent)
|
agentsRouter.put(
|
||||||
|
'/:agentId',
|
||||||
|
validateAgentId,
|
||||||
|
validateAgentReplace,
|
||||||
|
handleValidationErrors,
|
||||||
|
agentHandlers.updateAgent
|
||||||
|
)
|
||||||
agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent)
|
agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent)
|
||||||
agentsRouter.delete('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.deleteAgent)
|
agentsRouter.delete('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.deleteAgent)
|
||||||
|
|
||||||
@ -167,7 +175,7 @@ const createSessionsRouter = (): express.Router => {
|
|||||||
sessionsRouter.put(
|
sessionsRouter.put(
|
||||||
'/:sessionId',
|
'/:sessionId',
|
||||||
validateSessionId,
|
validateSessionId,
|
||||||
validateSessionUpdate,
|
validateSessionReplace,
|
||||||
handleValidationErrors,
|
handleValidationErrors,
|
||||||
sessionHandlers.updateSession
|
sessionHandlers.updateSession
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
import { AgentIdParamSchema, CreateAgentRequestSchema, UpdateAgentRequestSchema } from '@types'
|
import {
|
||||||
|
AgentIdParamSchema,
|
||||||
|
CreateAgentRequestSchema,
|
||||||
|
ReplaceAgentRequestSchema,
|
||||||
|
UpdateAgentRequestSchema
|
||||||
|
} from '@types'
|
||||||
|
|
||||||
import { createZodValidator } from './zodValidator'
|
import { createZodValidator } from './zodValidator'
|
||||||
|
|
||||||
@ -6,6 +11,10 @@ export const validateAgent = createZodValidator({
|
|||||||
body: CreateAgentRequestSchema
|
body: CreateAgentRequestSchema
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const validateAgentReplace = createZodValidator({
|
||||||
|
body: ReplaceAgentRequestSchema
|
||||||
|
})
|
||||||
|
|
||||||
export const validateAgentUpdate = createZodValidator({
|
export const validateAgentUpdate = createZodValidator({
|
||||||
body: UpdateAgentRequestSchema
|
body: UpdateAgentRequestSchema
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
import { CreateSessionRequestSchema, SessionIdParamSchema, UpdateSessionRequestSchema } from '@types'
|
import {
|
||||||
|
CreateSessionRequestSchema,
|
||||||
|
ReplaceSessionRequestSchema,
|
||||||
|
SessionIdParamSchema,
|
||||||
|
UpdateSessionRequestSchema
|
||||||
|
} from '@types'
|
||||||
|
|
||||||
import { createZodValidator } from './zodValidator'
|
import { createZodValidator } from './zodValidator'
|
||||||
|
|
||||||
@ -6,6 +11,10 @@ export const validateSession = createZodValidator({
|
|||||||
body: CreateSessionRequestSchema
|
body: CreateSessionRequestSchema
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const validateSessionReplace = createZodValidator({
|
||||||
|
body: ReplaceSessionRequestSchema
|
||||||
|
})
|
||||||
|
|
||||||
export const validateSessionUpdate = createZodValidator({
|
export const validateSessionUpdate = createZodValidator({
|
||||||
body: UpdateSessionRequestSchema
|
body: UpdateSessionRequestSchema
|
||||||
})
|
})
|
||||||
|
|||||||
@ -27,7 +27,7 @@ export abstract class BaseService {
|
|||||||
protected static db: LibSQLDatabase<typeof schema> | null = null
|
protected static db: LibSQLDatabase<typeof schema> | null = null
|
||||||
protected static isInitialized = false
|
protected static isInitialized = false
|
||||||
protected static initializationPromise: Promise<void> | null = null
|
protected static initializationPromise: Promise<void> | null = null
|
||||||
protected jsonFields: string[] = ['built_in_tools', 'mcps', 'configuration', 'accessible_paths']
|
protected jsonFields: string[] = ['built_in_tools', 'mcps', 'configuration', 'accessible_paths', 'allowed_tools']
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize database with retry logic and proper error handling
|
* Initialize database with retry logic and proper error handling
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import path from 'node:path'
|
import path from 'node:path'
|
||||||
|
|
||||||
import { getDataPath } from '@main/utils'
|
import { getDataPath } from '@main/utils'
|
||||||
import type {
|
import {
|
||||||
|
AgentBaseSchema,
|
||||||
AgentEntity,
|
AgentEntity,
|
||||||
CreateAgentRequest,
|
CreateAgentRequest,
|
||||||
CreateAgentResponse,
|
CreateAgentResponse,
|
||||||
@ -111,7 +112,11 @@ export class AgentService extends BaseService {
|
|||||||
return { agents, total: totalResult[0].count }
|
return { agents, total: totalResult[0].count }
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateAgent(id: string, updates: UpdateAgentRequest): Promise<UpdateAgentResponse | null> {
|
async updateAgent(
|
||||||
|
id: string,
|
||||||
|
updates: UpdateAgentRequest,
|
||||||
|
options: { replace?: boolean } = {}
|
||||||
|
): Promise<UpdateAgentResponse | null> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
// Check if agent exists
|
// Check if agent exists
|
||||||
@ -126,18 +131,20 @@ export class AgentService extends BaseService {
|
|||||||
const updateData: Partial<AgentRow> = {
|
const updateData: Partial<AgentRow> = {
|
||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
|
const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof AgentRow)[]
|
||||||
|
const shouldReplace = options.replace ?? false
|
||||||
|
|
||||||
|
for (const field of replaceableFields) {
|
||||||
|
if (shouldReplace || Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
|
||||||
|
if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
|
||||||
|
const value = serializedUpdates[field as keyof typeof serializedUpdates]
|
||||||
|
;(updateData as Record<string, unknown>)[field] = value ?? null
|
||||||
|
} else if (shouldReplace) {
|
||||||
|
;(updateData as Record<string, unknown>)[field] = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Only update fields that are provided
|
|
||||||
if (serializedUpdates.name !== undefined) updateData.name = serializedUpdates.name
|
|
||||||
if (serializedUpdates.description !== undefined) updateData.description = serializedUpdates.description
|
|
||||||
if (serializedUpdates.instructions !== undefined) updateData.instructions = serializedUpdates.instructions
|
|
||||||
if (serializedUpdates.model !== undefined) updateData.model = serializedUpdates.model
|
|
||||||
if (serializedUpdates.plan_model !== undefined) updateData.plan_model = serializedUpdates.plan_model
|
|
||||||
if (serializedUpdates.small_model !== undefined) updateData.small_model = serializedUpdates.small_model
|
|
||||||
if (serializedUpdates.mcps !== undefined) updateData.mcps = serializedUpdates.mcps
|
|
||||||
if (serializedUpdates.configuration !== undefined) updateData.configuration = serializedUpdates.configuration
|
|
||||||
if (serializedUpdates.accessible_paths !== undefined)
|
|
||||||
updateData.accessible_paths = serializedUpdates.accessible_paths
|
|
||||||
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
|
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
|
||||||
return await this.getAgent(id)
|
return await this.getAgent(id)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import type {
|
import {
|
||||||
AgentEntity,
|
AgentBaseSchema,
|
||||||
AgentSessionEntity,
|
type AgentEntity,
|
||||||
CreateSessionRequest,
|
type AgentSessionEntity,
|
||||||
CreateSessionResponse,
|
type CreateSessionRequest,
|
||||||
GetAgentSessionResponse,
|
type CreateSessionResponse,
|
||||||
ListOptions,
|
type GetAgentSessionResponse,
|
||||||
UpdateSessionRequest
|
type ListOptions,
|
||||||
|
type UpdateSessionRequest
|
||||||
} from '@types'
|
} from '@types'
|
||||||
import { and, count, eq, type SQL } from 'drizzle-orm'
|
import { and, count, eq, type SQL } from 'drizzle-orm'
|
||||||
|
|
||||||
@ -149,7 +150,11 @@ export class SessionService extends BaseService {
|
|||||||
return { sessions, total }
|
return { sessions, total }
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateSession(agentId: string, id: string, updates: UpdateSessionRequest): Promise<GetAgentSessionResponse | null> {
|
async updateSession(
|
||||||
|
agentId: string,
|
||||||
|
id: string,
|
||||||
|
updates: UpdateSessionRequest
|
||||||
|
): Promise<GetAgentSessionResponse | null> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
// Check if session exists
|
// Check if session exists
|
||||||
@ -167,19 +172,14 @@ export class SessionService extends BaseService {
|
|||||||
const updateData: Partial<SessionRow> = {
|
const updateData: Partial<SessionRow> = {
|
||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
|
const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof SessionRow)[]
|
||||||
|
|
||||||
// Only update fields that are provided
|
for (const field of replaceableFields) {
|
||||||
if (serializedUpdates.name !== undefined) updateData.name = serializedUpdates.name
|
if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) {
|
||||||
|
const value = serializedUpdates[field as keyof typeof serializedUpdates]
|
||||||
if (serializedUpdates.model !== undefined) updateData.model = serializedUpdates.model
|
;(updateData as Record<string, unknown>)[field] = value ?? null
|
||||||
if (serializedUpdates.plan_model !== undefined) updateData.plan_model = serializedUpdates.plan_model
|
}
|
||||||
if (serializedUpdates.small_model !== undefined) updateData.small_model = serializedUpdates.small_model
|
}
|
||||||
|
|
||||||
if (serializedUpdates.mcps !== undefined) updateData.mcps = serializedUpdates.mcps
|
|
||||||
|
|
||||||
if (serializedUpdates.configuration !== undefined) updateData.configuration = serializedUpdates.configuration
|
|
||||||
if (serializedUpdates.accessible_paths !== undefined)
|
|
||||||
updateData.accessible_paths = serializedUpdates.accessible_paths
|
|
||||||
|
|
||||||
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
||||||
|
|
||||||
|
|||||||
@ -163,6 +163,8 @@ export type CreateAgentResponse = AgentEntity
|
|||||||
|
|
||||||
export interface UpdateAgentRequest extends Partial<AgentBase> {}
|
export interface UpdateAgentRequest extends Partial<AgentBase> {}
|
||||||
|
|
||||||
|
export type ReplaceAgentRequest = AgentBase
|
||||||
|
|
||||||
export const GetAgentResponseSchema = AgentEntitySchema.extend({
|
export const GetAgentResponseSchema = AgentEntitySchema.extend({
|
||||||
built_in_tools: z.array(ToolSchema).optional() // Built-in tools available to the agent
|
built_in_tools: z.array(ToolSchema).optional() // Built-in tools available to the agent
|
||||||
})
|
})
|
||||||
@ -250,6 +252,8 @@ export const CreateAgentRequestSchema = agentCreatableSchema.extend({
|
|||||||
|
|
||||||
export const UpdateAgentRequestSchema = AgentBaseSchema.partial()
|
export const UpdateAgentRequestSchema = AgentBaseSchema.partial()
|
||||||
|
|
||||||
|
export const ReplaceAgentRequestSchema = AgentBaseSchema
|
||||||
|
|
||||||
const sessionCreatableSchema = AgentBaseSchema.extend({
|
const sessionCreatableSchema = AgentBaseSchema.extend({
|
||||||
model: z.string().min(1, 'Model is required')
|
model: z.string().min(1, 'Model is required')
|
||||||
})
|
})
|
||||||
@ -258,6 +262,10 @@ export const CreateSessionRequestSchema = sessionCreatableSchema
|
|||||||
|
|
||||||
export const UpdateSessionRequestSchema = sessionCreatableSchema.partial()
|
export const UpdateSessionRequestSchema = sessionCreatableSchema.partial()
|
||||||
|
|
||||||
|
export const ReplaceSessionRequestSchema = sessionCreatableSchema
|
||||||
|
|
||||||
|
export type ReplaceSessionRequest = z.infer<typeof ReplaceSessionRequestSchema>
|
||||||
|
|
||||||
export const CreateSessionMessageRequestSchema = z.object({
|
export const CreateSessionMessageRequestSchema = z.object({
|
||||||
content: z.string().min(1, 'Content must be a valid string')
|
content: z.string().min(1, 'Content must be a valid string')
|
||||||
})
|
})
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user