mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 14:31:35 +08:00
feat: Enhance message handling with user message persistence and improved stream management
This commit is contained in:
parent
c196a02c95
commit
e7c37231e0
@ -35,6 +35,12 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
logger.info(`Creating streaming message for session: ${sessionId}`)
|
||||
logger.debug('Streaming message data:', messageData)
|
||||
|
||||
// Step 1: Save user message first
|
||||
const userMessage = await sessionMessageService.saveUserMessage(
|
||||
sessionId,
|
||||
messageData.content
|
||||
)
|
||||
|
||||
// Set SSE headers
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
@ -42,13 +48,36 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
res.setHeader('Access-Control-Allow-Origin', '*')
|
||||
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
||||
|
||||
// Send initial connection event
|
||||
res.write('data: {"type":"start"}\n\n')
|
||||
|
||||
const messageStream = sessionMessageService.createSessionMessage(session, messageData)
|
||||
const messageStream = sessionMessageService.createSessionMessage(session, messageData, userMessage.id)
|
||||
|
||||
// Track if the response has ended to prevent further writes
|
||||
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
|
||||
let responseEnded = false
|
||||
let streamFinished = false
|
||||
let awaitingPersistence = false
|
||||
let persistenceResolved = false
|
||||
|
||||
const finalizeResponse = () => {
|
||||
if (responseEnded) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!streamFinished) {
|
||||
return
|
||||
}
|
||||
|
||||
if (awaitingPersistence && !persistenceResolved) {
|
||||
return
|
||||
}
|
||||
|
||||
responseEnded = true
|
||||
try {
|
||||
res.write('data: [DONE]\n\n')
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error })
|
||||
}
|
||||
res.end()
|
||||
}
|
||||
|
||||
// Handle client disconnect
|
||||
req.on('close', () => {
|
||||
@ -76,18 +105,44 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
|
||||
}
|
||||
res.write(`data: ${JSON.stringify(errorChunk)}\n\n`)
|
||||
logger.error(`Streaming message error for session: ${sessionId}:`, event.error)
|
||||
responseEnded = true
|
||||
res.write('data: [DONE]\n\n')
|
||||
res.end()
|
||||
|
||||
streamFinished = true
|
||||
awaitingPersistence = Boolean(event.persistScheduled)
|
||||
|
||||
if (!awaitingPersistence) {
|
||||
persistenceResolved = true
|
||||
}
|
||||
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
|
||||
case 'complete':
|
||||
// Send completion marker following AI SDK protocol
|
||||
case 'complete': {
|
||||
logger.info(`Streaming message completed for session: ${sessionId}`)
|
||||
responseEnded = true
|
||||
res.write('data: [DONE]\n\n')
|
||||
res.end()
|
||||
res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
|
||||
|
||||
streamFinished = true
|
||||
awaitingPersistence = true
|
||||
finalizeResponse()
|
||||
break
|
||||
}
|
||||
|
||||
case 'persisted':
|
||||
// Send persistence success event
|
||||
res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||
logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id })
|
||||
|
||||
persistenceResolved = true
|
||||
finalizeResponse()
|
||||
break
|
||||
|
||||
case 'persist-error':
|
||||
// Send persistence error event
|
||||
res.write(`data: ${JSON.stringify(event)}\n\n`)
|
||||
logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error)
|
||||
|
||||
persistenceResolved = true
|
||||
finalizeResponse()
|
||||
break
|
||||
|
||||
default:
|
||||
|
||||
@ -7,15 +7,175 @@ import type {
|
||||
GetAgentSessionResponse,
|
||||
ListOptions,
|
||||
} from '@types'
|
||||
import { UIMessageChunk } from 'ai'
|
||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||
import { count, eq } from 'drizzle-orm'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import { sessionMessagesTable } from '../database/schema'
|
||||
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
||||
import ClaudeCodeService from './claudecode'
|
||||
|
||||
const logger = loggerService.withContext('SessionMessageService')
|
||||
|
||||
|
||||
// Collapse a UIMessageChunk stream into a final UIMessage, then convert to ModelMessage[]
|
||||
export async function chunksToModelMessages(
|
||||
chunkStream: ReadableStream<UIMessageChunk>,
|
||||
priorUiHistory: UIMessage[] = []
|
||||
): Promise<ModelMessage[]> {
|
||||
let latest: UIMessage | undefined
|
||||
|
||||
for await (const uiMsg of readUIMessageStream({ stream: chunkStream })) {
|
||||
latest = uiMsg // each yield is a newer state; keep the last one
|
||||
}
|
||||
|
||||
const uiMessages = latest ? [...priorUiHistory, latest] : priorUiHistory
|
||||
return convertToModelMessages(uiMessages) // -> ModelMessage[]
|
||||
}
|
||||
|
||||
// Utility function to normalize content to ModelMessage
|
||||
function normalizeModelMessage(content: string | ModelMessage): ModelMessage {
|
||||
if (typeof content === 'string') {
|
||||
return {
|
||||
role: 'user',
|
||||
content: content
|
||||
}
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// Ensure errors emitted through SSE are serializable
|
||||
function serializeError(error: unknown): { message: string; name?: string; stack?: string } {
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
message: error.message,
|
||||
name: error.name,
|
||||
stack: error.stack
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof error === 'string') {
|
||||
return { message: error }
|
||||
}
|
||||
|
||||
return {
|
||||
message: 'Unknown error'
|
||||
}
|
||||
}
|
||||
|
||||
// Interface for persistence context
|
||||
interface PersistContext {
|
||||
session: GetAgentSessionResponse
|
||||
accumulator: ChunkAccumulator
|
||||
userMessageId: number
|
||||
sessionStream: EventEmitter
|
||||
}
|
||||
|
||||
// Chunk accumulator class to collect and reconstruct streaming data
|
||||
class ChunkAccumulator {
|
||||
private streamedChunks: UIMessageChunk[] = []
|
||||
private rawAgentMessages: any[] = []
|
||||
private agentResult: any = null
|
||||
private agentType: string = 'unknown'
|
||||
private uniqueIds: Set<string> = new Set()
|
||||
|
||||
addChunk(chunk: UIMessageChunk): void {
|
||||
this.streamedChunks.push(chunk)
|
||||
}
|
||||
|
||||
addRawMessage(message: any): void {
|
||||
if (message.uuid && this.uniqueIds.has(message.uuid)) {
|
||||
// Duplicate message based on uuid; skip adding
|
||||
return
|
||||
}
|
||||
if (message.uuid) {
|
||||
this.uniqueIds.add(message.uuid)
|
||||
}
|
||||
this.rawAgentMessages.push(message)
|
||||
}
|
||||
|
||||
setAgentResult(result: any): void {
|
||||
this.agentResult = result
|
||||
if (result?.agentType) {
|
||||
this.agentType = result.agentType
|
||||
}
|
||||
}
|
||||
|
||||
buildStructuredContent() {
|
||||
return {
|
||||
aiSDKChunks: this.streamedChunks,
|
||||
rawAgentMessages: this.rawAgentMessages,
|
||||
agentResult: this.agentResult,
|
||||
agentType: this.agentType
|
||||
}
|
||||
}
|
||||
|
||||
// Create a ReadableStream from accumulated chunks
|
||||
createChunkStream(): ReadableStream<UIMessageChunk> {
|
||||
const chunks = [...this.streamedChunks]
|
||||
|
||||
return new ReadableStream<UIMessageChunk>({
|
||||
start(controller) {
|
||||
// Enqueue all chunks
|
||||
for (const chunk of chunks) {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
controller.close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Convert accumulated chunks to ModelMessages using chunksToModelMessages
|
||||
async toModelMessages(priorUiHistory: UIMessage[] = []): Promise<ModelMessage[]> {
|
||||
const chunkStream = this.createChunkStream()
|
||||
return await chunksToModelMessages(chunkStream, priorUiHistory)
|
||||
}
|
||||
|
||||
toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage {
|
||||
// Reconstruct the content from chunks
|
||||
let textContent = ''
|
||||
const toolCalls: any[] = []
|
||||
|
||||
for (const chunk of this.streamedChunks) {
|
||||
if (chunk.type === 'text-delta' && 'delta' in chunk) {
|
||||
textContent += chunk.delta
|
||||
} else if (chunk.type === 'tool-input-available' && 'toolCallId' in chunk && 'toolName' in chunk) {
|
||||
// Handle tool calls - use tool-input-available chunks
|
||||
const toolCall = {
|
||||
toolCallId: chunk.toolCallId,
|
||||
toolName: chunk.toolName,
|
||||
args: chunk.input || {}
|
||||
}
|
||||
toolCalls.push(toolCall)
|
||||
}
|
||||
}
|
||||
|
||||
const message: any = {
|
||||
role,
|
||||
content: textContent
|
||||
}
|
||||
|
||||
// Add tool invocations if any
|
||||
if (toolCalls.length > 0) {
|
||||
message.toolInvocations = toolCalls
|
||||
}
|
||||
|
||||
return message as ModelMessage
|
||||
}
|
||||
|
||||
getChunkCount(): number {
|
||||
return this.streamedChunks.length
|
||||
}
|
||||
|
||||
getRawMessageCount(): number {
|
||||
return this.rawAgentMessages.length
|
||||
}
|
||||
|
||||
getAgentType(): string {
|
||||
return this.agentType
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionMessageService extends BaseService {
|
||||
private static instance: SessionMessageService | null = null
|
||||
private cc: ClaudeCodeService = new ClaudeCodeService()
|
||||
@ -76,14 +236,44 @@ export class SessionMessageService extends BaseService {
|
||||
return { messages, total }
|
||||
}
|
||||
|
||||
createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter {
|
||||
async saveUserMessage(sessionId: string, content: ModelMessage | string): Promise<AgentSessionMessageEntity> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const now = new Date().toISOString()
|
||||
const userContent: ModelMessage = normalizeModelMessage(content)
|
||||
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: sessionId,
|
||||
role: 'user',
|
||||
content: JSON.stringify(userContent),
|
||||
metadata: JSON.stringify({
|
||||
timestamp: now,
|
||||
source: 'api'
|
||||
}),
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await this.database
|
||||
.insert(sessionMessagesTable)
|
||||
.values(insertData)
|
||||
.returning()
|
||||
|
||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||
}
|
||||
|
||||
createSessionMessage(
|
||||
session: GetAgentSessionResponse,
|
||||
messageData: CreateSessionMessageRequest,
|
||||
userMessageId: number
|
||||
): EventEmitter {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Create a new EventEmitter to manage the session message lifecycle
|
||||
const sessionStream = new EventEmitter()
|
||||
|
||||
// No parent validation needed, start immediately
|
||||
this.startSessionMessageStream(session, messageData, sessionStream)
|
||||
this.startSessionMessageStream(session, messageData, sessionStream, userMessageId)
|
||||
|
||||
return sessionStream
|
||||
}
|
||||
@ -91,7 +281,8 @@ export class SessionMessageService extends BaseService {
|
||||
private startSessionMessageStream(
|
||||
session: GetAgentSessionResponse,
|
||||
req: CreateSessionMessageRequest,
|
||||
sessionStream: EventEmitter
|
||||
sessionStream: EventEmitter,
|
||||
userMessageId: number
|
||||
): void {
|
||||
const previousMessages = session.messages || []
|
||||
let session_id: string = ''
|
||||
@ -112,8 +303,8 @@ export class SessionMessageService extends BaseService {
|
||||
maxTurns: session.configuration?.maxTurns || 10
|
||||
})
|
||||
|
||||
const streamedChunks: UIMessageChunk[] = []
|
||||
const rawAgentMessages: any[] = [] // Generic agent messages storage
|
||||
// Use chunk accumulator to manage streaming data
|
||||
const accumulator = new ChunkAccumulator()
|
||||
|
||||
// Handle agent stream events (agent-agnostic)
|
||||
claudeStream.on('data', async (event: any) => {
|
||||
@ -123,11 +314,11 @@ export class SessionMessageService extends BaseService {
|
||||
// Forward UIMessageChunk directly and collect raw agent messages
|
||||
if (event.chunk) {
|
||||
const chunk = event.chunk as UIMessageChunk
|
||||
streamedChunks.push(chunk)
|
||||
accumulator.addChunk(chunk)
|
||||
|
||||
// Collect raw agent message if available (agent-agnostic)
|
||||
if (event.rawAgentMessage) {
|
||||
rawAgentMessages.push(event.rawAgentMessage)
|
||||
accumulator.addRawMessage(event.rawAgentMessage)
|
||||
}
|
||||
|
||||
sessionStream.emit('data', {
|
||||
@ -139,76 +330,55 @@ export class SessionMessageService extends BaseService {
|
||||
}
|
||||
break
|
||||
|
||||
case 'error':
|
||||
case 'error': {
|
||||
const underlyingError = event.error || (event.data?.stderr ? new Error(event.data.stderr) : undefined)
|
||||
const persistScheduled = accumulator.getChunkCount() > 0
|
||||
|
||||
if (persistScheduled) {
|
||||
// Try to save partial state with error metadata when possible
|
||||
accumulator.setAgentResult({
|
||||
error: serializeError(underlyingError),
|
||||
agentType: 'claude-code',
|
||||
incomplete: true
|
||||
})
|
||||
|
||||
void this.persistSessionMessageAsync({
|
||||
session,
|
||||
accumulator,
|
||||
userMessageId,
|
||||
sessionStream
|
||||
})
|
||||
}
|
||||
|
||||
sessionStream.emit('data', {
|
||||
type: 'error',
|
||||
error: event.error || (event.data?.stderr ? new Error(event.data.stderr) : undefined)
|
||||
error: serializeError(underlyingError),
|
||||
persistScheduled
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
case 'complete': {
|
||||
// Save the final message to database when agent completes
|
||||
logger.info('Agent stream completed, saving message to database')
|
||||
|
||||
// Extract additional raw agent messages from agentResult if available
|
||||
if (event.agentResult?.rawSDKMessages) {
|
||||
rawAgentMessages.push(...event.agentResult.rawSDKMessages)
|
||||
event.agentResult.rawSDKMessages.forEach((msg: any) => accumulator.addRawMessage(msg))
|
||||
}
|
||||
|
||||
// Create structured content with both AI SDK format and raw data
|
||||
const structuredContent = {
|
||||
aiSDKChunks: streamedChunks, // For UI consumption
|
||||
rawAgentMessages: rawAgentMessages, // Original agent-specific messages
|
||||
agentResult: event.agentResult, // Complete result from the agent
|
||||
agentType: event.agentResult?.agentType || 'unknown' // Store agent type for future reference
|
||||
}
|
||||
// Set the agent result in the accumulator
|
||||
accumulator.setAgentResult(event.agentResult)
|
||||
|
||||
// const now = new Date().toISOString()
|
||||
// const insertData: InsertSessionMessageRow = {
|
||||
// session_id: req.session_id,
|
||||
// parent_id: req.parent_id || null,
|
||||
// role: req.role,
|
||||
// type: req.type,
|
||||
// content: JSON.stringify(structuredContent),
|
||||
// metadata: req.metadata
|
||||
// ? JSON.stringify({
|
||||
// ...req.metadata,
|
||||
// chunkCount: streamedChunks.length,
|
||||
// rawMessageCount: rawAgentMessages.length,
|
||||
// agentType: event.agentResult?.agentType || 'unknown',
|
||||
// completedAt: now
|
||||
// })
|
||||
// : JSON.stringify({
|
||||
// chunkCount: streamedChunks.length,
|
||||
// rawMessageCount: rawAgentMessages.length,
|
||||
// agentType: event.agentResult?.agentType || 'unknown',
|
||||
// completedAt: now
|
||||
// }),
|
||||
// created_at: now,
|
||||
// updated_at: now
|
||||
// }
|
||||
// // Emit SSE completion FIRST before persistence
|
||||
// sessionStream.emit('data', {
|
||||
// type: 'complete',
|
||||
// result: accumulator.buildStructuredContent()
|
||||
// })
|
||||
|
||||
// const result = await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
// if (result[0]) {
|
||||
// sessionMessage = this.deserializeSessionMessage(result[0]) as AgentSessionMessageEntity
|
||||
// logger.info(`Session message saved with ID: ${sessionMessage.id}`)
|
||||
|
||||
// // Emit the complete event with the saved message and structured data
|
||||
// sessionStream.emit('data', {
|
||||
// type: 'complete',
|
||||
// result: structuredContent,
|
||||
// message: sessionMessage
|
||||
// })
|
||||
// } else {
|
||||
// sessionStream.emit('data', {
|
||||
// type: 'error',
|
||||
// error: new Error('Failed to save session message to database')
|
||||
// })
|
||||
// }
|
||||
sessionStream.emit('data', {
|
||||
type: 'complete',
|
||||
result: structuredContent
|
||||
// Then handle async persistence
|
||||
void this.persistSessionMessageAsync({
|
||||
session,
|
||||
accumulator,
|
||||
userMessageId,
|
||||
sessionStream
|
||||
})
|
||||
break
|
||||
}
|
||||
@ -223,12 +393,58 @@ export class SessionMessageService extends BaseService {
|
||||
logger.error('Error handling Claude Code stream event:', { error })
|
||||
sessionStream.emit('data', {
|
||||
type: 'error',
|
||||
error: error as Error
|
||||
error: serializeError(error)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private async persistSessionMessageAsync({ session, accumulator, userMessageId, sessionStream }: PersistContext) {
|
||||
if (!session?.id) {
|
||||
const missingSessionError = new Error('Missing session_id for persisted message')
|
||||
logger.error(missingSessionError.message, { error: missingSessionError })
|
||||
sessionStream.emit('data', { type: 'persist-error', error: serializeError(missingSessionError) })
|
||||
return
|
||||
}
|
||||
|
||||
const sessionId = session.id
|
||||
const now = new Date().toISOString()
|
||||
const structured = accumulator.buildStructuredContent()
|
||||
|
||||
try {
|
||||
// Use chunksToModelMessages to convert chunks to ModelMessages
|
||||
const modelMessages = await accumulator.toModelMessages()
|
||||
// Get the last message (should be the assistant's response)
|
||||
const modelMessage =
|
||||
modelMessages.length > 0 ? modelMessages[modelMessages.length - 1] : accumulator.toModelMessage('assistant')
|
||||
|
||||
const metadata = {
|
||||
userMessageId,
|
||||
chunkCount: accumulator.getChunkCount(),
|
||||
rawMessageCount: accumulator.getRawMessageCount(),
|
||||
agentType: accumulator.getAgentType(),
|
||||
completedAt: now
|
||||
}
|
||||
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: sessionId,
|
||||
role: 'assistant',
|
||||
content: JSON.stringify({ modelMessage, ...structured }),
|
||||
metadata: JSON.stringify(metadata),
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [row] = await this.database.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
const entity = this.deserializeSessionMessage(row) as AgentSessionMessageEntity
|
||||
sessionStream.emit('data', { type: 'persisted', message: entity })
|
||||
} catch (error) {
|
||||
logger.error('Failed to persist session message', { error })
|
||||
sessionStream.emit('data', { type: 'persist-error', error: serializeError(error) })
|
||||
}
|
||||
}
|
||||
|
||||
private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
|
||||
if (!data) return data
|
||||
|
||||
|
||||
@ -74,14 +74,7 @@ export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageC
|
||||
|
||||
function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata {
|
||||
const meta: ProviderMetadata = {
|
||||
raw: message as Record<string, any>,
|
||||
claudeCode: {
|
||||
originalSDKMessage: JSON.parse(JSON.stringify(message)), // Serialize to ensure JSON compatibility
|
||||
uuid: message.uuid || null,
|
||||
session_id: message.session_id || null,
|
||||
timestamp: new Date().toISOString(),
|
||||
type: message.type
|
||||
}
|
||||
message: message as Record<string, any>
|
||||
}
|
||||
return meta
|
||||
}
|
||||
@ -89,7 +82,7 @@ function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata {
|
||||
// Handle assistant messages
|
||||
function handleAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' }>): UIMessageChunk[] {
|
||||
const chunks: UIMessageChunk[] = []
|
||||
const messageId = generateMessageId()
|
||||
const messageId = message.uuid
|
||||
|
||||
// Extract text content
|
||||
const textContent = extractTextContent(message.message as MessageParam)
|
||||
@ -97,36 +90,18 @@ function handleAssistantMessage(message: Extract<SDKMessage, { type: 'assistant'
|
||||
chunks.push(
|
||||
{
|
||||
type: 'text-start',
|
||||
id: messageId,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
id: messageId
|
||||
},
|
||||
{
|
||||
type: 'text-delta',
|
||||
id: messageId,
|
||||
delta: textContent,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
delta: textContent
|
||||
},
|
||||
{
|
||||
type: 'text-end',
|
||||
id: messageId,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
rawMessage: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
}
|
||||
)
|
||||
@ -289,15 +264,17 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
|
||||
const chunks: UIMessageChunk[] = []
|
||||
|
||||
if (message.subtype === 'init') {
|
||||
chunks.push({
|
||||
type: 'start',
|
||||
messageId: message.session_id
|
||||
})
|
||||
|
||||
// System initialization - could emit as a data chunk or skip
|
||||
chunks.push({
|
||||
type: 'data-system' as any,
|
||||
data: {
|
||||
type: 'init',
|
||||
cwd: message.cwd,
|
||||
tools: message.tools,
|
||||
model: message.model,
|
||||
mcp_servers: message.mcp_servers,
|
||||
session_id: message.session_id,
|
||||
raw: message
|
||||
}
|
||||
})
|
||||
@ -319,63 +296,14 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
|
||||
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): UIMessageChunk[] {
|
||||
const chunks: UIMessageChunk[] = []
|
||||
|
||||
const messageId = message.uuid
|
||||
if (message.subtype === 'success') {
|
||||
// Emit the final result text if available
|
||||
if (message.result) {
|
||||
const messageId = generateMessageId()
|
||||
chunks.push(
|
||||
{
|
||||
type: 'text-start',
|
||||
id: messageId,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
final_result: true
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'text-delta',
|
||||
id: messageId,
|
||||
delta: message.result,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
final_result: true
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
},
|
||||
{
|
||||
type: 'text-end',
|
||||
id: messageId,
|
||||
providerMetadata: {
|
||||
anthropic: {
|
||||
uuid: message.uuid,
|
||||
session_id: message.session_id,
|
||||
final_result: true
|
||||
},
|
||||
raw: sdkMessageToProviderMetadata(message)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// Emit usage and cost data
|
||||
// Emit final result data
|
||||
chunks.push({
|
||||
type: 'data-usage' as any,
|
||||
data: {
|
||||
duration_ms: message.duration_ms,
|
||||
duration_api_ms: message.duration_api_ms,
|
||||
num_turns: message.num_turns,
|
||||
total_cost_usd: message.total_cost_usd,
|
||||
usage: message.usage,
|
||||
modelUsage: message.modelUsage,
|
||||
permission_denials: message.permission_denials
|
||||
}
|
||||
type: 'data-result' as any,
|
||||
id: messageId,
|
||||
data: message,
|
||||
transient: true
|
||||
})
|
||||
} else {
|
||||
// Handle error cases
|
||||
@ -383,22 +311,23 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
||||
type: 'error',
|
||||
errorText: `${message.subtype}: Process failed after ${message.num_turns} turns`
|
||||
})
|
||||
|
||||
// Still emit usage data for failed requests
|
||||
chunks.push({
|
||||
type: 'data-usage' as any,
|
||||
data: {
|
||||
duration_ms: message.duration_ms,
|
||||
duration_api_ms: message.duration_api_ms,
|
||||
num_turns: message.num_turns,
|
||||
total_cost_usd: message.total_cost_usd,
|
||||
usage: message.usage,
|
||||
modelUsage: message.modelUsage,
|
||||
permission_denials: message.permission_denials
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Emit usage and cost data
|
||||
chunks.push({
|
||||
type: 'data-usage' as any,
|
||||
data: {
|
||||
cost: message.total_cost_usd,
|
||||
usage: {
|
||||
input_tokens: message.usage.input_tokens,
|
||||
cache_creation_input_tokens: message.usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: message.usage.cache_read_input_tokens,
|
||||
output_tokens: message.usage.output_tokens,
|
||||
service_tier: 'standard'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Always emit a finish chunk at the end
|
||||
chunks.push({
|
||||
type: 'finish'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user