mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 10:40:07 +08:00
refactor(sessions): update getSession and related methods to include agentId parameter
This commit is contained in:
parent
4839b91cef
commit
df1d4cd62b
@ -12,7 +12,7 @@ const verifyAgentAndSession = async (agentId: string, sessionId: string) => {
|
||||
throw { status: 404, code: 'agent_not_found', message: 'Agent not found' }
|
||||
}
|
||||
|
||||
const session = await sessionService.getSession(sessionId)
|
||||
const session = await sessionService.getSession(agentId, sessionId)
|
||||
if (!session) {
|
||||
throw { status: 404, code: 'session_not_found', message: 'Session not found' }
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
|
||||
const { agentId, sessionId } = req.params
|
||||
logger.info(`Getting session: ${sessionId} for agent: ${agentId}`)
|
||||
|
||||
const session = await sessionService.getSession(sessionId)
|
||||
const session = await sessionService.getSession(agentId, sessionId)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found: ${sessionId}`)
|
||||
@ -119,7 +119,7 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
||||
logger.debug('Update data:', req.body)
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(sessionId)
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||
return res.status(404).json({
|
||||
@ -133,7 +133,7 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
||||
|
||||
// For PUT, we replace the entire resource
|
||||
const sessionData = { ...req.body, main_agent_id: agentId }
|
||||
const session = await sessionService.updateSession(sessionId, sessionData)
|
||||
const session = await sessionService.updateSession(agentId, sessionId, sessionData)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found for update: ${sessionId}`)
|
||||
@ -167,7 +167,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
||||
logger.debug('Patch data:', req.body)
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(sessionId)
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||
return res.status(404).json({
|
||||
@ -180,7 +180,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
||||
}
|
||||
|
||||
const updateSession = { ...existingSession, ...req.body }
|
||||
const session = await sessionService.updateSession(sessionId, updateSession)
|
||||
const session = await sessionService.updateSession(agentId, sessionId, updateSession)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found for patch: ${sessionId}`)
|
||||
@ -213,7 +213,7 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
||||
logger.info(`Deleting session: ${sessionId} for agent: ${agentId}`)
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(sessionId)
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||
return res.status(404).json({
|
||||
@ -225,7 +225,7 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
||||
})
|
||||
}
|
||||
|
||||
const deleted = await sessionService.deleteSession(sessionId)
|
||||
const deleted = await sessionService.deleteSession(agentId, sessionId)
|
||||
|
||||
if (!deleted) {
|
||||
logger.warn(`Session not found for deletion: ${sessionId}`)
|
||||
@ -287,7 +287,7 @@ export const getSessionById = async (req: Request, res: Response): Promise<Respo
|
||||
const { sessionId } = req.params
|
||||
logger.info(`Getting session: ${sessionId}`)
|
||||
|
||||
const session = await sessionService.getSession(sessionId)
|
||||
const session = await sessionService.getSessionById(sessionId)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found: ${sessionId}`)
|
||||
|
||||
@ -9,7 +9,7 @@ import type {
|
||||
} from '@types'
|
||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||
import { count, eq } from 'drizzle-orm'
|
||||
import { eq } from 'drizzle-orm'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
||||
@ -148,17 +148,9 @@ export class SessionMessageService extends BaseService {
|
||||
async listSessionMessages(
|
||||
sessionId: string,
|
||||
options: ListOptions = {}
|
||||
): Promise<{ messages: AgentSessionMessageEntity[]; total: number }> {
|
||||
): Promise<{ messages: AgentSessionMessageEntity[] }> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Get total count
|
||||
const totalResult = await this.database
|
||||
.select({ count: count() })
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||
|
||||
const total = totalResult[0].count
|
||||
|
||||
// Get messages with pagination
|
||||
const baseQuery = this.database
|
||||
.select()
|
||||
@ -175,7 +167,7 @@ export class SessionMessageService extends BaseService {
|
||||
|
||||
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
|
||||
|
||||
return { messages, total }
|
||||
return { messages }
|
||||
}
|
||||
|
||||
async saveUserMessage(
|
||||
|
||||
@ -78,7 +78,25 @@ export class SessionService extends BaseService {
|
||||
return this.deserializeJsonFields(result[0]) as AgentSessionEntity
|
||||
}
|
||||
|
||||
async getSession(id: string): Promise<GetAgentSessionResponse | null> {
|
||||
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select()
|
||||
.from(sessionsTable)
|
||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||
.limit(1)
|
||||
|
||||
if (!result[0]) {
|
||||
return null
|
||||
}
|
||||
|
||||
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
async getSessionById(id: string): Promise<GetAgentSessionResponse | null> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
||||
@ -92,14 +110,6 @@ export class SessionService extends BaseService {
|
||||
return session
|
||||
}
|
||||
|
||||
async getSessionWithAgent(id: string): Promise<any | null> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// TODO: Implement join query with agents table when needed
|
||||
// For now, just return the session
|
||||
return await this.getSession(id)
|
||||
}
|
||||
|
||||
async listSessions(
|
||||
agentId?: string,
|
||||
options: ListOptions = {}
|
||||
@ -139,11 +149,11 @@ export class SessionService extends BaseService {
|
||||
return { sessions, total }
|
||||
}
|
||||
|
||||
async updateSession(id: string, updates: UpdateSessionRequest): Promise<GetAgentSessionResponse | null> {
|
||||
async updateSession(agentId: string, id: string, updates: UpdateSessionRequest): Promise<GetAgentSessionResponse | null> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Check if session exists
|
||||
const existing = await this.getSession(id)
|
||||
const existing = await this.getSession(agentId, id)
|
||||
if (!existing) {
|
||||
return null
|
||||
}
|
||||
@ -173,24 +183,26 @@ export class SessionService extends BaseService {
|
||||
|
||||
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
||||
|
||||
return await this.getSession(id)
|
||||
return await this.getSession(agentId, id)
|
||||
}
|
||||
|
||||
async deleteSession(id: string): Promise<boolean> {
|
||||
async deleteSession(agentId: string, id: string): Promise<boolean> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database.delete(sessionsTable).where(eq(sessionsTable.id, id))
|
||||
const result = await this.database
|
||||
.delete(sessionsTable)
|
||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||
|
||||
return result.rowsAffected > 0
|
||||
}
|
||||
|
||||
async sessionExists(id: string): Promise<boolean> {
|
||||
async sessionExists(agentId: string, id: string): Promise<boolean> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select({ id: sessionsTable.id })
|
||||
.from(sessionsTable)
|
||||
.where(eq(sessionsTable.id, id))
|
||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||
.limit(1)
|
||||
|
||||
return result.length > 0
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
* Database entity types for Agent, Session, and SessionMessage
|
||||
* Shared between main and renderer processes
|
||||
*/
|
||||
import { ModelMessage, modelMessageSchema, TextStreamPart, UIMessageChunk } from 'ai'
|
||||
import { ModelMessage, modelMessageSchema, TextStreamPart } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
// ------------------ Core enums and helper types ------------------
|
||||
@ -119,14 +119,6 @@ export type AgentSessionMessageEntity = z.infer<typeof AgentSessionMessageEntity
|
||||
|
||||
// ------------------ Session message payload ------------------
|
||||
|
||||
// Structured content for session messages that preserves both AI SDK and raw data
|
||||
export interface SessionMessageContent {
|
||||
chunk: UIMessageChunk[] // UI-friendly AI SDK chunks for rendering
|
||||
raw: any[] // Original agent-specific messages for data integrity (agent-agnostic)
|
||||
agentResult?: any // Complete result from the underlying agent service
|
||||
agentType: string // The type of agent that generated this message (e.g., 'claude-code', 'openai', etc.)
|
||||
}
|
||||
|
||||
// Not implemented fields:
|
||||
// - plan_model: Optional model for planning/thinking tasks
|
||||
// - small_model: Optional lightweight model for quick responses
|
||||
|
||||
Loading…
Reference in New Issue
Block a user