refactor: streamline session handling and enhance message transformation

This commit is contained in:
Vaayne 2025-09-21 00:31:04 +08:00
parent 01ffd4c4ca
commit c3b2af5a15
4 changed files with 37 additions and 81 deletions

View File

@ -83,28 +83,6 @@ class TextStreamAccumulator {
break break
} }
} }
toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage {
const content = this.totalText || this.textBuffer || ''
const toolInvocations = Array.from(this.toolCalls.entries()).map(([toolCallId, info]) => ({
toolCallId,
toolName: info.toolName,
args: info.input,
result: this.toolResults.get(toolCallId)
}))
const message: Record<string, unknown> = {
role,
content
}
if (toolInvocations.length > 0) {
message.toolInvocations = toolInvocations
}
return message as ModelMessage
}
} }
export class SessionMessageService extends BaseService { export class SessionMessageService extends BaseService {
@ -175,7 +153,6 @@ export class SessionMessageService extends BaseService {
abortController: AbortController abortController: AbortController
): Promise<SessionStreamResult> { ): Promise<SessionStreamResult> {
const agentSessionId = await this.getLastAgentSessionId(session.id) const agentSessionId = await this.getLastAgentSessionId(session.id)
let newAgentSessionId = ''
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId }) logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
if (session.agent_type !== 'claude-code') { if (session.agent_type !== 'claude-code') {
@ -222,10 +199,6 @@ export class SessionMessageService extends BaseService {
return return
} }
if (chunk.type === 'start' && chunk.messageId) {
newAgentSessionId = chunk.messageId
}
accumulator.add(chunk) accumulator.add(chunk)
controller.enqueue(chunk) controller.enqueue(chunk)
break break

View File

@ -70,23 +70,18 @@ function generateTextChunks(id: string, text: string, message: SDKMessage): Agen
return [ return [
{ {
type: 'text-start', type: 'text-start',
id, id
providerMetadata
}, },
{ {
type: 'text-delta', type: 'text-delta',
id, id,
text, text
providerMetadata
}, },
{ {
type: 'text-end', type: 'text-end',
id, id,
providerMetadata: { providerMetadata: {
...providerMetadata, ...providerMetadata
text: {
value: text
}
} }
} }
] ]
@ -119,17 +114,22 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
}) })
break break
case 'tool_result': case 'tool_result':
// chunks.push({ chunks.push({
// type: 'tool-result', type: 'tool-result',
// toolCallId: block.tool_use_id, toolCallId: block.tool_use_id,
// output: block.content, toolName: '',
// providerMetadata: sdkMessageToProviderMetadata(message) input: '',
// }) output: block.content,
})
break break
default: default:
logger.warn('Unknown content block type in user/assistant message:', { logger.warn('Unknown content block type in user/assistant message:', {
type: block.type type: block.type
}) })
chunks.push({
type: 'raw',
rawValue: block
})
break break
} }
} }
@ -142,7 +142,7 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): AgentStreamPart[] { function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): AgentStreamPart[] {
const chunks: AgentStreamPart[] = [] const chunks: AgentStreamPart[] = []
const event = message.event const event = message.event
const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.index}` const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.type}`
switch (event.type) { switch (event.type) {
case 'message_start': case 'message_start':
@ -255,15 +255,13 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
} }
contentBlockState.delete(blockKey) contentBlockState.delete(blockKey)
} }
break
case 'message_delta': case 'message_delta':
// Handle usage updates or other message-level deltas // Handle usage updates or other message-level deltas
break break
case 'message_stop': case 'message_stop':
// This could signal the end of the message // This could signal the end of the message
break break
default: default:
logger.warn('Unknown stream event type:', { type: (event as any).type }) logger.warn('Unknown stream event type:', { type: (event as any).type })
break break
@ -283,9 +281,19 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
chunks.push({ chunks.push({
type: 'start' type: 'start'
}) })
chunks.push({
type: 'raw',
rawValue: {
type: 'init',
session_id: message.session_id,
slash_commands: message.slash_commands,
tools: message.tools,
raw: message
}
})
} }
} }
return [] return chunks
} }
// Handle result messages (completion with usage stats) // Handle result messages (completion with usage stats)
@ -295,14 +303,9 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
let usage: LanguageModelV2Usage | undefined let usage: LanguageModelV2Usage | undefined
if ('usage' in message) { if ('usage' in message) {
usage = { usage = {
inputTokens: inputTokens: message.usage.input_tokens ?? 0,
(message.usage.cache_creation_input_tokens ?? 0) +
(message.usage.cache_read_input_tokens ?? 0) +
(message.usage.input_tokens ?? 0),
outputTokens: message.usage.output_tokens ?? 0, outputTokens: message.usage.output_tokens ?? 0,
totalTokens: totalTokens:
(message.usage.cache_creation_input_tokens ?? 0) +
(message.usage.cache_read_input_tokens ?? 0) +
(message.usage.input_tokens ?? 0) + (message.usage.input_tokens ?? 0) +
(message.usage.output_tokens ?? 0) (message.usage.output_tokens ?? 0)
} }
@ -330,25 +333,3 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
} }
return chunks return chunks
} }
// Convenience function to transform a stream of SDKMessages
export function* transformSDKMessageStream(sdkMessages: SDKMessage[]): Generator<AgentStreamPart> {
for (const sdkMessage of sdkMessages) {
const chunks = transformSDKMessageToStreamParts(sdkMessage)
for (const chunk of chunks) {
yield chunk
}
}
}
// Async version for async iterables
export async function* transformSDKMessageStreamAsync(
sdkMessages: AsyncIterable<SDKMessage>
): AsyncGenerator<AgentStreamPart> {
for await (const sdkMessage of sdkMessages) {
const chunks = transformSDKMessageToStreamParts(sdkMessage)
for (const chunk of chunks) {
yield chunk
}
}
}

View File

@ -111,13 +111,10 @@ export class AiSdkToChunkAdapter {
chunk: TextStreamPart<any>, chunk: TextStreamPart<any>,
final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string } final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string }
) { ) {
const sessionId = // @ts-ignore
(chunk.providerMetadata as any)?.anthropic?.session_id ?? if (chunk.type === 'raw' && chunk.rawValue.type === 'init' && chunk.rawValue.session_id) {
(chunk.providerMetadata as any)?.anthropic?.sessionId ?? // @ts-ignore
(chunk.providerMetadata as any)?.raw?.session_id this.onSessionUpdate?.(chunk.rawValue.session_id)
if (typeof sessionId === 'string' && sessionId) {
this.onSessionUpdate?.(sessionId)
} }
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk) logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)

View File

@ -1,3 +1,4 @@
import type {ProviderMetadata} from "ai";
import type { CompletionUsage } from 'openai/resources' import type { CompletionUsage } from 'openai/resources'
import type { import type {
@ -203,6 +204,10 @@ export type Message = {
// 跟踪Id // 跟踪Id
traceId?: string traceId?: string
// raw data
// TODO: add this providerMetadata to MessageBlock to save raw provider data for each block
providerMetadata?: ProviderMetadata
} }
export interface Response { export interface Response {