feat(transform): refactor message handling to unify user and assistant processing

This commit is contained in:
Vaayne 2025-09-19 22:36:19 +08:00
parent 027ef17a2e
commit c426876d0d
2 changed files with 65 additions and 111 deletions

View File

@ -42,7 +42,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
res.setHeader('Access-Control-Allow-Origin', '*') res.setHeader('Access-Control-Allow-Origin', '*')
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control') res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
const messageStream = sessionMessageService.createSessionMessage(session, messageData) const messageStream = sessionMessageService.createSessionMessage(session, messageData)
// Track stream lifecycle so we keep the SSE connection open until persistence finishes // Track stream lifecycle so we keep the SSE connection open until persistence finishes
@ -66,6 +65,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
responseEnded = true responseEnded = true
try { try {
res.write('data: {"type":"finish"}\n\n')
res.write('data: [DONE]\n\n') res.write('data: [DONE]\n\n')
} catch (writeError) { } catch (writeError) {
logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error }) logger.error('Error writing final sentinel to SSE stream:', { error: writeError as Error })
@ -113,7 +113,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
case 'complete': { case 'complete': {
logger.info(`Streaming message completed for session: ${sessionId}`) logger.info(`Streaming message completed for session: ${sessionId}`)
res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`) // res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
streamFinished = true streamFinished = true
awaitingPersistence = true awaitingPersistence = true
@ -123,7 +123,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
case 'persisted': case 'persisted':
// Send persistence success event // Send persistence success event
res.write(`data: ${JSON.stringify(event)}\n\n`) // res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id }) logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id })
persistenceResolved = true persistenceResolved = true
@ -132,7 +132,7 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
case 'persist-error': case 'persist-error':
// Send persistence error event // Send persistence error event
res.write(`data: ${JSON.stringify(event)}\n\n`) // res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error) logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error)
persistenceResolved = true persistenceResolved = true

View File

@ -1,7 +1,6 @@
// This file is used to transform claude code json response to aisdk streaming format // This file is used to transform claude code json response to aisdk streaming format
import { SDKMessage } from '@anthropic-ai/claude-code' import { SDKMessage } from '@anthropic-ai/claude-code'
import { MessageParam } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { ProviderMetadata, UIMessageChunk } from 'ai' import { ProviderMetadata, UIMessageChunk } from 'ai'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
@ -13,42 +12,14 @@ const generateMessageId = (): string => {
return `msg_${uuidv4().replace(/-/g, '')}` return `msg_${uuidv4().replace(/-/g, '')}`
} }
// Helper function to extract text content from Anthropic messages
const extractTextContent = (message: MessageParam): string => {
if (typeof message.content === 'string') {
return message.content
}
if (Array.isArray(message.content)) {
return message.content
.filter((block) => block.type === 'text')
.map((block) => ('text' in block ? block.text : ''))
.join('')
}
return ''
}
// Helper function to extract tool calls from assistant messages
const extractToolCalls = (message: any): any[] => {
if (!message.content || !Array.isArray(message.content)) {
return []
}
return message.content.filter((block: any) => block.type === 'tool_use')
}
// Main transform function // Main transform function
export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageChunk[] { export function transformSDKMessageToUIChunk(sdkMessage: SDKMessage): UIMessageChunk[] {
const chunks: UIMessageChunk[] = [] const chunks: UIMessageChunk[] = []
switch (sdkMessage.type) { switch (sdkMessage.type) {
case 'assistant': case 'assistant':
chunks.push(...handleAssistantMessage(sdkMessage))
break
case 'user': case 'user':
chunks.push(...handleUserMessage(sdkMessage)) chunks.push(...handleUserOrAssistantMessage(sdkMessage))
break break
case 'stream_event': case 'stream_event':
@ -79,89 +50,72 @@ function sdkMessageToProviderMetadata(message: SDKMessage): ProviderMetadata {
return meta return meta
} }
// Handle assistant messages function generateTextChunks(id: string, text: string, message: SDKMessage): UIMessageChunk[] {
function handleAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' }>): UIMessageChunk[] { return [
const chunks: UIMessageChunk[] = []
const messageId = message.uuid
// Extract text content
const textContent = extractTextContent(message.message as MessageParam)
if (textContent) {
chunks.push(
{ {
type: 'text-start', type: 'text-start',
id: messageId id
}, },
{ {
type: 'text-delta', type: 'text-delta',
id: messageId, id,
delta: textContent delta: text
}, },
{ {
type: 'text-end', type: 'text-end',
id: messageId, id,
providerMetadata: { providerMetadata: {
rawMessage: sdkMessageToProviderMetadata(message) rawMessage: sdkMessageToProviderMetadata(message)
} }
} }
) ]
} }
// Handle tool calls function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assistant' | 'user' }>): UIMessageChunk[] {
const toolCalls = extractToolCalls(message.message) const chunks: UIMessageChunk[] = []
for (const toolCall of toolCalls) { const messageId = message.uuid?.toString() || generateMessageId()
// handle normal text content
if (typeof message.message.content === 'string') {
const textContent = message.message.content
if (textContent) {
chunks.push(...generateTextChunks(messageId, textContent, message))
}
} else if (Array.isArray(message.message.content)) {
for (const block of message.message.content) {
switch (block.type) {
case 'text':
chunks.push(...generateTextChunks(messageId, block.text, message))
break
case 'tool_use':
chunks.push({ chunks.push({
type: 'tool-input-available', type: 'tool-input-available',
toolCallId: toolCall.id, toolCallId: block.id,
toolName: toolCall.name, toolName: block.name,
input: toolCall.input, input: block.input,
providerExecuted: true providerExecuted: true,
providerMetadata: {
rawMessage: sdkMessageToProviderMetadata(message)
}
}) })
} break
case 'tool_result':
return chunks chunks.push({
} type: 'tool-output-available',
toolCallId: block.tool_use_id,
// Handle user messages output: block.content,
function handleUserMessage(message: Extract<SDKMessage, { type: 'user' }>): UIMessageChunk[] { providerExecuted: true,
const chunks: UIMessageChunk[] = [] dynamic: false,
const messageId = generateMessageId() preliminary: false
})
const textContent = extractTextContent(message.message) break
if (textContent) { default:
chunks.push( logger.warn('Unknown content block type in user/assistant message:', {
{ type: (block as any).type
type: 'text-start', })
id: messageId, break
providerMetadata: {
anthropic: {
session_id: message.session_id,
role: 'user'
} }
} }
},
{
type: 'text-delta',
id: messageId,
delta: textContent,
providerMetadata: {
anthropic: {
session_id: message.session_id,
role: 'user'
}
}
},
{
type: 'text-end',
id: messageId,
providerMetadata: {
anthropic: {
session_id: message.session_id,
role: 'user'
}
}
}
)
} }
return chunks return chunks