diff --git a/src/main/services/agents/services/SessionMessageService.ts b/src/main/services/agents/services/SessionMessageService.ts index b0a08e6efb..b3245edbc2 100644 --- a/src/main/services/agents/services/SessionMessageService.ts +++ b/src/main/services/agents/services/SessionMessageService.ts @@ -9,7 +9,7 @@ import type { } from '@types' import { ModelMessage, UIMessage, UIMessageChunk } from 'ai' import { convertToModelMessages, readUIMessageStream } from 'ai' -import { eq } from 'drizzle-orm' +import { desc, eq } from 'drizzle-orm' import { BaseService } from '../BaseService' import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema' @@ -170,29 +170,6 @@ export class SessionMessageService extends BaseService { return { messages } } - async saveUserMessage( - tx: any, - sessionId: string, - prompt: string, - agentSessionId: string - ): Promise { - this.ensureInitialized() - - const now = new Date().toISOString() - const insertData: InsertSessionMessageRow = { - session_id: sessionId, - role: 'user', - content: prompt, - agent_session_id: agentSessionId, - created_at: now, - updated_at: now - } - - const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning() - - return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity - } - createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter { this.ensureInitialized() @@ -210,12 +187,8 @@ export class SessionMessageService extends BaseService { req: CreateSessionMessageRequest, sessionStream: EventEmitter ): Promise { - const previousMessages = session.messages || [] - let agentSessionId: string = '' - if (previousMessages.length > 0) { - agentSessionId = previousMessages[previousMessages.length - 1].agent_session_id - } - + const agentSessionId = await this.getLastAgentSessionId(session.id) + let newAgentSessionId = '' logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId }) if (session.agent_type !== 'claude-code') { @@ -223,7 +196,6 @@ export class SessionMessageService extends BaseService { logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type }) throw new Error('Unsupported agent type for streaming') } - let newAgentSessionId = '' // Create the streaming agent invocation (using invokeStream for streaming) const claudeStream = this.cc.invoke(req.content, session.accessible_paths[0], agentSessionId, { @@ -273,8 +245,8 @@ export class SessionMessageService extends BaseService { case 'complete': { // Then handle async persistence this.database.transaction(async (tx) => { - await this.saveUserMessage(tx, session.id, req.content, newAgentSessionId) - await this.persistSessionMessageAsync({ + await this.persistUserMessage(tx, session.id, req.content, newAgentSessionId) + await this.persistAssistantMessage({ tx, session, accumulator, @@ -304,7 +276,51 @@ export class SessionMessageService extends BaseService { }) } - private async persistSessionMessageAsync({ + private async getLastAgentSessionId(sessionId: string): Promise { + this.ensureInitialized() + + try { + const result = await this.database + .select({ agent_session_id: sessionMessagesTable.agent_session_id }) + .from(sessionMessagesTable) + .where(eq(sessionMessagesTable.session_id, sessionId)) + .orderBy(desc(sessionMessagesTable.created_at)) + .limit(1) + + return result[0]?.agent_session_id || '' + } catch (error) { + logger.error('Failed to get last agent session ID', { + sessionId, + error + }) + return '' + } + } + + async persistUserMessage( + tx: any, + sessionId: string, + prompt: string, + agentSessionId: string + ): Promise { + this.ensureInitialized() + + const now = new Date().toISOString() + const insertData: InsertSessionMessageRow = { + session_id: sessionId, + role: 'user', + content: prompt, + agent_session_id: agentSessionId, + created_at: now, + updated_at: now + } + + const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning() + + return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity + } + + private async persistAssistantMessage({ tx, session, accumulator,