feat: Enhance message handling with user message persistence and improved stream management

This commit is contained in:
Vaayne 2025-09-18 00:30:43 +08:00
parent c196a02c95
commit e7c37231e0
3 changed files with 383 additions and 183 deletions

View File

@ -35,6 +35,12 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
logger.info(`Creating streaming message for session: ${sessionId}`)
logger.debug('Streaming message data:', messageData)
// Step 1: Save user message first
const userMessage = await sessionMessageService.saveUserMessage(
sessionId,
messageData.content
)
// Set SSE headers
res.setHeader('Content-Type', 'text/event-stream')
res.setHeader('Cache-Control', 'no-cache')
@ -42,13 +48,36 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
res.setHeader('Access-Control-Allow-Origin', '*')
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
// Send initial connection event
res.write('data: {"type":"start"}\n\n')
const messageStream = sessionMessageService.createSessionMessage(session, messageData)
const messageStream = sessionMessageService.createSessionMessage(session, messageData, userMessage.id)
// Track if the response has ended to prevent further writes
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
let responseEnded = false
let streamFinished = false
let awaitingPersistence = false
let persistenceResolved = false
const finalizeResponse = () => {
if (responseEnded) {
return
}
if (!streamFinished) {
return
}
if (awaitingPersistence && !persistenceResolved) {
return
}
responseEnded = true
try {
res.write('data: [DONE]\n\n')
} catch (writeError) {
logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error })
}
res.end()
}
// Handle client disconnect
req.on('close', () => {
@ -76,18 +105,44 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
}
res.write(`data: ${JSON.stringify(errorChunk)}\n\n`)
logger.error(`Streaming message error for session: ${sessionId}:`, event.error)
responseEnded = true
res.write('data: [DONE]\n\n')
res.end()
streamFinished = true
awaitingPersistence = Boolean(event.persistScheduled)
if (!awaitingPersistence) {
persistenceResolved = true
}
finalizeResponse()
break
}
case 'complete':
// Send completion marker following AI SDK protocol
case 'complete': {
logger.info(`Streaming message completed for session: ${sessionId}`)
responseEnded = true
res.write('data: [DONE]\n\n')
res.end()
res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
streamFinished = true
awaitingPersistence = true
finalizeResponse()
break
}
case 'persisted':
// Send persistence success event
res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id })
persistenceResolved = true
finalizeResponse()
break
case 'persist-error':
// Send persistence error event
res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error)
persistenceResolved = true
finalizeResponse()
break
default:

View File

