mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-30 15:59:09 +08:00
refactor: streamline session handling and enhance message transformation
This commit is contained in:
parent
01ffd4c4ca
commit
c3b2af5a15
@ -83,28 +83,6 @@ class TextStreamAccumulator {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage {
|
||||
const content = this.totalText || this.textBuffer || ''
|
||||
|
||||
const toolInvocations = Array.from(this.toolCalls.entries()).map(([toolCallId, info]) => ({
|
||||
toolCallId,
|
||||
toolName: info.toolName,
|
||||
args: info.input,
|
||||
result: this.toolResults.get(toolCallId)
|
||||
}))
|
||||
|
||||
const message: Record<string, unknown> = {
|
||||
role,
|
||||
content
|
||||
}
|
||||
|
||||
if (toolInvocations.length > 0) {
|
||||
message.toolInvocations = toolInvocations
|
||||
}
|
||||
|
||||
return message as ModelMessage
|
||||
}
|
||||
}
|
||||
|
||||
export class SessionMessageService extends BaseService {
|
||||
@ -175,7 +153,6 @@ export class SessionMessageService extends BaseService {
|
||||
abortController: AbortController
|
||||
): Promise<SessionStreamResult> {
|
||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||
let newAgentSessionId = ''
|
||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||
|
||||
if (session.agent_type !== 'claude-code') {
|
||||
@ -222,10 +199,6 @@ export class SessionMessageService extends BaseService {
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'start' && chunk.messageId) {
|
||||
newAgentSessionId = chunk.messageId
|
||||
}
|
||||
|
||||
accumulator.add(chunk)
|
||||
controller.enqueue(chunk)
|
||||
break
|
||||
|
||||
@ -70,23 +70,18 @@ function generateTextChunks(id: string, text: string, message: SDKMessage): Agen
|
||||
return [
|
||||
{
|
||||
type: 'text-start',
|
||||
id,
|
||||
providerMetadata
|
||||
id
|
||||
},
|
||||
{
|
||||
type: 'text-delta',
|
||||
id,
|
||||
text,
|
||||
providerMetadata
|
||||
text
|
||||
},
|
||||
{
|
||||
type: 'text-end',
|
||||
id,
|
||||
providerMetadata: {
|
||||
...providerMetadata,
|
||||
text: {
|
||||
value: text
|
||||
}
|
||||
...providerMetadata
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -119,17 +114,22 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
||||
})
|
||||
break
|
||||
case 'tool_result':
|
||||
// chunks.push({
|
||||
// type: 'tool-result',
|
||||
// toolCallId: block.tool_use_id,
|
||||
// output: block.content,
|
||||
// providerMetadata: sdkMessageToProviderMetadata(message)
|
||||
// })
|
||||
chunks.push({
|
||||
type: 'tool-result',
|
||||
toolCallId: block.tool_use_id,
|
||||
toolName: '',
|
||||
input: '',
|
||||
output: block.content,
|
||||
})
|
||||
break
|
||||
default:
|
||||
logger.warn('Unknown content block type in user/assistant message:', {
|
||||
type: block.type
|
||||
})
|
||||
chunks.push({
|
||||
type: 'raw',
|
||||
rawValue: block
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -142,7 +142,7 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
||||
function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): AgentStreamPart[] {
|
||||
const chunks: AgentStreamPart[] = []
|
||||
const event = message.event
|
||||
const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.index}`
|
||||
const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.type}`
|
||||
|
||||
switch (event.type) {
|
||||
case 'message_start':
|
||||
@ -255,15 +255,13 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
|
||||
}
|
||||
contentBlockState.delete(blockKey)
|
||||
}
|
||||
|
||||
break
|
||||
case 'message_delta':
|
||||
// Handle usage updates or other message-level deltas
|
||||
break
|
||||
|
||||
case 'message_stop':
|
||||
// This could signal the end of the message
|
||||
break
|
||||
|
||||
default:
|
||||
logger.warn('Unknown stream event type:', { type: (event as any).type })
|
||||
break
|
||||
@ -283,9 +281,19 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
|
||||
chunks.push({
|
||||
type: 'start'
|
||||
})
|
||||
chunks.push({
|
||||
type: 'raw',
|
||||
rawValue: {
|
||||
type: 'init',
|
||||
session_id: message.session_id,
|
||||
slash_commands: message.slash_commands,
|
||||
tools: message.tools,
|
||||
raw: message
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
return []
|
||||
return chunks
|
||||
}
|
||||
|
||||
// Handle result messages (completion with usage stats)
|
||||
@ -295,14 +303,9 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
||||
let usage: LanguageModelV2Usage | undefined
|
||||
if ('usage' in message) {
|
||||
usage = {
|
||||
inputTokens:
|
||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
||||
(message.usage.cache_read_input_tokens ?? 0) +
|
||||
(message.usage.input_tokens ?? 0),
|
||||
inputTokens: message.usage.input_tokens ?? 0,
|
||||
outputTokens: message.usage.output_tokens ?? 0,
|
||||
totalTokens:
|
||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
||||
(message.usage.cache_read_input_tokens ?? 0) +
|
||||
(message.usage.input_tokens ?? 0) +
|
||||
(message.usage.output_tokens ?? 0)
|
||||
}
|
||||
@ -330,25 +333,3 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// Convenience function to transform a stream of SDKMessages
|
||||
export function* transformSDKMessageStream(sdkMessages: SDKMessage[]): Generator<AgentStreamPart> {
|
||||
for (const sdkMessage of sdkMessages) {
|
||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
||||
for (const chunk of chunks) {
|
||||
yield chunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Async version for async iterables
|
||||
export async function* transformSDKMessageStreamAsync(
|
||||
sdkMessages: AsyncIterable<SDKMessage>
|
||||
): AsyncGenerator<AgentStreamPart> {
|
||||
for await (const sdkMessage of sdkMessages) {
|
||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
||||
for (const chunk of chunks) {
|
||||
yield chunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,13 +111,10 @@ export class AiSdkToChunkAdapter {
|
||||
chunk: TextStreamPart<any>,
|
||||
final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string }
|
||||
) {
|
||||
const sessionId =
|
||||
(chunk.providerMetadata as any)?.anthropic?.session_id ??
|
||||
(chunk.providerMetadata as any)?.anthropic?.sessionId ??
|
||||
(chunk.providerMetadata as any)?.raw?.session_id
|
||||
|
||||
if (typeof sessionId === 'string' && sessionId) {
|
||||
this.onSessionUpdate?.(sessionId)
|
||||
// @ts-ignore
|
||||
if (chunk.type === 'raw' && chunk.rawValue.type === 'init' && chunk.rawValue.session_id) {
|
||||
// @ts-ignore
|
||||
this.onSessionUpdate?.(chunk.rawValue.session_id)
|
||||
}
|
||||
|
||||
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import type {ProviderMetadata} from "ai";
|
||||
import type { CompletionUsage } from 'openai/resources'
|
||||
|
||||
import type {
|
||||
@ -203,6 +204,10 @@ export type Message = {
|
||||
|
||||
// 跟踪Id
|
||||
traceId?: string
|
||||
|
||||
// raw data
|
||||
// TODO: add this providerMetadata to MessageBlock to save raw provider data for each block
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
|
||||
export interface Response {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user