feat(transform): refactor message handling to unify user and assistant processing

This commit is contained in:
Vaayne 2025-09-19 22:36:19 +08:00
parent 027ef17a2e
commit c426876d0d
2 changed files with 65 additions and 111 deletions

View File

@ -42,7 +42,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
res.setHeader('Access-Control-Allow-Origin', '*') res.setHeader('Access-Control-Allow-Origin', '*')
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control') res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
const messageStream = sessionMessageService.createSessionMessage(session, messageData) const messageStream = sessionMessageService.createSessionMessage(session, messageData)
// Track stream lifecycle so we keep the SSE connection open until persistence finishes // Track stream lifecycle so we keep the SSE connection open until persistence finishes
@ -66,6 +65,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
responseEnded = true responseEnded = true
try { try {
res.write('data: {"type":"finish"}\n\n')
res.write('data: [DONE]\n\n') res.write('data: [DONE]\n\n')
} catch (writeError) { } catch (writeError) {
logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error }) logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error })
@ -113,7 +113,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
case 'complete': { case 'complete': {
logger.info(`Streaming message completed for session: ${sessionId}`) logger.info(`Streaming message completed for session: ${sessionId}`)
res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`) // res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
streamFinished = true streamFinished = true
awaitingPersistence = true awaitingPersistence = true
@ -123,7 +123,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
case 'persisted': case 'persisted':
// Send persistence success event // Send persistence success event
res.write(`data: ${JSON.stringify(event)}\n\n`) // res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id }) logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id })
persistenceResolved = true persistenceResolved = true
@ -132,7 +132,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
case 'persist-error': case 'persist-error':
// Send persistence error event // Send persistence error event
res.write(`data: ${JSON.stringify(event)}\n\n`) // res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error) logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error)
persistenceResolved = true persistenceResolved = true

View File

@ -1,7 +1,6 @@
// This file is used to transform claude code json response to aisdk streaming format // This file is used to transform claude code json response to aisdk streaming format
import { SDKMessage } from '@anthropic-ai/claude-code' import { SDKMessage } from '@anthropic-ai/claude-code'
import { MessageParam } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { ProviderMetadata, UIMessageChunk } from 'ai' import { ProviderMetadata, UIMessageChunk } from 'ai'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
@ -13,42 +12,14 @@ const generateMessageId = (): string => {
return `msg_${uuidv4().replace(/-/g, '')}` return `msg_${uuidv4().replace(/-/g, '')}`
} }
// Helper function to extract text content from Anthropic messages
const extractTextContent = (message: MessageParam): string => {
if (typeof message.content === 'string') {
return message.content
}
if (Array.isArray(message.content)) {
return message.content
.filter((block) => block.type === 'text')
.map((block) => ('text' in block ? block.text : ''))
.join('')
}
return ''
}
// Helper function to extract tool calls from assistant messages
const extractToolCalls = (message: any): any[] => {
if (!message.content || !Array.isArray(message.content)) {
return []
}
return message.content.filter((block: any) => block.type === 'tool_use')
}
// Main transform function // Main transform function
export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageChunk[] { export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageChunk[] {
const chunks: UIMessageChunk[] = [] const chunks: UIMessageChunk[] = []
switch (sdkMessage.type) { switch (sdkMessage.type) {
case 'assistant': case 'assistant':
chunks.push(...handleAssistantMessage(sdkMessage))
break
case 'user': case 'user':
chunks.push(...handleUserMessage(sdkMessage)) chunks.push(...handleUserOrAssistantMessage(sdkMessage))
break break
case 'stream_event': case 'stream_event':
@ -79,89 +50,72 @@ function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata {
return meta return meta
} }
// Handle assistant messages function generateTextChunks(id: string, text: string, message: SDKMessage): UIMessageChunk[] {
function handleAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' }>): UIMessageChunk[] { return [
const chunks: UIMessageChunk[] = [] {
const messageId = message.uuid type: 'text-start',
id
// Extract text content },
const textContent = extractTextContent(message.message as MessageParam) {
if (textContent) { type: 'text-delta',
chunks.push( id,
{ delta: text
type: 'text-start', },
id: messageId {
}, type: 'text-end',
{ id,
type: 'text-delta', providerMetadata: {
id: messageId, rawMessage: sdkMessageToProviderMetadata(message)
delta: textContent
},
{
type: 'text-end',
id: messageId,
providerMetadata: {
rawMessage: sdkMessageToProviderMetadata(message)
}
} }
) }
} ]
// Handle tool calls
const toolCalls = extractToolCalls(message.message)
for (const toolCall of toolCalls) {
chunks.push({
type: 'tool-input-available',
toolCallId: toolCall.id,
toolName: toolCall.name,
input: toolCall.input,
providerExecuted: true
})
}
return chunks
} }
// Handle user messages function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' | 'user' }>): UIMessageChunk[] {
function handleUserMessage(message: Extract<SDKMessage, { type: 'user' }>): UIMessageChunk[] {
const chunks: UIMessageChunk[] = [] const chunks: UIMessageChunk[] = []
const messageId = generateMessageId() const messageId = message.uuid?.toString() || generateMessageId()
const textContent = extractTextContent(message.message) // handle normal text content
if (textContent) { if (typeof message.message.content === 'string') {
chunks.push( const textContent = message.message.content
{ if (textContent) {
type: 'text-start', chunks.push(...generateTextChunks(messageId, textContent, message))
id: messageId, }
providerMetadata: { } else if (Array.isArray(message.message.content)) {
anthropic: { for (const block of message.message.content) {
session_id: message.session_id, switch (block.type) {
role: 'user' case 'text':
} chunks.push(...generateTextChunks(messageId, block.text, message))
} break
}, case 'tool_use':
{ chunks.push({
type: 'text-delta', type: 'tool-input-available',
id: messageId, toolCallId: block.id,
delta: textContent, toolName: block.name,
providerMetadata: { input: block.input,
anthropic: { providerExecuted: true,
session_id: message.session_id, providerMetadata: {
role: 'user' rawMessage: sdkMessageToProviderMetadata(message)
} }
} })
}, break
{ case 'tool_result':
type: 'text-end', chunks.push({
id: messageId, type: 'tool-output-available',
providerMetadata: { toolCallId: block.tool_use_id,
anthropic: { output: block.content,
session_id: message.session_id, providerExecuted: true,
role: 'user' dynamic: false,
} preliminary: false
} })
break
default:
logger.warn('Unknown content block type in user/assistant message:', {
type: (block as any).type
})
break
} }
) }
} }
return chunks return chunks