mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 06:19:05 +08:00
refactor(sessions): update getSession and related methods to include agentId parameter
This commit is contained in:
parent
4839b91cef
commit
df1d4cd62b
@ -12,7 +12,7 @@ const verifyAgentAndSession = async (agentId: string, sessionId: string) => {
|
|||||||
throw { status: 404, code: 'agent_not_found', message: 'Agent not found' }
|
throw { status: 404, code: 'agent_not_found', message: 'Agent not found' }
|
||||||
}
|
}
|
||||||
|
|
||||||
const session = await sessionService.getSession(sessionId)
|
const session = await sessionService.getSession(agentId, sessionId)
|
||||||
if (!session) {
|
if (!session) {
|
||||||
throw { status: 404, code: 'session_not_found', message: 'Session not found' }
|
throw { status: 404, code: 'session_not_found', message: 'Session not found' }
|
||||||
}
|
}
|
||||||
|
|||||||
@ -64,7 +64,7 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
|
|||||||
const { agentId, sessionId } = req.params
|
const { agentId, sessionId } = req.params
|
||||||
logger.info(`Getting session: ${sessionId} for agent: ${agentId}`)
|
logger.info(`Getting session: ${sessionId} for agent: ${agentId}`)
|
||||||
|
|
||||||
const session = await sessionService.getSession(sessionId)
|
const session = await sessionService.getSession(agentId, sessionId)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found: ${sessionId}`)
|
logger.warn(`Session not found: ${sessionId}`)
|
||||||
@ -119,7 +119,7 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
logger.debug('Update data:', req.body)
|
logger.debug('Update data:', req.body)
|
||||||
|
|
||||||
// First check if session exists and belongs to agent
|
// First check if session exists and belongs to agent
|
||||||
const existingSession = await sessionService.getSession(sessionId)
|
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
@ -133,7 +133,7 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
|
|
||||||
// For PUT, we replace the entire resource
|
// For PUT, we replace the entire resource
|
||||||
const sessionData = { ...req.body, main_agent_id: agentId }
|
const sessionData = { ...req.body, main_agent_id: agentId }
|
||||||
const session = await sessionService.updateSession(sessionId, sessionData)
|
const session = await sessionService.updateSession(agentId, sessionId, sessionData)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found for update: ${sessionId}`)
|
logger.warn(`Session not found for update: ${sessionId}`)
|
||||||
@ -167,7 +167,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
|||||||
logger.debug('Patch data:', req.body)
|
logger.debug('Patch data:', req.body)
|
||||||
|
|
||||||
// First check if session exists and belongs to agent
|
// First check if session exists and belongs to agent
|
||||||
const existingSession = await sessionService.getSession(sessionId)
|
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
@ -180,7 +180,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
const updateSession = { ...existingSession, ...req.body }
|
const updateSession = { ...existingSession, ...req.body }
|
||||||
const session = await sessionService.updateSession(sessionId, updateSession)
|
const session = await sessionService.updateSession(agentId, sessionId, updateSession)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found for patch: ${sessionId}`)
|
logger.warn(`Session not found for patch: ${sessionId}`)
|
||||||
@ -213,7 +213,7 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
logger.info(`Deleting session: ${sessionId} for agent: ${agentId}`)
|
logger.info(`Deleting session: ${sessionId} for agent: ${agentId}`)
|
||||||
|
|
||||||
// First check if session exists and belongs to agent
|
// First check if session exists and belongs to agent
|
||||||
const existingSession = await sessionService.getSession(sessionId)
|
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||||
return res.status(404).json({
|
return res.status(404).json({
|
||||||
@ -225,7 +225,7 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleted = await sessionService.deleteSession(sessionId)
|
const deleted = await sessionService.deleteSession(agentId, sessionId)
|
||||||
|
|
||||||
if (!deleted) {
|
if (!deleted) {
|
||||||
logger.warn(`Session not found for deletion: ${sessionId}`)
|
logger.warn(`Session not found for deletion: ${sessionId}`)
|
||||||
@ -287,7 +287,7 @@ export const getSessionById = async (req: Request, res: Response): Promise<Respo
|
|||||||
const { sessionId } = req.params
|
const { sessionId } = req.params
|
||||||
logger.info(`Getting session: ${sessionId}`)
|
logger.info(`Getting session: ${sessionId}`)
|
||||||
|
|
||||||
const session = await sessionService.getSession(sessionId)
|
const session = await sessionService.getSessionById(sessionId)
|
||||||
|
|
||||||
if (!session) {
|
if (!session) {
|
||||||
logger.warn(`Session not found: ${sessionId}`)
|
logger.warn(`Session not found: ${sessionId}`)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import type {
|
|||||||
} from '@types'
|
} from '@types'
|
||||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||||
import { count, eq } from 'drizzle-orm'
|
import { eq } from 'drizzle-orm'
|
||||||
|
|
||||||
import { BaseService } from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
||||||
@ -148,17 +148,9 @@ export class SessionMessageService extends BaseService {
|
|||||||
async listSessionMessages(
|
async listSessionMessages(
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
options: ListOptions = {}
|
options: ListOptions = {}
|
||||||
): Promise<{ messages: AgentSessionMessageEntity[]; total: number }> {
|
): Promise<{ messages: AgentSessionMessageEntity[] }> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
// Get total count
|
|
||||||
const totalResult = await this.database
|
|
||||||
.select({ count: count() })
|
|
||||||
.from(sessionMessagesTable)
|
|
||||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
|
||||||
|
|
||||||
const total = totalResult[0].count
|
|
||||||
|
|
||||||
// Get messages with pagination
|
// Get messages with pagination
|
||||||
const baseQuery = this.database
|
const baseQuery = this.database
|
||||||
.select()
|
.select()
|
||||||
@ -175,7 +167,7 @@ export class SessionMessageService extends BaseService {
|
|||||||
|
|
||||||
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
|
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
|
||||||
|
|
||||||
return { messages, total }
|
return { messages }
|
||||||
}
|
}
|
||||||
|
|
||||||
async saveUserMessage(
|
async saveUserMessage(
|
||||||
|
|||||||
@ -78,7 +78,25 @@ export class SessionService extends BaseService {
|
|||||||
return this.deserializeJsonFields(result[0]) as AgentSessionEntity
|
return this.deserializeJsonFields(result[0]) as AgentSessionEntity
|
||||||
}
|
}
|
||||||
|
|
||||||
async getSession(id: string): Promise<GetAgentSessionResponse | null> {
|
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
|
||||||
|
this.ensureInitialized()
|
||||||
|
|
||||||
|
const result = await this.database
|
||||||
|
.select()
|
||||||
|
.from(sessionsTable)
|
||||||
|
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
if (!result[0]) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse
|
||||||
|
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
async getSessionById(id: string): Promise<GetAgentSessionResponse | null> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
||||||
@ -92,14 +110,6 @@ export class SessionService extends BaseService {
|
|||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
async getSessionWithAgent(id: string): Promise<any | null> {
|
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
// TODO: Implement join query with agents table when needed
|
|
||||||
// For now, just return the session
|
|
||||||
return await this.getSession(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
async listSessions(
|
async listSessions(
|
||||||
agentId?: string,
|
agentId?: string,
|
||||||
options: ListOptions = {}
|
options: ListOptions = {}
|
||||||
@ -139,11 +149,11 @@ export class SessionService extends BaseService {
|
|||||||
return { sessions, total }
|
return { sessions, total }
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateSession(id: string, updates: UpdateSessionRequest): Promise<GetAgentSessionResponse | null> {
|
async updateSession(agentId: string, id: string, updates: UpdateSessionRequest): Promise<GetAgentSessionResponse | null> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
// Check if session exists
|
// Check if session exists
|
||||||
const existing = await this.getSession(id)
|
const existing = await this.getSession(agentId, id)
|
||||||
if (!existing) {
|
if (!existing) {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
@ -173,24 +183,26 @@ export class SessionService extends BaseService {
|
|||||||
|
|
||||||
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
||||||
|
|
||||||
return await this.getSession(id)
|
return await this.getSession(agentId, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async deleteSession(id: string): Promise<boolean> {
|
async deleteSession(agentId: string, id: string): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
const result = await this.database.delete(sessionsTable).where(eq(sessionsTable.id, id))
|
const result = await this.database
|
||||||
|
.delete(sessionsTable)
|
||||||
|
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||||
|
|
||||||
return result.rowsAffected > 0
|
return result.rowsAffected > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
async sessionExists(id: string): Promise<boolean> {
|
async sessionExists(agentId: string, id: string): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
const result = await this.database
|
const result = await this.database
|
||||||
.select({ id: sessionsTable.id })
|
.select({ id: sessionsTable.id })
|
||||||
.from(sessionsTable)
|
.from(sessionsTable)
|
||||||
.where(eq(sessionsTable.id, id))
|
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||||
.limit(1)
|
.limit(1)
|
||||||
|
|
||||||
return result.length > 0
|
return result.length > 0
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
* Database entity types for Agent, Session, and SessionMessage
|
* Database entity types for Agent, Session, and SessionMessage
|
||||||
* Shared between main and renderer processes
|
* Shared between main and renderer processes
|
||||||
*/
|
*/
|
||||||
import { ModelMessage, modelMessageSchema, TextStreamPart, UIMessageChunk } from 'ai'
|
import { ModelMessage, modelMessageSchema, TextStreamPart } from 'ai'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
// ------------------ Core enums and helper types ------------------
|
// ------------------ Core enums and helper types ------------------
|
||||||
@ -119,14 +119,6 @@ export type AgentSessionMessageEntity = z.infer<typeof AgentSessionMessageEntity
|
|||||||
|
|
||||||
// ------------------ Session message payload ------------------
|
// ------------------ Session message payload ------------------
|
||||||
|
|
||||||
// Structured content for session messages that preserves both AI SDK and raw data
|
|
||||||
export interface SessionMessageContent {
|
|
||||||
chunk: UIMessageChunk[] // UI-friendly AI SDK chunks for rendering
|
|
||||||
raw: any[] // Original agent-specific messages for data integrity (agent-agnostic)
|
|
||||||
agentResult?: any // Complete result from the underlying agent service
|
|
||||||
agentType: string // The type of agent that generated this message (e.g., 'claude-code', 'openai', etc.)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not implemented fields:
|
// Not implemented fields:
|
||||||
// - plan_model: Optional model for planning/thinking tasks
|
// - plan_model: Optional model for planning/thinking tasks
|
||||||
// - small_model: Optional lightweight model for quick responses
|
// - small_model: Optional lightweight model for quick responses
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user