From e7c37231e0305e7f8ccf687afba9c045a59b6990 Mon Sep 17 00:00:00 2001 From: Vaayne Date: Thu, 18 Sep 2025 00:30:43 +0800 Subject: [PATCH] feat: Enhance message handling with user message persistence and improved stream management --- .../routes/agents/handlers/messages.ts | 79 +++- .../agents/services/SessionMessageService.ts | 352 ++++++++++++++---- .../agents/services/claudecode/transform.ts | 135 ++----- 3 files changed, 383 insertions(+), 183 deletions(-) diff --git a/src/main/apiServer/routes/agents/handlers/messages.ts b/src/main/apiServer/routes/agents/handlers/messages.ts index fde9a07a0e..571778f337 100644 --- a/src/main/apiServer/routes/agents/handlers/messages.ts +++ b/src/main/apiServer/routes/agents/handlers/messages.ts @@ -35,6 +35,12 @@ export const createMessage = async (req: Request, res: Response): Promise logger.info(`Creating streaming message for session: ${sessionId}`) logger.debug('Streaming message data:', messageData) + // Step 1: Save user message first + const userMessage = await sessionMessageService.saveUserMessage( + sessionId, + messageData.content + ) + // Set SSE headers res.setHeader('Content-Type', 'text/event-stream') res.setHeader('Cache-Control', 'no-cache') @@ -42,13 +48,36 @@ export const createMessage = async (req: Request, res: Response): Promise res.setHeader('Access-Control-Allow-Origin', '*') res.setHeader('Access-Control-Allow-Headers', 'Cache-Control') - // Send initial connection event - res.write('data: {"type":"start"}\n\n') - const messageStream = sessionMessageService.createSessionMessage(session, messageData) + const messageStream = sessionMessageService.createSessionMessage(session, messageData, userMessage.id) - // Track if the response has ended to prevent further writes + // Track stream lifecycle so we keep the SSE connection open until persistence finishes let responseEnded = false + let streamFinished = false + let awaitingPersistence = false + let persistenceResolved = false + + const finalizeResponse = () => { + if (responseEnded) { + return + } + + if (!streamFinished) { + return + } + + if (awaitingPersistence && !persistenceResolved) { + return + } + + responseEnded = true + try { + res.write('data: [DONE]\n\n') + } catch (writeError) { + logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error }) + } + res.end() + } // Handle client disconnect req.on('close', () => { @@ -76,18 +105,44 @@ export const createMessage = async (req: Request, res: Response): Promise } res.write(`data: ${JSON.stringify(errorChunk)}\n\n`) logger.error(`Streaming message error for session: ${sessionId}:`, event.error) - responseEnded = true - res.write('data: [DONE]\n\n') - res.end() + + streamFinished = true + awaitingPersistence = Boolean(event.persistScheduled) + + if (!awaitingPersistence) { + persistenceResolved = true + } + + finalizeResponse() break } - case 'complete': - // Send completion marker following AI SDK protocol + case 'complete': { logger.info(`Streaming message completed for session: ${sessionId}`) - responseEnded = true - res.write('data: [DONE]\n\n') - res.end() + res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`) + + streamFinished = true + awaitingPersistence = true + finalizeResponse() + break + } + + case 'persisted': + // Send persistence success event + res.write(`data: ${JSON.stringify(event)}\n\n`) + logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id }) + + persistenceResolved = true + finalizeResponse() + break + + case 'persist-error': + // Send persistence error event + res.write(`data: ${JSON.stringify(event)}\n\n`) + logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error) + + persistenceResolved = true + finalizeResponse() break default: diff --git a/src/main/services/agents/services/SessionMessageService.ts b/src/main/services/agents/services/SessionMessageService.ts index 57a427a8be..1ab016b198 100644 --- a/src/main/services/agents/services/SessionMessageService.ts +++ b/src/main/services/agents/services/SessionMessageService.ts @@ -7,15 +7,175 @@ import type { GetAgentSessionResponse, ListOptions, } from '@types' -import { UIMessageChunk } from 'ai' +import { ModelMessage, UIMessage, UIMessageChunk } from 'ai' +import { convertToModelMessages, readUIMessageStream } from 'ai' import { count, eq } from 'drizzle-orm' import { BaseService } from '../BaseService' -import { sessionMessagesTable } from '../database/schema' +import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema' import ClaudeCodeService from './claudecode' const logger = loggerService.withContext('SessionMessageService') + +// Collapse a UIMessageChunk stream into a final UIMessage, then convert to ModelMessage[] +export async function chunksToModelMessages( + chunkStream: ReadableStream, + priorUiHistory: UIMessage[] = [] +): Promise { + let latest: UIMessage | undefined + + for await (const uiMsg of readUIMessageStream({ stream: chunkStream })) { + latest = uiMsg // each yield is a newer state; keep the last one + } + + const uiMessages = latest ? [...priorUiHistory, latest] : priorUiHistory + return convertToModelMessages(uiMessages) // -> ModelMessage[] +} + +// Utility function to normalize content to ModelMessage +function normalizeModelMessage(content: string | ModelMessage): ModelMessage { + if (typeof content === 'string') { + return { + role: 'user', + content: content + } + } + return content +} + +// Ensure errors emitted through SSE are serializable +function serializeError(error: unknown): { message: string; name?: string; stack?: string } { + if (error instanceof Error) { + return { + message: error.message, + name: error.name, + stack: error.stack + } + } + + if (typeof error === 'string') { + return { message: error } + } + + return { + message: 'Unknown error' + } +} + +// Interface for persistence context +interface PersistContext { + session: GetAgentSessionResponse + accumulator: ChunkAccumulator + userMessageId: number + sessionStream: EventEmitter +} + +// Chunk accumulator class to collect and reconstruct streaming data +class ChunkAccumulator { + private streamedChunks: UIMessageChunk[] = [] + private rawAgentMessages: any[] = [] + private agentResult: any = null + private agentType: string = 'unknown' + private uniqueIds: Set = new Set() + + addChunk(chunk: UIMessageChunk): void { + this.streamedChunks.push(chunk) + } + + addRawMessage(message: any): void { + if (message.uuid && this.uniqueIds.has(message.uuid)) { + // Duplicate message based on uuid; skip adding + return + } + if (message.uuid) { + this.uniqueIds.add(message.uuid) + } + this.rawAgentMessages.push(message) + } + + setAgentResult(result: any): void { + this.agentResult = result + if (result?.agentType) { + this.agentType = result.agentType + } + } + + buildStructuredContent() { + return { + aiSDKChunks: this.streamedChunks, + rawAgentMessages: this.rawAgentMessages, + agentResult: this.agentResult, + agentType: this.agentType + } + } + + // Create a ReadableStream from accumulated chunks + createChunkStream(): ReadableStream { + const chunks = [...this.streamedChunks] + + return new ReadableStream({ + start(controller) { + // Enqueue all chunks + for (const chunk of chunks) { + controller.enqueue(chunk) + } + controller.close() + } + }) + } + + // Convert accumulated chunks to ModelMessages using chunksToModelMessages + async toModelMessages(priorUiHistory: UIMessage[] = []): Promise { + const chunkStream = this.createChunkStream() + return await chunksToModelMessages(chunkStream, priorUiHistory) + } + + toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage { + // Reconstruct the content from chunks + let textContent = '' + const toolCalls: any[] = [] + + for (const chunk of this.streamedChunks) { + if (chunk.type === 'text-delta' && 'delta' in chunk) { + textContent += chunk.delta + } else if (chunk.type === 'tool-input-available' && 'toolCallId' in chunk && 'toolName' in chunk) { + // Handle tool calls - use tool-input-available chunks + const toolCall = { + toolCallId: chunk.toolCallId, + toolName: chunk.toolName, + args: chunk.input || {} + } + toolCalls.push(toolCall) + } + } + + const message: any = { + role, + content: textContent + } + + // Add tool invocations if any + if (toolCalls.length > 0) { + message.toolInvocations = toolCalls + } + + return message as ModelMessage + } + + getChunkCount(): number { + return this.streamedChunks.length + } + + getRawMessageCount(): number { + return this.rawAgentMessages.length + } + + getAgentType(): string { + return this.agentType + } +} + export class SessionMessageService extends BaseService { private static instance: SessionMessageService | null = null private cc: ClaudeCodeService = new ClaudeCodeService() @@ -76,14 +236,44 @@ export class SessionMessageService extends BaseService { return { messages, total } } - createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter { + async saveUserMessage(sessionId: string, content: ModelMessage | string): Promise { + this.ensureInitialized() + + const now = new Date().toISOString() + const userContent: ModelMessage = normalizeModelMessage(content) + + const insertData: InsertSessionMessageRow = { + session_id: sessionId, + role: 'user', + content: JSON.stringify(userContent), + metadata: JSON.stringify({ + timestamp: now, + source: 'api' + }), + created_at: now, + updated_at: now + } + + const [saved] = await this.database + .insert(sessionMessagesTable) + .values(insertData) + .returning() + + return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity + } + + createSessionMessage( + session: GetAgentSessionResponse, + messageData: CreateSessionMessageRequest, + userMessageId: number + ): EventEmitter { this.ensureInitialized() // Create a new EventEmitter to manage the session message lifecycle const sessionStream = new EventEmitter() // No parent validation needed, start immediately - this.startSessionMessageStream(session, messageData, sessionStream) + this.startSessionMessageStream(session, messageData, sessionStream, userMessageId) return sessionStream } @@ -91,7 +281,8 @@ export class SessionMessageService extends BaseService { private startSessionMessageStream( session: GetAgentSessionResponse, req: CreateSessionMessageRequest, - sessionStream: EventEmitter + sessionStream: EventEmitter, + userMessageId: number ): void { const previousMessages = session.messages || [] let session_id: string = '' @@ -112,8 +303,8 @@ export class SessionMessageService extends BaseService { maxTurns: session.configuration?.maxTurns || 10 }) - const streamedChunks: UIMessageChunk[] = [] - const rawAgentMessages: any[] = [] // Generic agent messages storage + // Use chunk accumulator to manage streaming data + const accumulator = new ChunkAccumulator() // Handle agent stream events (agent-agnostic) claudeStream.on('data', async (event: any) => { @@ -123,11 +314,11 @@ export class SessionMessageService extends BaseService { // Forward UIMessageChunk directly and collect raw agent messages if (event.chunk) { const chunk = event.chunk as UIMessageChunk - streamedChunks.push(chunk) + accumulator.addChunk(chunk) // Collect raw agent message if available (agent-agnostic) if (event.rawAgentMessage) { - rawAgentMessages.push(event.rawAgentMessage) + accumulator.addRawMessage(event.rawAgentMessage) } sessionStream.emit('data', { @@ -139,76 +330,55 @@ export class SessionMessageService extends BaseService { } break - case 'error': + case 'error': { + const underlyingError = event.error || (event.data?.stderr ? new Error(event.data.stderr) : undefined) + const persistScheduled = accumulator.getChunkCount() > 0 + + if (persistScheduled) { + // Try to save partial state with error metadata when possible + accumulator.setAgentResult({ + error: serializeError(underlyingError), + agentType: 'claude-code', + incomplete: true + }) + + void this.persistSessionMessageAsync({ + session, + accumulator, + userMessageId, + sessionStream + }) + } + sessionStream.emit('data', { type: 'error', - error: event.error || (event.data?.stderr ? new Error(event.data.stderr) : undefined) + error: serializeError(underlyingError), + persistScheduled }) break + } case 'complete': { - // Save the final message to database when agent completes - logger.info('Agent stream completed, saving message to database') - // Extract additional raw agent messages from agentResult if available if (event.agentResult?.rawSDKMessages) { - rawAgentMessages.push(...event.agentResult.rawSDKMessages) + event.agentResult.rawSDKMessages.forEach((msg: any) => accumulator.addRawMessage(msg)) } - // Create structured content with both AI SDK format and raw data - const structuredContent = { - aiSDKChunks: streamedChunks, // For UI consumption - rawAgentMessages: rawAgentMessages, // Original agent-specific messages - agentResult: event.agentResult, // Complete result from the agent - agentType: event.agentResult?.agentType || 'unknown' // Store agent type for future reference - } + // Set the agent result in the accumulator + accumulator.setAgentResult(event.agentResult) - // const now = new Date().toISOString() - // const insertData: InsertSessionMessageRow = { - // session_id: req.session_id, - // parent_id: req.parent_id || null, - // role: req.role, - // type: req.type, - // content: JSON.stringify(structuredContent), - // metadata: req.metadata - // ? JSON.stringify({ - // ...req.metadata, - // chunkCount: streamedChunks.length, - // rawMessageCount: rawAgentMessages.length, - // agentType: event.agentResult?.agentType || 'unknown', - // completedAt: now - // }) - // : JSON.stringify({ - // chunkCount: streamedChunks.length, - // rawMessageCount: rawAgentMessages.length, - // agentType: event.agentResult?.agentType || 'unknown', - // completedAt: now - // }), - // created_at: now, - // updated_at: now - // } + // // Emit SSE completion FIRST before persistence + // sessionStream.emit('data', { + // type: 'complete', + // result: accumulator.buildStructuredContent() + // }) - // const result = await this.database.insert(sessionMessagesTable).values(insertData).returning() - - // if (result[0]) { - // sessionMessage = this.deserializeSessionMessage(result[0]) as AgentSessionMessageEntity - // logger.info(`Session message saved with ID: ${sessionMessage.id}`) - - // // Emit the complete event with the saved message and structured data - // sessionStream.emit('data', { - // type: 'complete', - // result: structuredContent, - // message: sessionMessage - // }) - // } else { - // sessionStream.emit('data', { - // type: 'error', - // error: new Error('Failed to save session message to database') - // }) - // } - sessionStream.emit('data', { - type: 'complete', - result: structuredContent + // Then handle async persistence + void this.persistSessionMessageAsync({ + session, + accumulator, + userMessageId, + sessionStream }) break } @@ -223,12 +393,58 @@ export class SessionMessageService extends BaseService { logger.error('Error handling Claude Code stream event:', { error }) sessionStream.emit('data', { type: 'error', - error: error as Error + error: serializeError(error) }) } }) } + private async persistSessionMessageAsync({ session, accumulator, userMessageId, sessionStream }: PersistContext) { + if (!session?.id) { + const missingSessionError = new Error('Missing session_id for persisted message') + logger.error(missingSessionError.message, { error: missingSessionError }) + sessionStream.emit('data', { type: 'persist-error', error: serializeError(missingSessionError) }) + return + } + + const sessionId = session.id + const now = new Date().toISOString() + const structured = accumulator.buildStructuredContent() + + try { + // Use chunksToModelMessages to convert chunks to ModelMessages + const modelMessages = await accumulator.toModelMessages() + // Get the last message (should be the assistant's response) + const modelMessage = + modelMessages.length > 0 ? modelMessages[modelMessages.length - 1] : accumulator.toModelMessage('assistant') + + const metadata = { + userMessageId, + chunkCount: accumulator.getChunkCount(), + rawMessageCount: accumulator.getRawMessageCount(), + agentType: accumulator.getAgentType(), + completedAt: now + } + + const insertData: InsertSessionMessageRow = { + session_id: sessionId, + role: 'assistant', + content: JSON.stringify({ modelMessage, ...structured }), + metadata: JSON.stringify(metadata), + created_at: now, + updated_at: now + } + + const [row] = await this.database.insert(sessionMessagesTable).values(insertData).returning() + + const entity = this.deserializeSessionMessage(row) as AgentSessionMessageEntity + sessionStream.emit('data', { type: 'persisted', message: entity }) + } catch (error) { + logger.error('Failed to persist session message', { error }) + sessionStream.emit('data', { type: 'persist-error', error: serializeError(error) }) + } + } + private deserializeSessionMessage(data: any): AgentSessionMessageEntity { if (!data) return data diff --git a/src/main/services/agents/services/claudecode/transform.ts b/src/main/services/agents/services/claudecode/transform.ts index 31b8be0cf9..fc8f55132b 100644 --- a/src/main/services/agents/services/claudecode/transform.ts +++ b/src/main/services/agents/services/claudecode/transform.ts @@ -74,14 +74,7 @@ export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageC function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata { const meta: ProviderMetadata = { - raw: message as Record, - claudeCode: { - originalSDKMessage: JSON.parse(JSON.stringify(message)), // Serialize to ensure JSON compatibility - uuid: message.uuid || null, - session_id: message.session_id || null, - timestamp: new Date().toISOString(), - type: message.type - } + message: message as Record } return meta } @@ -89,7 +82,7 @@ function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata { // Handle assistant messages function handleAssistantMessage(message: Extract): UIMessageChunk[] { const chunks: UIMessageChunk[] = [] - const messageId = generateMessageId() + const messageId = message.uuid // Extract text content const textContent = extractTextContent(message.message as MessageParam) @@ -97,36 +90,18 @@ function handleAssistantMessage(message: Extract): const chunks: UIMessageChunk[] = [] if (message.subtype === 'init') { + chunks.push({ + type: 'start', + messageId: message.session_id + }) + // System initialization - could emit as a data chunk or skip chunks.push({ type: 'data-system' as any, data: { type: 'init', - cwd: message.cwd, - tools: message.tools, - model: message.model, - mcp_servers: message.mcp_servers, + session_id: message.session_id, raw: message } }) @@ -319,63 +296,14 @@ function handleSystemMessage(message: Extract): function handleResultMessage(message: Extract): UIMessageChunk[] { const chunks: UIMessageChunk[] = [] + const messageId = message.uuid if (message.subtype === 'success') { - // Emit the final result text if available - if (message.result) { - const messageId = generateMessageId() - chunks.push( - { - type: 'text-start', - id: messageId, - providerMetadata: { - anthropic: { - uuid: message.uuid, - session_id: message.session_id, - final_result: true - }, - raw: sdkMessageToProviderMetadata(message) - } - }, - { - type: 'text-delta', - id: messageId, - delta: message.result, - providerMetadata: { - anthropic: { - uuid: message.uuid, - session_id: message.session_id, - final_result: true - }, - raw: sdkMessageToProviderMetadata(message) - } - }, - { - type: 'text-end', - id: messageId, - providerMetadata: { - anthropic: { - uuid: message.uuid, - session_id: message.session_id, - final_result: true - }, - raw: sdkMessageToProviderMetadata(message) - } - } - ) - } - - // Emit usage and cost data + // Emit final result data chunks.push({ - type: 'data-usage' as any, - data: { - duration_ms: message.duration_ms, - duration_api_ms: message.duration_api_ms, - num_turns: message.num_turns, - total_cost_usd: message.total_cost_usd, - usage: message.usage, - modelUsage: message.modelUsage, - permission_denials: message.permission_denials - } + type: 'data-result' as any, + id: messageId, + data: message, + transient: true }) } else { // Handle error cases @@ -383,22 +311,23 @@ function handleResultMessage(message: Extract): type: 'error', errorText: `${message.subtype}: Process failed after ${message.num_turns} turns` }) - - // Still emit usage data for failed requests - chunks.push({ - type: 'data-usage' as any, - data: { - duration_ms: message.duration_ms, - duration_api_ms: message.duration_api_ms, - num_turns: message.num_turns, - total_cost_usd: message.total_cost_usd, - usage: message.usage, - modelUsage: message.modelUsage, - permission_denials: message.permission_denials - } - }) } + // Emit usage and cost data + chunks.push({ + type: 'data-usage' as any, + data: { + cost: message.total_cost_usd, + usage: { + input_tokens: message.usage.input_tokens, + cache_creation_input_tokens: message.usage.cache_creation_input_tokens, + cache_read_input_tokens: message.usage.cache_read_input_tokens, + output_tokens: message.usage.output_tokens, + service_tier: 'standard' + } + } + }) + // Always emit a finish chunk at the end chunks.push({ type: 'finish'