@ -7,15 +7,175 @@ import type {
GetAgentSessionResponse,
ListOptions,
} from '@types'
import { UIMessageChunk } from 'ai'
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
import { convertToModelMessages, readUIMessageStream } from 'ai'
import { count, eq } from 'drizzle-orm'
import { BaseService } from '../BaseService'
import { sessionMessagesTable } from '../database/schema'
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
import ClaudeCodeService from './claudecode'
const logger = loggerService.withContext('SessionMessageService')
// Collapse a UIMessageChunk stream into a final UIMessage, then convert to ModelMessage[]
export async function chunksToModelMessages(
chunkStream: ReadableStream<UIMessageChunk>,
priorUiHistory: UIMessage[] = []
): Promise<ModelMessage[]> {
let latest: UIMessage | undefined
for await (const uiMsg of readUIMessageStream({ stream: chunkStream })) {
latest = uiMsg // each yield is a newer state; keep the last one
}
const uiMessages = latest ? [...priorUiHistory, latest] : priorUiHistory
return convertToModelMessages(uiMessages) // -> ModelMessage[]
}
// Utility function to normalize content to ModelMessage
function normalizeModelMessage(content: string | ModelMessage): ModelMessage {
if (typeof content === 'string') {
return {
role: 'user',
content: content
}
}
return content
}
// Ensure errors emitted through SSE are serializable
function serializeError(error: unknown): { message: string; name?: string; stack?: string } {
if (error instanceof Error) {
return {
message: error.message,
name: error.name,
stack: error.stack
}
}
if (typeof error === 'string') {
return { message: error }
}
return {
message: 'Unknown error'
}
}
// Interface for persistence context
interface PersistContext {
session: GetAgentSessionResponse
accumulator: ChunkAccumulator
userMessageId: number
sessionStream: EventEmitter
}
// Chunk accumulator class to collect and reconstruct streaming data
class ChunkAccumulator {
private streamedChunks: UIMessageChunk[] = []
private rawAgentMessages: any[] = []
private agentResult: any = null
private agentType: string = 'unknown'
private uniqueIds: Set<string> = new Set()
addChunk(chunk: UIMessageChunk): void {
this.streamedChunks.push(chunk)
}
addRawMessage(message: any): void {
if (message.uuid && this.uniqueIds.has(message.uuid)) {
// Duplicate message based on uuid; skip adding
return
}
if (message.uuid) {
this.uniqueIds.add(message.uuid)
}
this.rawAgentMessages.push(message)
}
setAgentResult(result: any): void {
this.agentResult = result
if (result?.agentType) {
this.agentType = result.agentType
}
}
buildStructuredContent() {
return {
aiSDKChunks: this.streamedChunks,
rawAgentMessages: this.rawAgentMessages,
agentResult: this.agentResult,
agentType: this.agentType
}
}
// Create a ReadableStream from accumulated chunks
createChunkStream(): ReadableStream<UIMessageChunk> {
const chunks = [...this.streamedChunks]
return new ReadableStream<UIMessageChunk>({
start(controller) {
// Enqueue all chunks
for (const chunk of chunks) {
controller.enqueue(chunk)
}
controller.close()
}
})
}
// Convert accumulated chunks to ModelMessages using chunksToModelMessages
async toModelMessages(priorUiHistory: UIMessage[] = []): Promise<ModelMessage[]> {
const chunkStream = this.createChunkStream()
return await chunksToModelMessages(chunkStream, priorUiHistory)
}
toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage {
// Reconstruct the content from chunks
let textContent = ''
const toolCalls: any[] = []
for (const chunk of this.streamedChunks) {
if (chunk.type === 'text-delta' && 'delta' in chunk) {
textContent += chunk.delta
} else if (chunk.type === 'tool-input-available' && 'toolCallId' in chunk && 'toolName' in chunk) {
// Handle tool calls - use tool-input-available chunks
const toolCall = {
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
args: chunk.input || {}
}
toolCalls.push(toolCall)
}
}
const message: any = {
role,
content: textContent
}
// Add tool invocations if any
if (toolCalls.length > 0) {
message.toolInvocations = toolCalls
}
return message as ModelMessage
}
getChunkCount(): number {
return this.streamedChunks.length
}
getRawMessageCount(): number {
return this.rawAgentMessages.length
}
getAgentType(): string {
return this.agentType
}
}
export class SessionMessageService extends BaseService {
private static instance: SessionMessageService | null = null
private cc: ClaudeCodeService = new ClaudeCodeService()
@ -76,14 +236,44 @@ export class SessionMessageService extends BaseService {
return { messages, total }
}
createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter {
async saveUserMessage(sessionId: string, content: ModelMessage | string): Promise<AgentSessionMessageEntity> {
this.ensureInitialized()
const now = new Date().toISOString()
const userContent: ModelMessage = normalizeModelMessage(content)
const insertData: InsertSessionMessageRow = {
session_id: sessionId,
role: 'user',
content: JSON.stringify(userContent),
metadata: JSON.stringify({
timestamp: now,
source: 'api'
}),
created_at: now,
updated_at: now
}
const [saved] = await this.database
.insert(sessionMessagesTable)
.values(insertData)
.returning()
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
}
createSessionMessage(
session: GetAgentSessionResponse,
messageData: CreateSessionMessageRequest,
userMessageId: number
): EventEmitter {
this.ensureInitialized()
// Create a new EventEmitter to manage the session message lifecycle
const sessionStream = new EventEmitter()
// No parent validation needed, start immediately
this.startSessionMessageStream(session, messageData, sessionStream)
this.startSessionMessageStream(session, messageData, sessionStream, userMessageId)
return sessionStream
}
@ -91,7 +281,8 @@ export class SessionMessageService extends BaseService {
private startSessionMessageStream(
session: GetAgentSessionResponse,
req: CreateSessionMessageRequest,
sessionStream: EventEmitter
sessionStream: EventEmitter,
userMessageId: number
): void {
const previousMessages = session.messages || []
let session_id: string = ''
@ -112,8 +303,8 @@ export class SessionMessageService extends BaseService {
maxTurns: session.configuration?.maxTurns || 10
})
const streamedChunks: UIMessageChunk[] = []
const rawAgentMessages: any[] = [] // Generic agent messages storage
// Use chunk accumulator to manage streaming data
const accumulator = new ChunkAccumulator()
// Handle agent stream events (agent-agnostic)
claudeStream.on('data', async (event: any) => {
@ -123,11 +314,11 @@ export class SessionMessageService extends BaseService {
// Forward UIMessageChunk directly and collect raw agent messages
if (event.chunk) {
const chunk = event.chunk as UIMessageChunk
streamedChunks.push(chunk)
accumulator.addChunk(chunk)
// Collect raw agent message if available (agent-agnostic)
if (event.rawAgentMessage) {
rawAgentMessages.push(event.rawAgentMessage)
accumulator.addRawMessage(event.rawAgentMessage)
}
sessionStream.emit('data', {
@ -139,76 +330,55 @@ export class SessionMessageService extends BaseService {
}
break
case 'error':
case 'error': {
const underlyingError = event.error || (event.data?.stderr ? new Error(event.data.stderr) : undefined)
const persistScheduled = accumulator.getChunkCount() > 0
if (persistScheduled) {
// Try to save partial state with error metadata when possible
accumulator.setAgentResult({
error: serializeError(underlyingError),
agentType: 'claude-code',
incomplete: true
})
void this.persistSessionMessageAsync({
session,
accumulator,
userMessageId,
sessionStream
})
}
sessionStream.emit('data', {
type: 'error',
error: event.error || (event.data?.stderr ? new Error(event.data.stderr) : undefined)
error: serializeError(underlyingError),
persistScheduled
})
break
}
case 'complete': {
// Save the final message to database when agent completes
logger.info('Agent stream completed, saving message to database')
// Extract additional raw agent messages from agentResult if available
if (event.agentResult?.rawSDKMessages) {
rawAgentMessages.push(...event.agentResult.rawSDKMessages)
event.agentResult.rawSDKMessages.forEach((msg: any) => accumulator.addRawMessage(msg))
}
// Create structured content with both AI SDK format and raw data
const structuredContent = {
aiSDKChunks: streamedChunks, // For UI consumption
rawAgentMessages: rawAgentMessages, // Original agent-specific messages
agentResult: event.agentResult, // Complete result from the agent
agentType: event.agentResult?.agentType || 'unknown' // Store agent type for future reference
}
// Set the agent result in the accumulator
accumulator.setAgentResult(event.agentResult)
// const now = new Date().toISOString()
// const insertData: InsertSessionMessageRow = {
// session_id: req.session_id,
// parent_id: req.parent_id || null,
// role: req.role,
// type: req.type,
// content: JSON.stringify(structuredContent),
// metadata: req.metadata
// ? JSON.stringify({
// ...req.metadata,
// chunkCount: streamedChunks.length,
// rawMessageCount: rawAgentMessages.length,
// agentType: event.agentResult?.agentType || 'unknown',
// completedAt: now
// })
// : JSON.stringify({
// chunkCount: streamedChunks.length,
// rawMessageCount: rawAgentMessages.length,
// agentType: event.agentResult?.agentType || 'unknown',
// completedAt: now
// }),
// created_at: now,
// updated_at: now
// }
// // Emit SSE completion FIRST before persistence
// sessionStream.emit('data', {
// type: 'complete',
// result: accumulator.buildStructuredContent()
// })
// const result = await this.database.insert(sessionMessagesTable).values(insertData).returning()
// if (result[0]) {
// sessionMessage = this.deserializeSessionMessage(result[0]) as AgentSessionMessageEntity
// logger.info(`Session message saved with ID: ${sessionMessage.id}`)
// // Emit the complete event with the saved message and structured data
// sessionStream.emit('data', {
// type: 'complete',
// result: structuredContent,
// message: sessionMessage
// })
// } else {
// sessionStream.emit('data', {
// type: 'error',
// error: new Error('Failed to save session message to database')
// })
// }
sessionStream.emit('data', {
type: 'complete',
result: structuredContent
// Then handle async persistence
void this.persistSessionMessageAsync({
session,
accumulator,
userMessageId,
sessionStream
})
break
}
@ -223,12 +393,58 @@ export class SessionMessageService extends BaseService {
logger.error('Error handling Claude Code stream event:', { error })
sessionStream.emit('data', {
type: 'error',
error: error as Error
error: serializeError(error)
})
}
})
}
private async persistSessionMessageAsync({ session, accumulator, userMessageId, sessionStream }: PersistContext) {
if (!session?.id) {
const missingSessionError = new Error('Missing session_id for persisted message')
logger.error(missingSessionError.message, { error: missingSessionError })
sessionStream.emit('data', { type: 'persist-error', error: serializeError(missingSessionError) })
return
}
const sessionId = session.id
const now = new Date().toISOString()
const structured = accumulator.buildStructuredContent()
try {
// Use chunksToModelMessages to convert chunks to ModelMessages
const modelMessages = await accumulator.toModelMessages()
// Get the last message (should be the assistant's response)
const modelMessage =
modelMessages.length > 0 ? modelMessages[modelMessages.length - 1] : accumulator.toModelMessage('assistant')
const metadata = {
userMessageId,
chunkCount: accumulator.getChunkCount(),
rawMessageCount: accumulator.getRawMessageCount(),
agentType: accumulator.getAgentType(),
completedAt: now
}
const insertData: InsertSessionMessageRow = {
session_id: sessionId,
role: 'assistant',
content: JSON.stringify({ modelMessage, ...structured }),
metadata: JSON.stringify(metadata),
created_at: now,
updated_at: now
}
const [row] = await this.database.insert(sessionMessagesTable).values(insertData).returning()
const entity = this.deserializeSessionMessage(row) as AgentSessionMessageEntity
sessionStream.emit('data', { type: 'persisted', message: entity })
} catch (error) {
logger.error('Failed to persist session message', { error })
sessionStream.emit('data', { type: 'persist-error', error: serializeError(error) })
}
}
private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
if (!data) return data

View File

@ -74,14 +74,7 @@ export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageC
function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata {
const meta: ProviderMetadata = {
raw: message as Record<string, any>,
claudeCode: {
originalSDKMessage: JSON.parse(JSON.stringify(message)), // Serialize to ensure JSON compatibility
uuid: message.uuid || null,
session_id: message.session_id || null,
timestamp: new Date().toISOString(),
type: message.type
}
message: message as Record<string, any>
}
return meta
}
@ -89,7 +82,7 @@ function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata {
// Handle assistant messages
function handleAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' }>): UIMessageChunk[] {
const chunks: UIMessageChunk[] = []
const messageId = generateMessageId()
const messageId = message.uuid
// Extract text content
const textContent = extractTextContent(message.message as MessageParam)
@ -97,36 +90,18 @@ function handleAssistantMessage(message: Extract<SDKMessage, { type: 'assistant'
chunks.push(
{
type: 'text-start',
id: messageId,
providerMetadata: {
anthropic: {
uuid: message.uuid,
session_id: message.session_id
},
raw: sdkMessageToProviderMetadata(message)
}
id: messageId
},
{
type: 'text-delta',
id: messageId,
delta: textContent,
providerMetadata: {
anthropic: {
uuid: message.uuid,
session_id: message.session_id
},
raw: sdkMessageToProviderMetadata(message)
}
delta: textContent
},
{
type: 'text-end',
id: messageId,
providerMetadata: {
anthropic: {
uuid: message.uuid,
session_id: message.session_id
},
raw: sdkMessageToProviderMetadata(message)
rawMessage: sdkMessageToProviderMetadata(message)
}
}
)
@ -289,15 +264,17 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
const chunks: UIMessageChunk[] = []
if (message.subtype === 'init') {
chunks.push({
type: 'start',
messageId: message.session_id
})
// System initialization - could emit as a data chunk or skip
chunks.push({
type: 'data-system' as any,
data: {
type: 'init',
cwd: message.cwd,
tools: message.tools,
model: message.model,
mcp_servers: message.mcp_servers,
session_id: message.session_id,
raw: message
}
})
@ -319,63 +296,14 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): UIMessageChunk[] {
const chunks: UIMessageChunk[] = []
const messageId = message.uuid
if (message.subtype === 'success') {
// Emit the final result text if available
if (message.result) {
const messageId = generateMessageId()
chunks.push(
{
type: 'text-start',
id: messageId,
providerMetadata: {
anthropic: {
uuid: message.uuid,
session_id: message.session_id,
final_result: true
},
raw: sdkMessageToProviderMetadata(message)
}
},
{
type: 'text-delta',
id: messageId,
delta: message.result,
providerMetadata: {
anthropic: {
uuid: message.uuid,
session_id: message.session_id,
final_result: true
},
raw: sdkMessageToProviderMetadata(message)
}
},
{
type: 'text-end',
id: messageId,
providerMetadata: {
anthropic: {
uuid: message.uuid,
session_id: message.session_id,
final_result: true
},
raw: sdkMessageToProviderMetadata(message)
}
}
)
}
// Emit usage and cost data
// Emit final result data
chunks.push({
type: 'data-usage' as any,
data: {
duration_ms: message.duration_ms,
duration_api_ms: message.duration_api_ms,
num_turns: message.num_turns,
total_cost_usd: message.total_cost_usd,
usage: message.usage,
modelUsage: message.modelUsage,
permission_denials: message.permission_denials
}
type: 'data-result' as any,
id: messageId,
data: message,
transient: true
})
} else {
// Handle error cases
@ -383,22 +311,23 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
type: 'error',
errorText: `${message.subtype}: Process failed after ${message.num_turns} turns`
})
// Still emit usage data for failed requests
chunks.push({
type: 'data-usage' as any,
data: {
duration_ms: message.duration_ms,
duration_api_ms: message.duration_api_ms,
num_turns: message.num_turns,
total_cost_usd: message.total_cost_usd,
usage: message.usage,
modelUsage: message.modelUsage,
permission_denials: message.permission_denials
}
})
}
// Emit usage and cost data
chunks.push({
type: 'data-usage' as any,
data: {
cost: message.total_cost_usd,
usage: {
input_tokens: message.usage.input_tokens,
cache_creation_input_tokens: message.usage.cache_creation_input_tokens,
cache_read_input_tokens: message.usage.cache_read_input_tokens,
output_tokens: message.usage.output_tokens,
service_tier: 'standard'
}
}
})
// Always emit a finish chunk at the end
chunks.push({
type: 'finish'