mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 14:29:15 +08:00
refactor: streamline session handling and enhance message transformation
This commit is contained in:
parent
01ffd4c4ca
commit
c3b2af5a15
@ -83,28 +83,6 @@ class TextStreamAccumulator {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toModelMessage(role: ModelMessage['role'] = 'assistant'): ModelMessage {
|
|
||||||
const content = this.totalText || this.textBuffer || ''
|
|
||||||
|
|
||||||
const toolInvocations = Array.from(this.toolCalls.entries()).map(([toolCallId, info]) => ({
|
|
||||||
toolCallId,
|
|
||||||
toolName: info.toolName,
|
|
||||||
args: info.input,
|
|
||||||
result: this.toolResults.get(toolCallId)
|
|
||||||
}))
|
|
||||||
|
|
||||||
const message: Record<string, unknown> = {
|
|
||||||
role,
|
|
||||||
content
|
|
||||||
}
|
|
||||||
|
|
||||||
if (toolInvocations.length > 0) {
|
|
||||||
message.toolInvocations = toolInvocations
|
|
||||||
}
|
|
||||||
|
|
||||||
return message as ModelMessage
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export class SessionMessageService extends BaseService {
|
export class SessionMessageService extends BaseService {
|
||||||
@ -175,7 +153,6 @@ export class SessionMessageService extends BaseService {
|
|||||||
abortController: AbortController
|
abortController: AbortController
|
||||||
): Promise<SessionStreamResult> {
|
): Promise<SessionStreamResult> {
|
||||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||||
let newAgentSessionId = ''
|
|
||||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||||
|
|
||||||
if (session.agent_type !== 'claude-code') {
|
if (session.agent_type !== 'claude-code') {
|
||||||
@ -222,10 +199,6 @@ export class SessionMessageService extends BaseService {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chunk.type === 'start' && chunk.messageId) {
|
|
||||||
newAgentSessionId = chunk.messageId
|
|
||||||
}
|
|
||||||
|
|
||||||
accumulator.add(chunk)
|
accumulator.add(chunk)
|
||||||
controller.enqueue(chunk)
|
controller.enqueue(chunk)
|
||||||
break
|
break
|
||||||
|
|||||||
@ -70,23 +70,18 @@ function generateTextChunks(id: string, text: string, message: SDKMessage): Agen
|
|||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
type: 'text-start',
|
type: 'text-start',
|
||||||
id,
|
id
|
||||||
providerMetadata
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'text-delta',
|
type: 'text-delta',
|
||||||
id,
|
id,
|
||||||
text,
|
text
|
||||||
providerMetadata
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'text-end',
|
type: 'text-end',
|
||||||
id,
|
id,
|
||||||
providerMetadata: {
|
providerMetadata: {
|
||||||
...providerMetadata,
|
...providerMetadata
|
||||||
text: {
|
|
||||||
value: text
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -119,17 +114,22 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
|||||||
})
|
})
|
||||||
break
|
break
|
||||||
case 'tool_result':
|
case 'tool_result':
|
||||||
// chunks.push({
|
chunks.push({
|
||||||
// type: 'tool-result',
|
type: 'tool-result',
|
||||||
// toolCallId: block.tool_use_id,
|
toolCallId: block.tool_use_id,
|
||||||
// output: block.content,
|
toolName: '',
|
||||||
// providerMetadata: sdkMessageToProviderMetadata(message)
|
input: '',
|
||||||
// })
|
output: block.content,
|
||||||
|
})
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
logger.warn('Unknown content block type in user/assistant message:', {
|
logger.warn('Unknown content block type in user/assistant message:', {
|
||||||
type: block.type
|
type: block.type
|
||||||
})
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'raw',
|
||||||
|
rawValue: block
|
||||||
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,7 +142,7 @@ function handleUserOrAssistantMessage(message: Extract<SDKMessage, { type: 'assi
|
|||||||
function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): AgentStreamPart[] {
|
function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }>): AgentStreamPart[] {
|
||||||
const chunks: AgentStreamPart[] = []
|
const chunks: AgentStreamPart[] = []
|
||||||
const event = message.event
|
const event = message.event
|
||||||
const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.index}`
|
const blockKey = `${message.uuid ?? message.session_id ?? 'session'}:${event.type}`
|
||||||
|
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case 'message_start':
|
case 'message_start':
|
||||||
@ -255,15 +255,13 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
|
|||||||
}
|
}
|
||||||
contentBlockState.delete(blockKey)
|
contentBlockState.delete(blockKey)
|
||||||
}
|
}
|
||||||
|
break
|
||||||
case 'message_delta':
|
case 'message_delta':
|
||||||
// Handle usage updates or other message-level deltas
|
// Handle usage updates or other message-level deltas
|
||||||
break
|
break
|
||||||
|
|
||||||
case 'message_stop':
|
case 'message_stop':
|
||||||
// This could signal the end of the message
|
// This could signal the end of the message
|
||||||
break
|
break
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logger.warn('Unknown stream event type:', { type: (event as any).type })
|
logger.warn('Unknown stream event type:', { type: (event as any).type })
|
||||||
break
|
break
|
||||||
@ -283,9 +281,19 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
|
|||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'start'
|
type: 'start'
|
||||||
})
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'raw',
|
||||||
|
rawValue: {
|
||||||
|
type: 'init',
|
||||||
|
session_id: message.session_id,
|
||||||
|
slash_commands: message.slash_commands,
|
||||||
|
tools: message.tools,
|
||||||
|
raw: message
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return []
|
return chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle result messages (completion with usage stats)
|
// Handle result messages (completion with usage stats)
|
||||||
@ -295,14 +303,9 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
|||||||
let usage: LanguageModelV2Usage | undefined
|
let usage: LanguageModelV2Usage | undefined
|
||||||
if ('usage' in message) {
|
if ('usage' in message) {
|
||||||
usage = {
|
usage = {
|
||||||
inputTokens:
|
inputTokens: message.usage.input_tokens ?? 0,
|
||||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
|
||||||
(message.usage.cache_read_input_tokens ?? 0) +
|
|
||||||
(message.usage.input_tokens ?? 0),
|
|
||||||
outputTokens: message.usage.output_tokens ?? 0,
|
outputTokens: message.usage.output_tokens ?? 0,
|
||||||
totalTokens:
|
totalTokens:
|
||||||
(message.usage.cache_creation_input_tokens ?? 0) +
|
|
||||||
(message.usage.cache_read_input_tokens ?? 0) +
|
|
||||||
(message.usage.input_tokens ?? 0) +
|
(message.usage.input_tokens ?? 0) +
|
||||||
(message.usage.output_tokens ?? 0)
|
(message.usage.output_tokens ?? 0)
|
||||||
}
|
}
|
||||||
@ -330,25 +333,3 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
|||||||
}
|
}
|
||||||
return chunks
|
return chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convenience function to transform a stream of SDKMessages
|
|
||||||
export function* transformSDKMessageStream(sdkMessages: SDKMessage[]): Generator<AgentStreamPart> {
|
|
||||||
for (const sdkMessage of sdkMessages) {
|
|
||||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
|
||||||
for (const chunk of chunks) {
|
|
||||||
yield chunk
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Async version for async iterables
|
|
||||||
export async function* transformSDKMessageStreamAsync(
|
|
||||||
sdkMessages: AsyncIterable<SDKMessage>
|
|
||||||
): AsyncGenerator<AgentStreamPart> {
|
|
||||||
for await (const sdkMessage of sdkMessages) {
|
|
||||||
const chunks = transformSDKMessageToStreamParts(sdkMessage)
|
|
||||||
for (const chunk of chunks) {
|
|
||||||
yield chunk
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -111,13 +111,10 @@ export class AiSdkToChunkAdapter {
|
|||||||
chunk: TextStreamPart<any>,
|
chunk: TextStreamPart<any>,
|
||||||
final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string }
|
final: { text: string; reasoningContent: string; webSearchResults: AISDKWebSearchResult[]; reasoningId: string }
|
||||||
) {
|
) {
|
||||||
const sessionId =
|
// @ts-ignore
|
||||||
(chunk.providerMetadata as any)?.anthropic?.session_id ??
|
if (chunk.type === 'raw' && chunk.rawValue.type === 'init' && chunk.rawValue.session_id) {
|
||||||
(chunk.providerMetadata as any)?.anthropic?.sessionId ??
|
// @ts-ignore
|
||||||
(chunk.providerMetadata as any)?.raw?.session_id
|
this.onSessionUpdate?.(chunk.rawValue.session_id)
|
||||||
|
|
||||||
if (typeof sessionId === 'string' && sessionId) {
|
|
||||||
this.onSessionUpdate?.(sessionId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
|
logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import type {ProviderMetadata} from "ai";
|
||||||
import type { CompletionUsage } from 'openai/resources'
|
import type { CompletionUsage } from 'openai/resources'
|
||||||
|
|
||||||
import type {
|
import type {
|
||||||
@ -203,6 +204,10 @@ export type Message = {
|
|||||||
|
|
||||||
// 跟踪Id
|
// 跟踪Id
|
||||||
traceId?: string
|
traceId?: string
|
||||||
|
|
||||||
|
// raw data
|
||||||
|
// TODO: add this providerMetadata to MessageBlock to save raw provider data for each block
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Response {
|
export interface Response {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user