mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 23:10:20 +08:00
refactor: migrate v1 message streaming to v2 StreamingService for state management
- Replaced Redux dispatch and state management with StreamingService in BlockManager and various callback modules. - Simplified dependencies by removing unnecessary dispatch and getState parameters. - Updated block handling logic to utilize StreamingService for immediate updates and persistence during streaming. - Enhanced architecture for better performance and maintainability as part of the v2 data refactoring initiative.
This commit is contained in:
parent
01d8888601
commit
f2cd361ab8
@ -1,35 +1,48 @@
|
||||
/**
|
||||
* @fileoverview BlockManager - Manages block operations during message streaming
|
||||
*
|
||||
* This module handles the lifecycle and state management of message blocks
|
||||
* during the streaming process. It provides methods for:
|
||||
* - Smart block updates with throttling support
|
||||
* - Block type transitions
|
||||
* - Active block tracking
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* BlockManager now uses StreamingService for state management instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*
|
||||
* Key changes from original design:
|
||||
* - dispatch/getState replaced with streamingService methods
|
||||
* - DB saves removed during streaming (handled by finalize)
|
||||
* - Throttling logic preserved, but internal calls changed
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { AppDispatch, RootState } from '@renderer/store'
|
||||
import { updateOneBlock, upsertOneBlock } from '@renderer/store/messageBlock'
|
||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||
import type { MessageBlock } from '@renderer/types/newMessage'
|
||||
import { MessageBlockType } from '@renderer/types/newMessage'
|
||||
|
||||
import { streamingService } from './StreamingService'
|
||||
|
||||
const logger = loggerService.withContext('BlockManager')
|
||||
|
||||
/**
|
||||
* Information about the currently active block during streaming
|
||||
*/
|
||||
interface ActiveBlockInfo {
|
||||
id: string
|
||||
type: MessageBlockType
|
||||
}
|
||||
|
||||
/**
|
||||
* Dependencies required by BlockManager
|
||||
*
|
||||
* NOTE: Simplified from original design - removed dispatch, getState, and DB save functions
|
||||
* since StreamingService now handles state management and persistence.
|
||||
*/
|
||||
interface BlockManagerDependencies {
|
||||
dispatch: AppDispatch
|
||||
getState: () => RootState
|
||||
saveUpdatedBlockToDB: (
|
||||
blockId: string | null,
|
||||
messageId: string,
|
||||
topicId: string,
|
||||
getState: () => RootState
|
||||
) => Promise<void>
|
||||
saveUpdatesToDB: (
|
||||
messageId: string,
|
||||
topicId: string,
|
||||
messageUpdates: Partial<any>,
|
||||
blocksToUpdate: MessageBlock[]
|
||||
) => Promise<void>
|
||||
assistantMsgId: string
|
||||
topicId: string
|
||||
// 节流器管理从外部传入
|
||||
assistantMsgId: string
|
||||
// Throttling is still controlled externally by messageThunk.ts
|
||||
throttledBlockUpdate: (id: string, blockUpdate: any) => void
|
||||
cancelThrottledBlockUpdate: (id: string) => void
|
||||
}
|
||||
@ -37,9 +50,9 @@ interface BlockManagerDependencies {
|
||||
export class BlockManager {
|
||||
private deps: BlockManagerDependencies
|
||||
|
||||
// 简化后的状态管理
|
||||
// Simplified state management
|
||||
private _activeBlockInfo: ActiveBlockInfo | null = null
|
||||
private _lastBlockType: MessageBlockType | null = null // 保留用于错误处理
|
||||
private _lastBlockType: MessageBlockType | null = null // Preserved for error handling
|
||||
|
||||
constructor(dependencies: BlockManagerDependencies) {
|
||||
this.deps = dependencies
|
||||
@ -72,7 +85,15 @@ export class BlockManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* 智能更新策略:根据块类型连续性自动判断使用节流还是立即更新
|
||||
* Smart update strategy: automatically decides between throttled and immediate updates
|
||||
* based on block type continuity.
|
||||
*
|
||||
* Behavior:
|
||||
* - If block type changes: cancel previous throttle, immediately update via streamingService
|
||||
* - If block completes: cancel throttle, immediately update via streamingService
|
||||
* - Otherwise: use throttled update (throttler calls streamingService internally)
|
||||
*
|
||||
* NOTE: DB saves are removed - persistence happens during finalize()
|
||||
*/
|
||||
smartBlockUpdate(
|
||||
blockId: string,
|
||||
@ -82,61 +103,60 @@ export class BlockManager {
|
||||
) {
|
||||
const isBlockTypeChanged = this._lastBlockType !== null && this._lastBlockType !== blockType
|
||||
if (isBlockTypeChanged || isComplete) {
|
||||
// 如果块类型改变,则取消上一个块的节流更新
|
||||
// Cancel throttled update for previous block if type changed
|
||||
if (isBlockTypeChanged && this._activeBlockInfo) {
|
||||
this.deps.cancelThrottledBlockUpdate(this._activeBlockInfo.id)
|
||||
}
|
||||
// 如果当前块完成,则取消当前块的节流更新
|
||||
// Cancel throttled update for current block if complete
|
||||
if (isComplete) {
|
||||
this.deps.cancelThrottledBlockUpdate(blockId)
|
||||
this._activeBlockInfo = null // 块完成时清空activeBlockInfo
|
||||
this._activeBlockInfo = null // Clear activeBlockInfo when block completes
|
||||
} else {
|
||||
this._activeBlockInfo = { id: blockId, type: blockType } // 更新活跃块信息
|
||||
this._activeBlockInfo = { id: blockId, type: blockType } // Update active block info
|
||||
}
|
||||
this.deps.dispatch(updateOneBlock({ id: blockId, changes }))
|
||||
this.deps.saveUpdatedBlockToDB(blockId, this.deps.assistantMsgId, this.deps.topicId, this.deps.getState)
|
||||
|
||||
// Immediate update via StreamingService (replaces dispatch + DB save)
|
||||
streamingService.updateBlock(blockId, changes)
|
||||
this._lastBlockType = blockType
|
||||
} else {
|
||||
this._activeBlockInfo = { id: blockId, type: blockType } // 更新活跃块信息
|
||||
this._activeBlockInfo = { id: blockId, type: blockType } // Update active block info
|
||||
// Throttled update (throttler internally calls streamingService.updateBlock)
|
||||
this.deps.throttledBlockUpdate(blockId, changes)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理块转换
|
||||
* Handle block transitions (new block creation during streaming)
|
||||
*
|
||||
* This method:
|
||||
* 1. Updates active block tracking state
|
||||
* 2. Adds new block to StreamingService
|
||||
* 3. Updates message block references
|
||||
*
|
||||
* NOTE: DB saves are removed - persistence happens during finalize()
|
||||
*/
|
||||
async handleBlockTransition(newBlock: MessageBlock, newBlockType: MessageBlockType) {
|
||||
logger.debug('handleBlockTransition', { newBlock, newBlockType })
|
||||
this._lastBlockType = newBlockType
|
||||
this._activeBlockInfo = { id: newBlock.id, type: newBlockType } // 设置新的活跃块信息
|
||||
this._activeBlockInfo = { id: newBlock.id, type: newBlockType } // Set new active block info
|
||||
|
||||
this.deps.dispatch(
|
||||
newMessagesActions.updateMessage({
|
||||
topicId: this.deps.topicId,
|
||||
messageId: this.deps.assistantMsgId,
|
||||
updates: { blockInstruction: { id: newBlock.id } }
|
||||
})
|
||||
)
|
||||
this.deps.dispatch(upsertOneBlock(newBlock))
|
||||
this.deps.dispatch(
|
||||
newMessagesActions.upsertBlockReference({
|
||||
messageId: this.deps.assistantMsgId,
|
||||
blockId: newBlock.id,
|
||||
status: newBlock.status,
|
||||
blockType: newBlock.type
|
||||
})
|
||||
)
|
||||
// Add new block to StreamingService (replaces dispatch(upsertOneBlock))
|
||||
streamingService.addBlock(this.deps.assistantMsgId, newBlock)
|
||||
|
||||
const currentState = this.deps.getState()
|
||||
const updatedMessage = currentState.messages.entities[this.deps.assistantMsgId]
|
||||
if (updatedMessage) {
|
||||
await this.deps.saveUpdatesToDB(this.deps.assistantMsgId, this.deps.topicId, { blocks: updatedMessage.blocks }, [
|
||||
newBlock
|
||||
])
|
||||
} else {
|
||||
logger.error(
|
||||
`[handleBlockTransition] Failed to get updated message ${this.deps.assistantMsgId} from state for DB save.`
|
||||
)
|
||||
}
|
||||
// Update block reference in message (replaces dispatch(upsertBlockReference))
|
||||
streamingService.addBlockReference(this.deps.assistantMsgId, newBlock.id)
|
||||
|
||||
// TEMPORARY: The blockInstruction field was used for UI coordination.
|
||||
// TODO: Evaluate if this is still needed with StreamingService approach
|
||||
// For now, we update it in the message
|
||||
streamingService.updateMessage(this.deps.assistantMsgId, {
|
||||
blockInstruction: { id: newBlock.id }
|
||||
} as any) // Using 'as any' since blockInstruction may not be in Message type
|
||||
|
||||
logger.debug('Block transition completed', {
|
||||
messageId: this.deps.assistantMsgId,
|
||||
blockId: newBlock.id,
|
||||
blockType: newBlockType
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
572
src/renderer/src/services/messageStreaming/StreamingService.ts
Normal file
572
src/renderer/src/services/messageStreaming/StreamingService.ts
Normal file
@ -0,0 +1,572 @@
|
||||
/**
|
||||
* @fileoverview StreamingService - Manages message streaming lifecycle and state
|
||||
*
|
||||
* This service encapsulates the streaming state management during message generation.
|
||||
* It uses CacheService (memoryCache) for temporary storage during streaming,
|
||||
* and persists final data to the database via Data API or dbService.
|
||||
*
|
||||
* Key Design Decisions:
|
||||
* - Uses messageId as primary key for sessions (supports multi-model concurrent streaming)
|
||||
* - Streaming data is stored in memory only (not Redux, not Dexie during streaming)
|
||||
* - On finalize, data is converted to new format and persisted via appropriate data source
|
||||
* - Throttling is handled externally by messageThunk.ts (preserves existing throttle logic)
|
||||
*
|
||||
* Cache Key Strategy:
|
||||
* - Session key: `streaming:session:${messageId}` - Internal session lifecycle management
|
||||
* - Topic sessions index: `streaming:topic:${topicId}:sessions` - Track active sessions per topic
|
||||
* - Message key: `streaming:message:${messageId}` - UI subscription for message-level changes
|
||||
* - Block key: `streaming:block:${blockId}` - UI subscription for block content updates
|
||||
*/
|
||||
|
||||
import { cacheService } from '@data/CacheService'
|
||||
import { dataApiService } from '@data/DataApiService'
|
||||
import { loggerService } from '@logger'
|
||||
import type { Message, MessageBlock } from '@renderer/types/newMessage'
|
||||
import { AssistantMessageStatus, MessageBlockStatus } from '@renderer/types/newMessage'
|
||||
import { isAgentSessionTopicId } from '@renderer/utils/agentSession'
|
||||
import type { UpdateMessageDto } from '@shared/data/api/schemas/messages'
|
||||
import type { MessageDataBlock, MessageStats } from '@shared/data/types/message'
|
||||
|
||||
import { dbService } from '../db'
|
||||
|
||||
const logger = loggerService.withContext('StreamingService')
|
||||
|
||||
// Cache key generators
|
||||
const getSessionKey = (messageId: string) => `streaming:session:${messageId}`
|
||||
const getTopicSessionsKey = (topicId: string) => `streaming:topic:${topicId}:sessions`
|
||||
const getMessageKey = (messageId: string) => `streaming:message:${messageId}`
|
||||
const getBlockKey = (blockId: string) => `streaming:block:${blockId}`
|
||||
const getSiblingsGroupCounterKey = (topicId: string) => `streaming:topic:${topicId}:siblings-counter`
|
||||
|
||||
// Session TTL for auto-cleanup (prevents memory leaks from crashed processes)
|
||||
const SESSION_TTL = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
/**
|
||||
* Streaming session data structure (stored in memory)
|
||||
*/
|
||||
interface StreamingSession {
|
||||
topicId: string
|
||||
messageId: string
|
||||
|
||||
// Message data (legacy format, compatible with existing logic)
|
||||
message: Message
|
||||
blocks: Record<string, MessageBlock>
|
||||
|
||||
// Tree structure information (v2 new fields)
|
||||
parentId: string // Parent message ID (user message)
|
||||
siblingsGroupId: number // Multi-model group ID (0=normal, >0=multi-model response)
|
||||
|
||||
// Context for usage estimation (messages up to and including user message)
|
||||
contextMessages?: Message[]
|
||||
|
||||
// Metadata
|
||||
startedAt: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for starting a streaming session
|
||||
*/
|
||||
interface StartSessionOptions {
|
||||
parentId: string
|
||||
siblingsGroupId?: number // Defaults to 0
|
||||
role: 'assistant'
|
||||
model?: Message['model']
|
||||
modelId?: string
|
||||
assistantId: string
|
||||
askId?: string
|
||||
traceId?: string
|
||||
agentSessionId?: string
|
||||
// Context messages for usage estimation (messages up to and including user message)
|
||||
contextMessages?: Message[]
|
||||
}
|
||||
|
||||
/**
|
||||
* StreamingService - Manages streaming message state during generation
|
||||
*
|
||||
* Responsibilities:
|
||||
* - Session lifecycle management (start, update, finalize, clear)
|
||||
* - Block operations (add, update, get)
|
||||
* - Message operations (update, get)
|
||||
* - Cache-based state management with automatic TTL cleanup
|
||||
*/
|
||||
class StreamingService {
|
||||
// Internal mapping: blockId -> messageId (for efficient block updates)
|
||||
private blockToMessageMap = new Map<string, string>()
|
||||
|
||||
// ============ Session Lifecycle ============
|
||||
|
||||
/**
|
||||
* Start a streaming session for a message
|
||||
*
|
||||
* IMPORTANT: The message must be created via Data API POST before calling this.
|
||||
* This method initializes the in-memory streaming state.
|
||||
*
|
||||
* @param topicId - Topic ID (used for topic sessions index)
|
||||
* @param messageId - Message ID returned from Data API POST
|
||||
* @param options - Session options including parentId and siblingsGroupId
|
||||
*/
|
||||
startSession(topicId: string, messageId: string, options: StartSessionOptions): void {
|
||||
const {
|
||||
parentId,
|
||||
siblingsGroupId = 0,
|
||||
role,
|
||||
model,
|
||||
modelId,
|
||||
assistantId,
|
||||
askId,
|
||||
traceId,
|
||||
agentSessionId,
|
||||
contextMessages
|
||||
} = options
|
||||
|
||||
// Initialize message structure
|
||||
const message: Message = {
|
||||
id: messageId,
|
||||
topicId,
|
||||
role,
|
||||
assistantId,
|
||||
status: AssistantMessageStatus.PENDING,
|
||||
createdAt: new Date().toISOString(),
|
||||
blocks: [],
|
||||
model,
|
||||
modelId,
|
||||
askId,
|
||||
traceId,
|
||||
agentSessionId
|
||||
}
|
||||
|
||||
// Create session
|
||||
const session: StreamingSession = {
|
||||
topicId,
|
||||
messageId,
|
||||
message,
|
||||
blocks: {},
|
||||
parentId,
|
||||
siblingsGroupId,
|
||||
contextMessages,
|
||||
startedAt: Date.now()
|
||||
}
|
||||
|
||||
// Store session with TTL
|
||||
cacheService.setCasual(getSessionKey(messageId), session, SESSION_TTL)
|
||||
|
||||
// Store message data for UI subscription
|
||||
cacheService.setCasual(getMessageKey(messageId), message, SESSION_TTL)
|
||||
|
||||
// Add to topic sessions index
|
||||
const topicSessions = cacheService.getCasual<string[]>(getTopicSessionsKey(topicId)) || []
|
||||
if (!topicSessions.includes(messageId)) {
|
||||
topicSessions.push(messageId)
|
||||
cacheService.setCasual(getTopicSessionsKey(topicId), topicSessions, SESSION_TTL)
|
||||
}
|
||||
|
||||
logger.debug('Started streaming session', { topicId, messageId, parentId, siblingsGroupId })
|
||||
}
|
||||
|
||||
/**
|
||||
* Finalize a streaming session by persisting data to database
|
||||
*
|
||||
* This method:
|
||||
* 1. Converts streaming data to the appropriate format
|
||||
* 2. Routes to Data API (normal topics) or dbService (agent topics)
|
||||
* 3. Cleans up all related cache keys
|
||||
*
|
||||
* @param messageId - Session message ID
|
||||
* @param status - Final message status
|
||||
*/
|
||||
async finalize(messageId: string, status: AssistantMessageStatus): Promise<void> {
|
||||
const session = this.getSession(messageId)
|
||||
if (!session) {
|
||||
logger.warn(`finalize called for non-existent session: ${messageId}`)
|
||||
return
|
||||
}
|
||||
|
||||
const maxRetries = 3
|
||||
let lastError: Error | null = null
|
||||
|
||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
||||
try {
|
||||
const updatePayload = this.convertToUpdatePayload(session, status)
|
||||
|
||||
// TRADEOFF: Using dbService for agent messages instead of Data API
|
||||
// because agent message storage refactoring is planned for later phase.
|
||||
// TODO: Unify to Data API when agent message migration is complete.
|
||||
if (isAgentSessionTopicId(session.topicId)) {
|
||||
await dbService.updateMessageAndBlocks(session.topicId, updatePayload.messageUpdates, updatePayload.blocks)
|
||||
} else {
|
||||
// Normal topic → Use Data API for persistence
|
||||
const dataApiPayload = this.convertToDataApiFormat(session, status)
|
||||
await dataApiService.patch(`/messages/${session.messageId}`, { body: dataApiPayload })
|
||||
}
|
||||
|
||||
// Success - cleanup session
|
||||
this.clearSession(messageId)
|
||||
logger.debug('Finalized streaming session', { messageId, status })
|
||||
return
|
||||
} catch (error) {
|
||||
lastError = error as Error
|
||||
logger.warn(`finalize attempt ${attempt}/${maxRetries} failed:`, error as Error)
|
||||
|
||||
if (attempt < maxRetries) {
|
||||
// Exponential backoff
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000 * attempt))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All retries failed
|
||||
logger.error(`finalize failed after ${maxRetries} attempts:`, lastError)
|
||||
// TRADEOFF: Don't clear session to allow manual retry
|
||||
// TTL will auto-clean to prevent permanent memory leak
|
||||
throw lastError
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear a streaming session and all related cache keys
|
||||
*
|
||||
* @param messageId - Session message ID
|
||||
*/
|
||||
clearSession(messageId: string): void {
|
||||
const session = this.getSession(messageId)
|
||||
if (!session) {
|
||||
return
|
||||
}
|
||||
|
||||
// Remove block mappings
|
||||
Object.keys(session.blocks).forEach((blockId) => {
|
||||
this.blockToMessageMap.delete(blockId)
|
||||
cacheService.deleteCasual(getBlockKey(blockId))
|
||||
})
|
||||
|
||||
// Remove message cache
|
||||
cacheService.deleteCasual(getMessageKey(messageId))
|
||||
|
||||
// Remove from topic sessions index
|
||||
const topicSessions = cacheService.getCasual<string[]>(getTopicSessionsKey(session.topicId)) || []
|
||||
const updatedTopicSessions = topicSessions.filter((id) => id !== messageId)
|
||||
if (updatedTopicSessions.length > 0) {
|
||||
cacheService.setCasual(getTopicSessionsKey(session.topicId), updatedTopicSessions, SESSION_TTL)
|
||||
} else {
|
||||
cacheService.deleteCasual(getTopicSessionsKey(session.topicId))
|
||||
}
|
||||
|
||||
// Remove session
|
||||
cacheService.deleteCasual(getSessionKey(messageId))
|
||||
|
||||
logger.debug('Cleared streaming session', { messageId, topicId: session.topicId })
|
||||
}
|
||||
|
||||
// ============ Block Operations ============
|
||||
|
||||
/**
|
||||
* Add a new block to a streaming session
|
||||
* (Replaces dispatch(upsertOneBlock))
|
||||
*
|
||||
* @param messageId - Parent message ID
|
||||
* @param block - Block to add
|
||||
*/
|
||||
addBlock(messageId: string, block: MessageBlock): void {
|
||||
const session = this.getSession(messageId)
|
||||
if (!session) {
|
||||
logger.warn(`addBlock called for non-existent session: ${messageId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// Register block mapping
|
||||
this.blockToMessageMap.set(block.id, messageId)
|
||||
|
||||
// Add to session
|
||||
session.blocks[block.id] = block
|
||||
|
||||
// Update message block references
|
||||
if (!session.message.blocks.includes(block.id)) {
|
||||
session.message.blocks = [...session.message.blocks, block.id]
|
||||
}
|
||||
|
||||
// Update caches
|
||||
cacheService.setCasual(getSessionKey(messageId), session, SESSION_TTL)
|
||||
cacheService.setCasual(getBlockKey(block.id), block, SESSION_TTL)
|
||||
cacheService.setCasual(getMessageKey(messageId), session.message, SESSION_TTL)
|
||||
|
||||
logger.debug('Added block to session', { messageId, blockId: block.id, blockType: block.type })
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a block in a streaming session
|
||||
* (Replaces dispatch(updateOneBlock))
|
||||
*
|
||||
* NOTE: This method does NOT include throttling. Throttling is controlled
|
||||
* by the existing throttler in messageThunk.ts.
|
||||
*
|
||||
* @param blockId - Block ID to update
|
||||
* @param changes - Partial block changes
|
||||
*/
|
||||
updateBlock(blockId: string, changes: Partial<MessageBlock>): void {
|
||||
const messageId = this.blockToMessageMap.get(blockId)
|
||||
if (!messageId) {
|
||||
logger.warn(`updateBlock: Block ${blockId} not found in blockToMessageMap`)
|
||||
return
|
||||
}
|
||||
|
||||
const session = this.getSession(messageId)
|
||||
if (!session) {
|
||||
logger.warn(`updateBlock: Session not found for message ${messageId}`)
|
||||
return
|
||||
}
|
||||
|
||||
const existingBlock = session.blocks[blockId]
|
||||
if (!existingBlock) {
|
||||
logger.warn(`updateBlock: Block ${blockId} not found in session`)
|
||||
return
|
||||
}
|
||||
|
||||
// Merge changes - use type assertion since we're updating the same block type
|
||||
const updatedBlock = { ...existingBlock, ...changes } as MessageBlock
|
||||
session.blocks[blockId] = updatedBlock
|
||||
|
||||
// Update caches
|
||||
cacheService.setCasual(getSessionKey(messageId), session, SESSION_TTL)
|
||||
cacheService.setCasual(getBlockKey(blockId), updatedBlock, SESSION_TTL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a block from the streaming session
|
||||
*
|
||||
* @param blockId - Block ID
|
||||
* @returns Block or null if not found
|
||||
*/
|
||||
getBlock(blockId: string): MessageBlock | null {
|
||||
return cacheService.getCasual<MessageBlock>(getBlockKey(blockId)) || null
|
||||
}
|
||||
|
||||
// ============ Message Operations ============
|
||||
|
||||
/**
|
||||
* Update message properties in the streaming session
|
||||
* (Replaces dispatch(newMessagesActions.updateMessage))
|
||||
*
|
||||
* @param messageId - Message ID
|
||||
* @param updates - Partial message updates
|
||||
*/
|
||||
updateMessage(messageId: string, updates: Partial<Message>): void {
|
||||
const session = this.getSession(messageId)
|
||||
if (!session) {
|
||||
logger.warn(`updateMessage called for non-existent session: ${messageId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// Merge updates
|
||||
session.message = { ...session.message, ...updates }
|
||||
|
||||
// Update caches
|
||||
cacheService.setCasual(getSessionKey(messageId), session, SESSION_TTL)
|
||||
cacheService.setCasual(getMessageKey(messageId), session.message, SESSION_TTL)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a block reference to the message
|
||||
* (Replaces dispatch(newMessagesActions.upsertBlockReference))
|
||||
*
|
||||
* Note: In the streaming context, we just need to track the block ID in message.blocks
|
||||
* The block reference details are maintained in the block itself
|
||||
*
|
||||
* @param messageId - Message ID
|
||||
* @param blockId - Block ID to reference
|
||||
*/
|
||||
addBlockReference(messageId: string, blockId: string): void {
|
||||
const session = this.getSession(messageId)
|
||||
if (!session) {
|
||||
logger.warn(`addBlockReference called for non-existent session: ${messageId}`)
|
||||
return
|
||||
}
|
||||
|
||||
if (!session.message.blocks.includes(blockId)) {
|
||||
session.message.blocks = [...session.message.blocks, blockId]
|
||||
cacheService.setCasual(getSessionKey(messageId), session, SESSION_TTL)
|
||||
cacheService.setCasual(getMessageKey(messageId), session.message, SESSION_TTL)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a message from the streaming session
|
||||
*
|
||||
* @param messageId - Message ID
|
||||
* @returns Message or null if not found
|
||||
*/
|
||||
getMessage(messageId: string): Message | null {
|
||||
return cacheService.getCasual<Message>(getMessageKey(messageId)) || null
|
||||
}
|
||||
|
||||
// ============ Query Methods ============
|
||||
|
||||
/**
|
||||
* Check if a topic has any active streaming sessions
|
||||
*
|
||||
* @param topicId - Topic ID
|
||||
* @returns True if streaming is active
|
||||
*/
|
||||
isStreaming(topicId: string): boolean {
|
||||
const topicSessions = cacheService.getCasual<string[]>(getTopicSessionsKey(topicId)) || []
|
||||
return topicSessions.length > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a specific message is currently streaming
|
||||
*
|
||||
* @param messageId - Message ID
|
||||
* @returns True if message is streaming
|
||||
*/
|
||||
isMessageStreaming(messageId: string): boolean {
|
||||
return cacheService.hasCasual(getSessionKey(messageId))
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the streaming session for a message
|
||||
*
|
||||
* @param messageId - Message ID
|
||||
* @returns Session or null if not found
|
||||
*/
|
||||
getSession(messageId: string): StreamingSession | null {
|
||||
return cacheService.getCasual<StreamingSession>(getSessionKey(messageId)) || null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all active streaming message IDs for a topic
|
||||
*
|
||||
* @param topicId - Topic ID
|
||||
* @returns Array of message IDs
|
||||
*/
|
||||
getActiveMessageIds(topicId: string): string[] {
|
||||
return cacheService.getCasual<string[]>(getTopicSessionsKey(topicId)) || []
|
||||
}
|
||||
|
||||
// ============ siblingsGroupId Generation ============
|
||||
|
||||
/**
|
||||
* Generate the next siblingsGroupId for a topic.
|
||||
*
|
||||
* Used for multi-model responses where multiple assistant messages
|
||||
* share the same parentId and siblingsGroupId (>0).
|
||||
*
|
||||
* The counter is stored in CacheService and auto-increments.
|
||||
* Single-model responses should use siblingsGroupId=0 (not generated here).
|
||||
*
|
||||
* @param topicId - Topic ID
|
||||
* @returns Next siblingsGroupId (always > 0)
|
||||
*/
|
||||
//FIXME [v2] 现在获取 siblingsGroupId 的方式是不正确,后续再做修改调整
|
||||
generateNextGroupId(topicId: string): number {
|
||||
const counterKey = getSiblingsGroupCounterKey(topicId)
|
||||
const currentCounter = cacheService.getCasual<number>(counterKey) || 0
|
||||
const nextGroupId = currentCounter + 1
|
||||
// Store with no TTL (persistent within session, cleared on app restart)
|
||||
cacheService.setCasual(counterKey, nextGroupId)
|
||||
logger.debug('Generated siblingsGroupId', { topicId, siblingsGroupId: nextGroupId })
|
||||
return nextGroupId
|
||||
}
|
||||
|
||||
// ============ Internal Methods ============
|
||||
|
||||
/**
|
||||
* Convert session data to database update payload
|
||||
*
|
||||
* @param session - Streaming session
|
||||
* @param status - Final message status
|
||||
* @returns Update payload for database
|
||||
*/
|
||||
private convertToUpdatePayload(
|
||||
session: StreamingSession,
|
||||
status: AssistantMessageStatus
|
||||
): {
|
||||
messageUpdates: Partial<Message> & Pick<Message, 'id'>
|
||||
blocks: MessageBlock[]
|
||||
} {
|
||||
const blocks = Object.values(session.blocks)
|
||||
|
||||
// Ensure all blocks have final status
|
||||
// Use type assertion since we're only updating the status field
|
||||
const finalizedBlocks: MessageBlock[] = blocks.map((block) => {
|
||||
if (block.status === MessageBlockStatus.STREAMING || block.status === MessageBlockStatus.PROCESSING) {
|
||||
const finalizedBlock = {
|
||||
...block,
|
||||
status: status === AssistantMessageStatus.SUCCESS ? MessageBlockStatus.SUCCESS : MessageBlockStatus.ERROR
|
||||
}
|
||||
return finalizedBlock as typeof block
|
||||
}
|
||||
return block
|
||||
})
|
||||
|
||||
const messageUpdates: Partial<Message> & Pick<Message, 'id'> = {
|
||||
id: session.messageId,
|
||||
status,
|
||||
blocks: session.message.blocks,
|
||||
updatedAt: new Date().toISOString(),
|
||||
// Include usage and metrics if available
|
||||
...(session.message.usage && { usage: session.message.usage }),
|
||||
...(session.message.metrics && { metrics: session.message.metrics })
|
||||
}
|
||||
|
||||
return {
|
||||
messageUpdates,
|
||||
blocks: finalizedBlocks
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert session data to Data API UpdateMessageDto format
|
||||
*
|
||||
* Converts from renderer format (MessageBlock with id/status) to
|
||||
* shared format (MessageDataBlock without id/status) for Data API persistence.
|
||||
*
|
||||
* @param session - Streaming session
|
||||
* @param status - Final message status
|
||||
* @returns UpdateMessageDto for Data API PATCH request
|
||||
*/
|
||||
private convertToDataApiFormat(session: StreamingSession, status: AssistantMessageStatus): UpdateMessageDto {
|
||||
const blocks = Object.values(session.blocks)
|
||||
|
||||
// Convert MessageBlock[] to MessageDataBlock[]
|
||||
// Remove id, status, messageId fields as they are renderer-specific, not part of MessageDataBlock
|
||||
// TRADEOFF: Using 'as unknown as' because renderer's MessageBlockType and shared's BlockType
|
||||
// are structurally identical but TypeScript treats them as incompatible enums.
|
||||
const dataBlocks: MessageDataBlock[] = blocks.map((block) => {
|
||||
// Extract only the fields that belong to MessageDataBlock
|
||||
const {
|
||||
id: _id,
|
||||
status: _blockStatus,
|
||||
messageId: _messageId,
|
||||
...blockData
|
||||
} = block as MessageBlock & {
|
||||
messageId?: string
|
||||
}
|
||||
|
||||
return blockData as unknown as MessageDataBlock
|
||||
})
|
||||
|
||||
// Build MessageStats from usage and metrics
|
||||
// Note: Renderer uses 'time_first_token_millsec' while shared uses 'timeFirstTokenMs'
|
||||
const stats: MessageStats | undefined =
|
||||
session.message.usage || session.message.metrics
|
||||
? {
|
||||
promptTokens: session.message.usage?.prompt_tokens,
|
||||
completionTokens: session.message.usage?.completion_tokens,
|
||||
totalTokens: session.message.usage?.total_tokens,
|
||||
timeFirstTokenMs: session.message.metrics?.time_first_token_millsec,
|
||||
timeCompletionMs: session.message.metrics?.time_completion_millsec
|
||||
}
|
||||
: undefined
|
||||
|
||||
return {
|
||||
data: { blocks: dataBlocks },
|
||||
status: status as 'pending' | 'success' | 'error' | 'paused',
|
||||
...(stats && { stats })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const streamingService = new StreamingService()
|
||||
|
||||
// Also export class for testing
|
||||
export { StreamingService }
|
||||
export type { StartSessionOptions, StreamingSession }
|
||||
@ -1,12 +1,26 @@
|
||||
/**
|
||||
* @fileoverview Base callbacks for streaming message processing
|
||||
*
|
||||
* This module provides the core callback handlers for message streaming:
|
||||
* - onLLMResponseCreated: Initialize placeholder block for incoming response
|
||||
* - onError: Handle streaming errors and cleanup
|
||||
* - onComplete: Finalize streaming and persist to database
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* These callbacks now use StreamingService for state management instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*
|
||||
* Key changes:
|
||||
* - dispatch/getState replaced with streamingService methods
|
||||
* - saveUpdatesToDB replaced with streamingService.finalize()
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { autoRenameTopic } from '@renderer/hooks/useTopic'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
import { NotificationService } from '@renderer/services/NotificationService'
|
||||
import { estimateMessagesUsage } from '@renderer/services/TokenService'
|
||||
import { updateOneBlock } from '@renderer/store/messageBlock'
|
||||
import { selectMessagesForTopic } from '@renderer/store/newMessage'
|
||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||
import type { Assistant } from '@renderer/types'
|
||||
import type { PlaceholderMessageBlock, Response } from '@renderer/types/newMessage'
|
||||
import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
@ -19,47 +33,58 @@ import type { AISDKError } from 'ai'
|
||||
import { NoOutputGeneratedError } from 'ai'
|
||||
|
||||
import type { BlockManager } from '../BlockManager'
|
||||
import { streamingService } from '../StreamingService'
|
||||
|
||||
const logger = loggerService.withContext('BaseCallbacks')
|
||||
|
||||
/**
|
||||
* Dependencies required for base callbacks
|
||||
*
|
||||
* NOTE: Simplified from original design - removed dispatch, getState, and saveUpdatesToDB
|
||||
* since StreamingService now handles state management and persistence.
|
||||
*/
|
||||
interface BaseCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
dispatch: any
|
||||
getState: any
|
||||
topicId: string
|
||||
assistantMsgId: string
|
||||
saveUpdatesToDB: any
|
||||
assistant: Assistant
|
||||
}
|
||||
|
||||
export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
const { blockManager, dispatch, getState, topicId, assistantMsgId, saveUpdatesToDB, assistant } = deps
|
||||
const { blockManager, topicId, assistantMsgId, assistant } = deps
|
||||
|
||||
const startTime = Date.now()
|
||||
const notificationService = NotificationService.getInstance()
|
||||
|
||||
// 通用的 block 查找函数
|
||||
const findBlockIdForCompletion = (message?: any) => {
|
||||
// 优先使用 BlockManager 中的 activeBlockInfo
|
||||
/**
|
||||
* Find the block ID that should receive completion updates.
|
||||
* Priority: active block > latest block in message > initial placeholder
|
||||
*/
|
||||
const findBlockIdForCompletion = () => {
|
||||
// Priority 1: Use active block from BlockManager
|
||||
const activeBlockInfo = blockManager.activeBlockInfo
|
||||
|
||||
if (activeBlockInfo) {
|
||||
return activeBlockInfo.id
|
||||
}
|
||||
|
||||
// 如果没有活跃的block,从message中查找最新的block作为备选
|
||||
const targetMessage = message || getState().messages.entities[assistantMsgId]
|
||||
if (targetMessage) {
|
||||
const allBlocks = findAllBlocks(targetMessage)
|
||||
// Priority 2: Find latest block from StreamingService message
|
||||
const message = streamingService.getMessage(assistantMsgId)
|
||||
if (message) {
|
||||
const allBlocks = findAllBlocks(message)
|
||||
if (allBlocks.length > 0) {
|
||||
return allBlocks[allBlocks.length - 1].id // 返回最新的block
|
||||
return allBlocks[allBlocks.length - 1].id
|
||||
}
|
||||
}
|
||||
|
||||
// 最后的备选方案:从 blockManager 获取占位符块ID
|
||||
// Priority 3: Initial placeholder block
|
||||
return blockManager.initialPlaceholderBlockId
|
||||
}
|
||||
|
||||
return {
|
||||
/**
|
||||
* Called when LLM response stream is created.
|
||||
* Creates an initial placeholder block to receive streaming content.
|
||||
*/
|
||||
onLLMResponseCreated: async () => {
|
||||
const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
@ -67,6 +92,10 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
await blockManager.handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
},
|
||||
|
||||
/**
|
||||
* Called when an error occurs during streaming.
|
||||
* Updates block and message status, creates error block, and finalizes session.
|
||||
*/
|
||||
onError: async (error: AISDKError) => {
|
||||
logger.debug('onError', error)
|
||||
if (NoOutputGeneratedError.isInstance(error)) {
|
||||
@ -79,7 +108,8 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
// 发送错误通知(除了中止错误)
|
||||
|
||||
// Send error notification (except for abort errors)
|
||||
if (!isErrorTypeAbort) {
|
||||
const timeOut = duration > 30 * 1000
|
||||
if ((!isOnHomePage() && timeOut) || (!isFocused() && timeOut)) {
|
||||
@ -98,45 +128,35 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
const possibleBlockId = findBlockIdForCompletion()
|
||||
|
||||
if (possibleBlockId) {
|
||||
// 更改上一个block的状态为ERROR
|
||||
// Update previous block status to ERROR/PAUSED
|
||||
const changes = {
|
||||
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
|
||||
}
|
||||
blockManager.smartBlockUpdate(possibleBlockId, changes, blockManager.lastBlockType!, true)
|
||||
}
|
||||
|
||||
// Fix: 更新所有仍处于 STREAMING 状态的 blocks 为 PAUSED/ERROR
|
||||
// 这修复了停止回复时思考计时器继续运行的问题
|
||||
const currentMessage = getState().messages.entities[assistantMsgId]
|
||||
// Fix: Update all blocks still in STREAMING status to PAUSED/ERROR
|
||||
// This fixes the thinking timer continuing when response is stopped
|
||||
const currentMessage = streamingService.getMessage(assistantMsgId)
|
||||
if (currentMessage) {
|
||||
const allBlockRefs = findAllBlocks(currentMessage)
|
||||
const blockState = getState().messageBlocks
|
||||
for (const blockRef of allBlockRefs) {
|
||||
const block = blockState.entities[blockRef.id]
|
||||
const block = streamingService.getBlock(blockRef.id)
|
||||
if (block && block.status === MessageBlockStatus.STREAMING && block.id !== possibleBlockId) {
|
||||
dispatch(
|
||||
updateOneBlock({
|
||||
id: block.id,
|
||||
changes: { status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR }
|
||||
})
|
||||
)
|
||||
streamingService.updateBlock(block.id, {
|
||||
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create error block
|
||||
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
|
||||
await blockManager.handleBlockTransition(errorBlock, MessageBlockType.ERROR)
|
||||
const messageErrorUpdate = {
|
||||
status: isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
|
||||
}
|
||||
dispatch(
|
||||
newMessagesActions.updateMessage({
|
||||
topicId,
|
||||
messageId: assistantMsgId,
|
||||
updates: messageErrorUpdate
|
||||
})
|
||||
)
|
||||
await saveUpdatesToDB(assistantMsgId, topicId, messageErrorUpdate, [])
|
||||
|
||||
// Finalize session with error/success status
|
||||
const finalStatus = isErrorTypeAbort ? AssistantMessageStatus.SUCCESS : AssistantMessageStatus.ERROR
|
||||
await streamingService.finalize(assistantMsgId, finalStatus)
|
||||
|
||||
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, {
|
||||
id: assistantMsgId,
|
||||
@ -146,18 +166,15 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
})
|
||||
},
|
||||
|
||||
/**
|
||||
* Called when streaming completes successfully.
|
||||
* Updates block status, processes usage stats, and finalizes session.
|
||||
*/
|
||||
onComplete: async (status: AssistantMessageStatus, response?: Response) => {
|
||||
const finalStateOnComplete = getState()
|
||||
const finalAssistantMsg = finalStateOnComplete.messages.entities[assistantMsgId]
|
||||
const finalAssistantMsg = streamingService.getMessage(assistantMsgId)
|
||||
|
||||
if (status === 'success' && finalAssistantMsg) {
|
||||
const userMsgId = finalAssistantMsg.askId
|
||||
const orderedMsgs = selectMessagesForTopic(finalStateOnComplete, topicId)
|
||||
const userMsgIndex = orderedMsgs.findIndex((m) => m.id === userMsgId)
|
||||
const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : []
|
||||
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
|
||||
|
||||
const possibleBlockId = findBlockIdForCompletion(finalAssistantMsg)
|
||||
const possibleBlockId = findBlockIdForCompletion()
|
||||
|
||||
if (possibleBlockId) {
|
||||
const changes = {
|
||||
@ -170,7 +187,7 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
const content = getMainTextContent(finalAssistantMsg)
|
||||
|
||||
const timeOut = duration > 30 * 1000
|
||||
// 发送长时间运行消息的成功通知
|
||||
// Send success notification for long-running messages
|
||||
if ((!isOnHomePage() && timeOut) || (!isFocused() && timeOut)) {
|
||||
await notificationService.send({
|
||||
id: uuid(),
|
||||
@ -184,10 +201,10 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
})
|
||||
}
|
||||
|
||||
// 更新topic的name
|
||||
// Rename topic if needed
|
||||
autoRenameTopic(assistant, topicId)
|
||||
|
||||
// 处理usage估算
|
||||
// Process usage estimation
|
||||
// For OpenRouter, always use the accurate usage data from API, don't estimate
|
||||
const isOpenRouter = assistant.model?.provider === 'openrouter'
|
||||
if (
|
||||
@ -197,11 +214,20 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
response?.usage?.prompt_tokens === 0 ||
|
||||
response?.usage?.completion_tokens === 0)
|
||||
) {
|
||||
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
||||
response.usage = usage
|
||||
// Use context from session for usage estimation
|
||||
const session = streamingService.getSession(assistantMsgId)
|
||||
if (session?.contextMessages && session.contextMessages.length > 0) {
|
||||
// Include the final assistant message in context for accurate estimation
|
||||
const finalContextWithAssistant = [...session.contextMessages, finalAssistantMsg]
|
||||
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
||||
response.usage = usage
|
||||
} else {
|
||||
logger.debug('Skipping usage estimation - contextMessages not available in session')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle metrics completion_tokens fallback
|
||||
if (response && response.metrics) {
|
||||
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
|
||||
response = {
|
||||
@ -214,15 +240,17 @@ export const createBaseCallbacks = (deps: BaseCallbacksDependencies) => {
|
||||
}
|
||||
}
|
||||
|
||||
const messageUpdates = { status, metrics: response?.metrics, usage: response?.usage }
|
||||
dispatch(
|
||||
newMessagesActions.updateMessage({
|
||||
topicId,
|
||||
messageId: assistantMsgId,
|
||||
updates: messageUpdates
|
||||
// Update message with final stats before finalize
|
||||
if (response) {
|
||||
streamingService.updateMessage(assistantMsgId, {
|
||||
metrics: response.metrics,
|
||||
usage: response.usage
|
||||
})
|
||||
)
|
||||
await saveUpdatesToDB(assistantMsgId, topicId, messageUpdates, [])
|
||||
}
|
||||
|
||||
// Finalize session and persist to database
|
||||
await streamingService.finalize(assistantMsgId, status)
|
||||
|
||||
EventEmitter.emit(EVENT_NAMES.MESSAGE_COMPLETE, { id: assistantMsgId, topicId, status })
|
||||
logger.debug('onComplete finished')
|
||||
}
|
||||
|
||||
@ -1,3 +1,15 @@
|
||||
/**
|
||||
* @fileoverview Citation callbacks for handling web search and knowledge references
|
||||
*
|
||||
* This module provides callbacks for processing citation data during streaming:
|
||||
* - External tool citations (web search, knowledge)
|
||||
* - LLM-integrated web search citations
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* These callbacks now use StreamingService for state management instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { ExternalToolResult } from '@renderer/types'
|
||||
import type { CitationMessageBlock } from '@renderer/types/newMessage'
|
||||
@ -6,17 +18,22 @@ import { createCitationBlock } from '@renderer/utils/messageUtils/create'
|
||||
import { findMainTextBlocks } from '@renderer/utils/messageUtils/find'
|
||||
|
||||
import type { BlockManager } from '../BlockManager'
|
||||
import { streamingService } from '../StreamingService'
|
||||
|
||||
const logger = loggerService.withContext('CitationCallbacks')
|
||||
|
||||
/**
|
||||
* Dependencies required for citation callbacks
|
||||
*
|
||||
* NOTE: Simplified - removed getState since StreamingService handles state.
|
||||
*/
|
||||
interface CitationCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
getState: any
|
||||
}
|
||||
|
||||
export const createCitationCallbacks = (deps: CitationCallbacksDependencies) => {
|
||||
const { blockManager, assistantMsgId, getState } = deps
|
||||
const { blockManager, assistantMsgId } = deps
|
||||
|
||||
// 内部维护的状态
|
||||
let citationBlockId: string | null = null
|
||||
@ -80,15 +97,18 @@ export const createCitationCallbacks = (deps: CitationCallbacksDependencies) =>
|
||||
}
|
||||
blockManager.smartBlockUpdate(blockId, changes, MessageBlockType.CITATION, true)
|
||||
|
||||
const state = getState()
|
||||
const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId])
|
||||
if (existingMainTextBlocks.length > 0) {
|
||||
const existingMainTextBlock = existingMainTextBlocks[0]
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [...currentRefs, { blockId, citationBlockSource: llmWebSearchResult.source }]
|
||||
// Get message from StreamingService
|
||||
const message = streamingService.getMessage(assistantMsgId)
|
||||
if (message) {
|
||||
const existingMainTextBlocks = findMainTextBlocks(message)
|
||||
if (existingMainTextBlocks.length > 0) {
|
||||
const existingMainTextBlock = existingMainTextBlocks[0]
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [...currentRefs, { blockId, citationBlockSource: llmWebSearchResult.source }]
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
|
||||
}
|
||||
|
||||
if (blockManager.hasInitialPlaceholder) {
|
||||
@ -106,15 +126,18 @@ export const createCitationCallbacks = (deps: CitationCallbacksDependencies) =>
|
||||
)
|
||||
citationBlockId = citationBlock.id
|
||||
|
||||
const state = getState()
|
||||
const existingMainTextBlocks = findMainTextBlocks(state.messages.entities[assistantMsgId])
|
||||
if (existingMainTextBlocks.length > 0) {
|
||||
const existingMainTextBlock = existingMainTextBlocks[0]
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [...currentRefs, { citationBlockId, citationBlockSource: llmWebSearchResult.source }]
|
||||
// Get message from StreamingService
|
||||
const message = streamingService.getMessage(assistantMsgId)
|
||||
if (message) {
|
||||
const existingMainTextBlocks = findMainTextBlocks(message)
|
||||
if (existingMainTextBlocks.length > 0) {
|
||||
const existingMainTextBlock = existingMainTextBlocks[0]
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [...currentRefs, { citationBlockId, citationBlockSource: llmWebSearchResult.source }]
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingMainTextBlock.id, mainTextChanges, MessageBlockType.MAIN_TEXT, true)
|
||||
}
|
||||
await blockManager.handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
|
||||
@ -1,22 +1,39 @@
|
||||
/**
|
||||
* @fileoverview Compact callbacks for handling /compact command responses
|
||||
*
|
||||
* This module provides callbacks for processing compact command responses
|
||||
* from Claude Code. It detects compact_boundary messages and creates
|
||||
* compact blocks that contain both summary and compacted content.
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* These callbacks now use StreamingService for state management instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*
|
||||
* Key changes:
|
||||
* - dispatch/getState replaced with streamingService methods
|
||||
* - saveUpdatesToDB removed (handled by finalize)
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { AppDispatch, RootState } from '@renderer/store'
|
||||
import { updateOneBlock } from '@renderer/store/messageBlock'
|
||||
import { newMessagesActions } from '@renderer/store/newMessage'
|
||||
import type { MainTextMessageBlock } from '@renderer/types/newMessage'
|
||||
import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import type { ClaudeCodeRawValue } from '@shared/agents/claudecode/types'
|
||||
|
||||
import type { BlockManager } from '../BlockManager'
|
||||
import { streamingService } from '../StreamingService'
|
||||
|
||||
const logger = loggerService.withContext('CompactCallbacks')
|
||||
|
||||
/**
|
||||
* Dependencies required for compact callbacks
|
||||
*
|
||||
* NOTE: Simplified from original design - removed dispatch, getState, and saveUpdatesToDB
|
||||
* since StreamingService now handles state management and persistence.
|
||||
*/
|
||||
interface CompactCallbacksDeps {
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
dispatch: AppDispatch
|
||||
getState: () => RootState
|
||||
topicId: string
|
||||
saveUpdatesToDB: any
|
||||
}
|
||||
|
||||
interface CompactState {
|
||||
@ -27,7 +44,7 @@ interface CompactState {
|
||||
}
|
||||
|
||||
export const createCompactCallbacks = (deps: CompactCallbacksDeps) => {
|
||||
const { blockManager, assistantMsgId, dispatch, getState, topicId, saveUpdatesToDB } = deps
|
||||
const { blockManager, assistantMsgId } = deps
|
||||
|
||||
// State to track compact command processing
|
||||
const compactState: CompactState = {
|
||||
@ -78,9 +95,8 @@ export const createCompactCallbacks = (deps: CompactCallbacksDeps) => {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get the current main text block to check its full content
|
||||
const state = getState()
|
||||
const currentBlock = state.messageBlocks.entities[currentMainTextBlockId] as MainTextMessageBlock | undefined
|
||||
// Get the current main text block from StreamingService
|
||||
const currentBlock = streamingService.getBlock(currentMainTextBlockId) as MainTextMessageBlock | null
|
||||
|
||||
if (!currentBlock) {
|
||||
return false
|
||||
@ -99,14 +115,9 @@ export const createCompactCallbacks = (deps: CompactCallbacksDeps) => {
|
||||
|
||||
// Hide this block by marking it as a placeholder temporarily
|
||||
// We'll convert it to compact block when we get the second block
|
||||
dispatch(
|
||||
updateOneBlock({
|
||||
id: currentMainTextBlockId,
|
||||
changes: {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
}
|
||||
})
|
||||
)
|
||||
streamingService.updateBlock(currentMainTextBlockId, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
|
||||
return true // Prevent normal text block completion
|
||||
}
|
||||
@ -125,53 +136,25 @@ export const createCompactCallbacks = (deps: CompactCallbacksDeps) => {
|
||||
})
|
||||
|
||||
// Update the summary block to compact type
|
||||
dispatch(
|
||||
updateOneBlock({
|
||||
id: summaryBlockId,
|
||||
changes: {
|
||||
type: MessageBlockType.COMPACT,
|
||||
content: compactState.summaryText,
|
||||
compactedContent: compactedContent,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// Update block reference
|
||||
dispatch(
|
||||
newMessagesActions.upsertBlockReference({
|
||||
messageId: assistantMsgId,
|
||||
blockId: summaryBlockId,
|
||||
status: MessageBlockStatus.SUCCESS,
|
||||
blockType: MessageBlockType.COMPACT
|
||||
})
|
||||
)
|
||||
streamingService.updateBlock(summaryBlockId, {
|
||||
type: MessageBlockType.COMPACT,
|
||||
content: compactState.summaryText,
|
||||
compactedContent: compactedContent,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
} as any) // Using 'as any' for compactedContent which is specific to CompactMessageBlock
|
||||
|
||||
// Clear active block info and update lastBlockType since the compact block is now complete
|
||||
blockManager.activeBlockInfo = null
|
||||
blockManager.lastBlockType = MessageBlockType.COMPACT
|
||||
|
||||
// Remove the current block (the one with XML tags) from message.blocks
|
||||
const currentState = getState()
|
||||
const currentMessage = currentState.messages.entities[assistantMsgId]
|
||||
const currentMessage = streamingService.getMessage(assistantMsgId)
|
||||
if (currentMessage && currentMessage.blocks) {
|
||||
const updatedBlocks = currentMessage.blocks.filter((id) => id !== currentMainTextBlockId)
|
||||
dispatch(
|
||||
newMessagesActions.updateMessage({
|
||||
topicId,
|
||||
messageId: assistantMsgId,
|
||||
updates: { blocks: updatedBlocks }
|
||||
})
|
||||
)
|
||||
streamingService.updateMessage(assistantMsgId, { blocks: updatedBlocks })
|
||||
}
|
||||
|
||||
// Save to DB
|
||||
const updatedState = getState()
|
||||
const updatedMessage = updatedState.messages.entities[assistantMsgId]
|
||||
const updatedBlock = updatedState.messageBlocks.entities[summaryBlockId]
|
||||
if (updatedMessage && updatedBlock) {
|
||||
await saveUpdatesToDB(assistantMsgId, topicId, { blocks: updatedMessage.blocks }, [updatedBlock])
|
||||
}
|
||||
// NOTE: DB save is removed - will be handled by finalize()
|
||||
|
||||
// Reset compact state
|
||||
compactState.compactBoundaryDetected = false
|
||||
|
||||
@ -1,3 +1,22 @@
|
||||
/**
|
||||
* @fileoverview Callbacks factory for streaming message processing
|
||||
*
|
||||
* This module creates and composes all callback handlers used during
|
||||
* message streaming. Each callback type handles specific aspects:
|
||||
* - Base: session lifecycle, error handling, completion
|
||||
* - Text: main text block processing
|
||||
* - Thinking: thinking/reasoning block processing
|
||||
* - Tool: tool call/result processing
|
||||
* - Image: image generation processing
|
||||
* - Citation: web search/knowledge citations
|
||||
* - Video: video content processing
|
||||
* - Compact: /compact command handling
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* These callbacks now use StreamingService for state management instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*/
|
||||
|
||||
import type { Assistant } from '@renderer/types'
|
||||
|
||||
import type { BlockManager } from '../BlockManager'
|
||||
@ -10,31 +29,31 @@ import { createThinkingCallbacks } from './thinkingCallbacks'
|
||||
import { createToolCallbacks } from './toolCallbacks'
|
||||
import { createVideoCallbacks } from './videoCallbacks'
|
||||
|
||||
/**
|
||||
* Dependencies required for creating all callbacks
|
||||
*
|
||||
* NOTE: Simplified from original design - removed dispatch, getState, and saveUpdatesToDB
|
||||
* since StreamingService now handles state management and persistence.
|
||||
*/
|
||||
interface CallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
dispatch: any
|
||||
getState: any
|
||||
topicId: string
|
||||
assistantMsgId: string
|
||||
saveUpdatesToDB: any
|
||||
assistant: Assistant
|
||||
}
|
||||
|
||||
export const createCallbacks = (deps: CallbacksDependencies) => {
|
||||
const { blockManager, dispatch, getState, topicId, assistantMsgId, saveUpdatesToDB, assistant } = deps
|
||||
const { blockManager, topicId, assistantMsgId, assistant } = deps
|
||||
|
||||
// 创建基础回调
|
||||
// Create base callbacks (lifecycle, error, complete)
|
||||
const baseCallbacks = createBaseCallbacks({
|
||||
blockManager,
|
||||
dispatch,
|
||||
getState,
|
||||
topicId,
|
||||
assistantMsgId,
|
||||
saveUpdatesToDB,
|
||||
assistant
|
||||
})
|
||||
|
||||
// 创建各类回调
|
||||
// Create specialized callbacks for each block type
|
||||
const thinkingCallbacks = createThinkingCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId
|
||||
@ -42,8 +61,7 @@ export const createCallbacks = (deps: CallbacksDependencies) => {
|
||||
|
||||
const toolCallbacks = createToolCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId,
|
||||
dispatch
|
||||
assistantMsgId
|
||||
})
|
||||
|
||||
const imageCallbacks = createImageCallbacks({
|
||||
@ -53,8 +71,7 @@ export const createCallbacks = (deps: CallbacksDependencies) => {
|
||||
|
||||
const citationCallbacks = createCitationCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId,
|
||||
getState
|
||||
assistantMsgId
|
||||
})
|
||||
|
||||
const videoCallbacks = createVideoCallbacks({ blockManager, assistantMsgId })
|
||||
@ -62,23 +79,19 @@ export const createCallbacks = (deps: CallbacksDependencies) => {
|
||||
const compactCallbacks = createCompactCallbacks({
|
||||
blockManager,
|
||||
assistantMsgId,
|
||||
dispatch,
|
||||
getState,
|
||||
topicId,
|
||||
saveUpdatesToDB
|
||||
topicId
|
||||
})
|
||||
|
||||
// 创建textCallbacks时传入citationCallbacks的getCitationBlockId方法和compactCallbacks的handleTextComplete方法
|
||||
// Create textCallbacks with citation and compact handlers
|
||||
const textCallbacks = createTextCallbacks({
|
||||
blockManager,
|
||||
getState,
|
||||
assistantMsgId,
|
||||
getCitationBlockId: citationCallbacks.getCitationBlockId,
|
||||
getCitationBlockIdFromTool: toolCallbacks.getCitationBlockId,
|
||||
handleCompactTextComplete: compactCallbacks.handleTextComplete
|
||||
})
|
||||
|
||||
// 组合所有回调
|
||||
// Compose all callbacks
|
||||
return {
|
||||
...baseCallbacks,
|
||||
...textCallbacks,
|
||||
@ -88,10 +101,10 @@ export const createCallbacks = (deps: CallbacksDependencies) => {
|
||||
...citationCallbacks,
|
||||
...videoCallbacks,
|
||||
...compactCallbacks,
|
||||
// 清理资源的方法
|
||||
// Cleanup method (throttling is managed by messageThunk)
|
||||
cleanup: () => {
|
||||
// 清理由 messageThunk 中的节流函数管理,这里不需要特别处理
|
||||
// 如果需要,可以调用 blockManager 的相关清理方法
|
||||
// Cleanup is managed by messageThunk throttle functions
|
||||
// Add any additional cleanup here if needed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,3 +1,16 @@
|
||||
/**
|
||||
* @fileoverview Text callbacks for handling main text block streaming
|
||||
*
|
||||
* This module provides callbacks for processing text content during streaming:
|
||||
* - Text start: initialize or transform placeholder to main text block
|
||||
* - Text chunk: update content during streaming
|
||||
* - Text complete: finalize the block
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* These callbacks now use StreamingService for state management instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { WebSearchSource } from '@renderer/types'
|
||||
import type { CitationMessageBlock, MessageBlock } from '@renderer/types/newMessage'
|
||||
@ -5,12 +18,17 @@ import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage
|
||||
import { createMainTextBlock } from '@renderer/utils/messageUtils/create'
|
||||
|
||||
import type { BlockManager } from '../BlockManager'
|
||||
import { streamingService } from '../StreamingService'
|
||||
|
||||
const logger = loggerService.withContext('TextCallbacks')
|
||||
|
||||
/**
|
||||
* Dependencies required for text callbacks
|
||||
*
|
||||
* NOTE: Simplified - removed getState since StreamingService handles state.
|
||||
*/
|
||||
interface TextCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
getState: any
|
||||
assistantMsgId: string
|
||||
getCitationBlockId: () => string | null
|
||||
getCitationBlockIdFromTool: () => string | null
|
||||
@ -18,14 +36,8 @@ interface TextCallbacksDependencies {
|
||||
}
|
||||
|
||||
export const createTextCallbacks = (deps: TextCallbacksDependencies) => {
|
||||
const {
|
||||
blockManager,
|
||||
getState,
|
||||
assistantMsgId,
|
||||
getCitationBlockId,
|
||||
getCitationBlockIdFromTool,
|
||||
handleCompactTextComplete
|
||||
} = deps
|
||||
const { blockManager, assistantMsgId, getCitationBlockId, getCitationBlockIdFromTool, handleCompactTextComplete } =
|
||||
deps
|
||||
|
||||
// 内部维护的状态
|
||||
let mainTextBlockId: string | null = null
|
||||
@ -52,9 +64,12 @@ export const createTextCallbacks = (deps: TextCallbacksDependencies) => {
|
||||
|
||||
onTextChunk: async (text: string) => {
|
||||
const citationBlockId = getCitationBlockId() || getCitationBlockIdFromTool()
|
||||
const citationBlockSource = citationBlockId
|
||||
? (getState().messageBlocks.entities[citationBlockId] as CitationMessageBlock).response?.source
|
||||
: WebSearchSource.WEBSEARCH
|
||||
// Get citation block from StreamingService to determine source
|
||||
const citationBlock = citationBlockId
|
||||
? (streamingService.getBlock(citationBlockId) as CitationMessageBlock | null)
|
||||
: null
|
||||
const citationBlockSource = citationBlock?.response?.source ?? WebSearchSource.WEBSEARCH
|
||||
|
||||
if (text) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: text,
|
||||
|
||||
@ -1,5 +1,20 @@
|
||||
/**
|
||||
* @fileoverview Tool callbacks for handling MCP tool calls during streaming
|
||||
*
|
||||
* This module provides callbacks for processing tool calls:
|
||||
* - Tool call pending: create tool block when tool is called
|
||||
* - Tool call complete: update with result or error
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* These callbacks now use StreamingService for state management instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*
|
||||
* NOTE: toolPermissionsActions dispatch is still required for permission management
|
||||
* as this is outside the scope of streaming state management.
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import type { AppDispatch } from '@renderer/store'
|
||||
import store from '@renderer/store'
|
||||
import { toolPermissionsActions } from '@renderer/store/toolPermissions'
|
||||
import type { MCPToolResponse } from '@renderer/types'
|
||||
import { WebSearchSource } from '@renderer/types'
|
||||
@ -11,14 +26,19 @@ import type { BlockManager } from '../BlockManager'
|
||||
|
||||
const logger = loggerService.withContext('ToolCallbacks')
|
||||
|
||||
/**
|
||||
* Dependencies required for tool callbacks
|
||||
*
|
||||
* NOTE: dispatch removed - toolPermissions uses store.dispatch directly
|
||||
* since it's outside streaming state scope.
|
||||
*/
|
||||
interface ToolCallbacksDependencies {
|
||||
blockManager: BlockManager
|
||||
assistantMsgId: string
|
||||
dispatch: AppDispatch
|
||||
}
|
||||
|
||||
export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
const { blockManager, assistantMsgId, dispatch } = deps
|
||||
const { blockManager, assistantMsgId } = deps
|
||||
|
||||
// 内部维护的状态
|
||||
const toolCallIdToBlockIdMap = new Map<string, string>()
|
||||
@ -57,7 +77,8 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
|
||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||
if (toolResponse?.id) {
|
||||
dispatch(toolPermissionsActions.removeByToolCallId({ toolCallId: toolResponse.id }))
|
||||
// Use store.dispatch for permission cleanup (outside streaming state scope)
|
||||
store.dispatch(toolPermissionsActions.removeByToolCallId({ toolCallId: toolResponse.id }))
|
||||
}
|
||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
export { BlockManager } from './BlockManager'
|
||||
export type { createCallbacks as CreateCallbacksFunction } from './callbacks'
|
||||
export { createCallbacks } from './callbacks'
|
||||
export type { StartSessionOptions, StreamingSession } from './StreamingService'
|
||||
export { StreamingService, streamingService } from './StreamingService'
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import { combineReducers, configureStore } from '@reduxjs/toolkit'
|
||||
import { BlockManager } from '@renderer/services/messageStreaming/BlockManager'
|
||||
import { createCallbacks } from '@renderer/services/messageStreaming/callbacks'
|
||||
import { streamingService } from '@renderer/services/messageStreaming/StreamingService'
|
||||
import { createStreamProcessor } from '@renderer/services/StreamProcessingService'
|
||||
import type { AppDispatch } from '@renderer/store'
|
||||
import { messageBlocksSlice } from '@renderer/store/messageBlock'
|
||||
import { messagesSlice } from '@renderer/store/newMessage'
|
||||
import type { Assistant, ExternalToolResult, MCPTool, Model } from '@renderer/types'
|
||||
@ -12,33 +12,41 @@ import { ChunkType } from '@renderer/types/chunk'
|
||||
import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { RootState } from '../../index'
|
||||
|
||||
/**
|
||||
* Create mock callbacks for testing.
|
||||
*
|
||||
* NOTE: Updated to use simplified dependencies after StreamingService refactoring.
|
||||
* Now we need to initialize StreamingService session before creating callbacks.
|
||||
*/
|
||||
const createMockCallbacks = (
|
||||
mockAssistantMsgId: string,
|
||||
mockTopicId: string,
|
||||
mockAssistant: Assistant,
|
||||
dispatch: AppDispatch,
|
||||
getState: () => ReturnType<typeof reducer> & RootState
|
||||
) =>
|
||||
createCallbacks({
|
||||
mockAssistant: Assistant
|
||||
// dispatch and getState are no longer needed after StreamingService refactoring
|
||||
) => {
|
||||
// Initialize streaming session for tests
|
||||
streamingService.startSession(mockTopicId, mockAssistantMsgId, {
|
||||
parentId: 'test-user-msg-id',
|
||||
role: 'assistant',
|
||||
assistantId: mockAssistant.id,
|
||||
model: mockAssistant.model
|
||||
})
|
||||
|
||||
return createCallbacks({
|
||||
blockManager: new BlockManager({
|
||||
dispatch,
|
||||
getState,
|
||||
saveUpdatedBlockToDB: vi.fn(),
|
||||
saveUpdatesToDB: vi.fn(),
|
||||
assistantMsgId: mockAssistantMsgId,
|
||||
topicId: mockTopicId,
|
||||
throttledBlockUpdate: vi.fn(),
|
||||
throttledBlockUpdate: vi.fn((blockId, changes) => {
|
||||
// In tests, immediately update the block
|
||||
streamingService.updateBlock(blockId, changes)
|
||||
}),
|
||||
cancelThrottledBlockUpdate: vi.fn()
|
||||
}),
|
||||
dispatch,
|
||||
getState,
|
||||
topicId: mockTopicId,
|
||||
assistantMsgId: mockAssistantMsgId,
|
||||
saveUpdatesToDB: vi.fn(),
|
||||
assistant: mockAssistant
|
||||
})
|
||||
}
|
||||
|
||||
// Mock external dependencies
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
@ -311,8 +319,7 @@ const processChunks = async (chunks: Chunk[], callbacks: ReturnType<typeof creat
|
||||
|
||||
describe('streamCallback Integration Tests', () => {
|
||||
let store: ReturnType<typeof createMockStore>
|
||||
let dispatch: AppDispatch
|
||||
let getState: () => ReturnType<typeof reducer> & RootState
|
||||
// dispatch and getState are no longer needed after StreamingService refactoring
|
||||
|
||||
const mockTopicId = 'test-topic-id'
|
||||
const mockAssistantMsgId = 'test-assistant-msg-id'
|
||||
@ -334,10 +341,8 @@ describe('streamCallback Integration Tests', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
store = createMockStore()
|
||||
dispatch = store.dispatch
|
||||
getState = store.getState as () => ReturnType<typeof reducer> & RootState
|
||||
|
||||
// 为测试消息添加初始状态
|
||||
// Add initial message state for tests
|
||||
store.dispatch(
|
||||
messagesSlice.actions.addMessage({
|
||||
topicId: mockTopicId,
|
||||
@ -360,7 +365,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle complete text streaming flow', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
@ -387,7 +392,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
expect(blocks.length).toBeGreaterThan(0)
|
||||
|
||||
@ -403,7 +408,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle thinking flow', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
@ -418,7 +423,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
|
||||
const thinkingBlock = blocks.find((block) => block.type === MessageBlockType.THINKING)
|
||||
@ -432,7 +437,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle tool call flow', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const mockTool: MCPTool = {
|
||||
id: 'tool-1',
|
||||
@ -492,7 +497,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
|
||||
const toolBlock = blocks.find((block) => block.type === MessageBlockType.TOOL)
|
||||
@ -503,7 +508,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle image generation flow', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
@ -532,7 +537,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
const imageBlock = blocks.find((block) => block.type === MessageBlockType.IMAGE)
|
||||
expect(imageBlock).toBeDefined()
|
||||
@ -543,7 +548,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle web search flow', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const mockWebSearchResult = {
|
||||
source: WebSearchSource.WEBSEARCH,
|
||||
@ -560,7 +565,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION)
|
||||
expect(citationBlock).toBeDefined()
|
||||
@ -569,7 +574,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle mixed content flow (thinking + tool + text)', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const mockCalculatorTool: MCPTool = {
|
||||
id: 'tool-1',
|
||||
@ -652,7 +657,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
|
||||
expect(blocks.length).toBeGreaterThan(2) // 至少有思考块、工具块、文本块
|
||||
@ -671,7 +676,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle error flow', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const mockError = new Error('Test error')
|
||||
|
||||
@ -685,7 +690,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
|
||||
expect(blocks.length).toBeGreaterThan(0)
|
||||
@ -701,7 +706,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle external tool flow', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const mockExternalToolResult: ExternalToolResult = {
|
||||
webSearch: {
|
||||
@ -728,7 +733,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
|
||||
const citationBlock = blocks.find((block) => block.type === MessageBlockType.CITATION)
|
||||
@ -739,7 +744,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should handle abort error correctly', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
// 创建一个模拟的 abort 错误
|
||||
const abortError = new Error('Request aborted')
|
||||
@ -755,7 +760,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
|
||||
expect(blocks.length).toBeGreaterThan(0)
|
||||
@ -770,7 +775,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
})
|
||||
|
||||
it('should maintain block reference integrity during streaming', async () => {
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant, dispatch, getState)
|
||||
const callbacks = createMockCallbacks(mockAssistantMsgId, mockTopicId, mockAssistant)
|
||||
|
||||
const chunks: Chunk[] = [
|
||||
{ type: ChunkType.LLM_RESPONSE_CREATED },
|
||||
@ -784,7 +789,7 @@ describe('streamCallback Integration Tests', () => {
|
||||
await processChunks(chunks, callbacks)
|
||||
|
||||
// 验证 Redux 状态
|
||||
const state = getState()
|
||||
const state = store.getState()
|
||||
const blocks = Object.values(state.messageBlocks.entities)
|
||||
const message = state.messages.entities[mockAssistantMsgId]
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
* --------------------------------------------------------------------------
|
||||
*/
|
||||
import { cacheService } from '@data/CacheService'
|
||||
import { dataApiService } from '@data/DataApiService'
|
||||
import { loggerService } from '@logger'
|
||||
import { AiSdkToChunkAdapter } from '@renderer/aiCore/chunk/AiSdkToChunkAdapter'
|
||||
import { AgentApiClient } from '@renderer/api/agent'
|
||||
@ -25,6 +26,7 @@ import { DbService } from '@renderer/services/db/DbService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { BlockManager } from '@renderer/services/messageStreaming/BlockManager'
|
||||
import { createCallbacks } from '@renderer/services/messageStreaming/callbacks'
|
||||
import { streamingService } from '@renderer/services/messageStreaming/StreamingService'
|
||||
import { endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/services/StreamProcessingService'
|
||||
import store from '@renderer/store'
|
||||
@ -48,6 +50,8 @@ import {
|
||||
} from '@renderer/utils/messageUtils/create'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
|
||||
import type { CreateMessageDto } from '@shared/data/api/schemas/messages'
|
||||
import type { Message as SharedMessage } from '@shared/data/types/message'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { TextStreamPart } from 'ai'
|
||||
@ -74,6 +78,35 @@ import { newMessagesActions, selectMessagesForTopic } from '../newMessage'
|
||||
|
||||
const logger = loggerService.withContext('MessageThunk')
|
||||
|
||||
/**
|
||||
* Convert shared Message format (from Data API) to renderer Message format
|
||||
*
|
||||
* The Data API returns messages with `data: { blocks: MessageDataBlock[] }` format,
|
||||
* but the renderer expects `blocks: string[]` format.
|
||||
*
|
||||
* For newly created pending messages, blocks are empty, so conversion is straightforward.
|
||||
* For messages with content, this would need to store blocks separately and return IDs.
|
||||
*
|
||||
* @param shared - Message from Data API response
|
||||
* @param model - Optional Model object to include
|
||||
* @returns Renderer-format Message
|
||||
*/
|
||||
const convertSharedToRendererMessage = (shared: SharedMessage, assistantId: string, model?: Model): Message => {
|
||||
return {
|
||||
id: shared.id,
|
||||
topicId: shared.topicId,
|
||||
role: shared.role,
|
||||
assistantId,
|
||||
status: shared.status as AssistantMessageStatus,
|
||||
blocks: [], // For new pending messages, blocks are empty
|
||||
createdAt: shared.createdAt,
|
||||
askId: shared.parentId ?? undefined,
|
||||
modelId: shared.modelId ?? undefined,
|
||||
traceId: shared.traceId ?? undefined,
|
||||
model
|
||||
}
|
||||
}
|
||||
|
||||
const finishTopicLoading = async (topicId: string) => {
|
||||
await waitForTopicQueue(topicId)
|
||||
store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
@ -418,23 +451,34 @@ const blockUpdateRafs = new LRUCache<string, number>({
|
||||
})
|
||||
|
||||
/**
|
||||
* 获取或创建消息块专用的节流函数。
|
||||
* Get or create a dedicated throttle function for a message block.
|
||||
*
|
||||
* ARCHITECTURE NOTE:
|
||||
* Updated to use StreamingService.updateBlock instead of Redux dispatch.
|
||||
* This is part of the v2 data refactoring to use CacheService + Data API.
|
||||
*
|
||||
* The throttler now:
|
||||
* 1. Uses RAF for visual consistency
|
||||
* 2. Updates StreamingService (memory cache) for immediate reactivity
|
||||
* 3. Removes the DB update (moved to finalize)
|
||||
*/
|
||||
const getBlockThrottler = (id: string) => {
|
||||
if (!blockUpdateThrottlers.has(id)) {
|
||||
const throttler = throttle(async (blockUpdate: any) => {
|
||||
const throttler = throttle((blockUpdate: any) => {
|
||||
const existingRAF = blockUpdateRafs.get(id)
|
||||
if (existingRAF) {
|
||||
cancelAnimationFrame(existingRAF)
|
||||
}
|
||||
|
||||
const rafId = requestAnimationFrame(() => {
|
||||
store.dispatch(updateOneBlock({ id, changes: blockUpdate }))
|
||||
// Update StreamingService instead of Redux store
|
||||
streamingService.updateBlock(id, blockUpdate)
|
||||
blockUpdateRafs.delete(id)
|
||||
})
|
||||
|
||||
blockUpdateRafs.set(id, rafId)
|
||||
await updateSingleBlock(id, blockUpdate)
|
||||
// NOTE: DB update removed - persistence happens during finalize()
|
||||
// await updateSingleBlock(id, blockUpdate)
|
||||
}, 150)
|
||||
|
||||
blockUpdateThrottlers.set(id, throttler)
|
||||
@ -516,25 +560,26 @@ const saveUpdatesToDB = async (
|
||||
}
|
||||
}
|
||||
|
||||
// 新增: 辅助函数,用于获取并保存单个更新后的 Block 到数据库
|
||||
const saveUpdatedBlockToDB = async (
|
||||
blockId: string | null,
|
||||
messageId: string,
|
||||
topicId: string,
|
||||
getState: () => RootState
|
||||
) => {
|
||||
if (!blockId) {
|
||||
logger.warn('[DB Save Single Block] Received null/undefined blockId. Skipping save.')
|
||||
return
|
||||
}
|
||||
const state = getState()
|
||||
const blockToSave = state.messageBlocks.entities[blockId]
|
||||
if (blockToSave) {
|
||||
await saveUpdatesToDB(messageId, topicId, {}, [blockToSave]) // Pass messageId, topicId, empty message updates, and the block
|
||||
} else {
|
||||
logger.warn(`[DB Save Single Block] Block ${blockId} not found in state. Cannot save.`)
|
||||
}
|
||||
}
|
||||
// NOTE: saveUpdatedBlockToDB was removed as part of StreamingService refactoring.
|
||||
// Block persistence is now handled by StreamingService.finalize().
|
||||
// const saveUpdatedBlockToDB = async (
|
||||
// blockId: string | null,
|
||||
// messageId: string,
|
||||
// topicId: string,
|
||||
// getState: () => RootState
|
||||
// ) => {
|
||||
// if (!blockId) {
|
||||
// logger.warn('[DB Save Single Block] Received null/undefined blockId. Skipping save.')
|
||||
// return
|
||||
// }
|
||||
// const state = getState()
|
||||
// const blockToSave = state.messageBlocks.entities[blockId]
|
||||
// if (blockToSave) {
|
||||
// await saveUpdatesToDB(messageId, topicId, {}, [blockToSave]) // Pass messageId, topicId, empty message updates, and the block
|
||||
// } else {
|
||||
// logger.warn(`[DB Save Single Block] Block ${blockId} not found in state. Cannot save.`)
|
||||
// }
|
||||
// }
|
||||
|
||||
interface AgentStreamParams {
|
||||
topicId: string
|
||||
@ -553,24 +598,32 @@ const fetchAndProcessAgentResponseImpl = async (
|
||||
try {
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: true }))
|
||||
|
||||
// Initialize streaming session in StreamingService
|
||||
streamingService.startSession(topicId, assistantMessage.id, {
|
||||
parentId: userMessageId,
|
||||
siblingsGroupId: 0,
|
||||
role: 'assistant',
|
||||
model: assistant.model,
|
||||
modelId: assistant.model?.id,
|
||||
assistantId: assistant.id,
|
||||
askId: userMessageId,
|
||||
traceId: assistantMessage.traceId,
|
||||
agentSessionId: agentSession.agentSessionId
|
||||
})
|
||||
|
||||
// Create BlockManager with simplified dependencies (no dispatch/getState/saveUpdatesToDB)
|
||||
const blockManager = new BlockManager({
|
||||
dispatch,
|
||||
getState,
|
||||
saveUpdatedBlockToDB,
|
||||
saveUpdatesToDB,
|
||||
assistantMsgId: assistantMessage.id,
|
||||
topicId,
|
||||
throttledBlockUpdate,
|
||||
cancelThrottledBlockUpdate
|
||||
})
|
||||
|
||||
// Create callbacks with simplified dependencies
|
||||
callbacks = createCallbacks({
|
||||
blockManager,
|
||||
dispatch,
|
||||
getState,
|
||||
topicId,
|
||||
assistantMsgId: assistantMessage.id,
|
||||
saveUpdatesToDB,
|
||||
assistant
|
||||
})
|
||||
|
||||
@ -718,74 +771,85 @@ const dispatchMultiModelResponses = async (
|
||||
mentionedModels: Model[]
|
||||
) => {
|
||||
const assistantMessageStubs: Message[] = []
|
||||
const tasksToQueue: { assistantConfig: Assistant; messageStub: Message }[] = []
|
||||
const tasksToQueue: { assistantConfig: Assistant; messageStub: Message; siblingsGroupId: number }[] = []
|
||||
|
||||
// Generate siblingsGroupId for multi-model responses (all share the same group ID)
|
||||
const siblingsGroupId = mentionedModels.length > 1 ? streamingService.generateNextGroupId(topicId) : 0
|
||||
|
||||
for (const mentionedModel of mentionedModels) {
|
||||
const assistantForThisMention = { ...assistant, model: mentionedModel }
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: triggeringMessage.id,
|
||||
model: mentionedModel,
|
||||
|
||||
// Create message via Data API instead of local creation
|
||||
const createDto: CreateMessageDto = {
|
||||
parentId: triggeringMessage.id,
|
||||
role: 'assistant',
|
||||
data: { blocks: [] },
|
||||
status: 'pending',
|
||||
siblingsGroupId,
|
||||
assistantId: assistant.id,
|
||||
modelId: mentionedModel.id,
|
||||
traceId: triggeringMessage.traceId
|
||||
})
|
||||
traceId: triggeringMessage.traceId ?? undefined
|
||||
}
|
||||
|
||||
const sharedMessage = await dataApiService.post(`/topics/${topicId}/messages`, { body: createDto })
|
||||
const assistantMessage = convertSharedToRendererMessage(sharedMessage, assistant.id, mentionedModel)
|
||||
|
||||
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
||||
assistantMessageStubs.push(assistantMessage)
|
||||
tasksToQueue.push({ assistantConfig: assistantForThisMention, messageStub: assistantMessage })
|
||||
tasksToQueue.push({ assistantConfig: assistantForThisMention, messageStub: assistantMessage, siblingsGroupId })
|
||||
}
|
||||
|
||||
const topicFromDB = await db.topics.get(topicId)
|
||||
if (topicFromDB) {
|
||||
const currentTopicMessageIds = getState().messages.messageIdsByTopic[topicId] || []
|
||||
const currentEntities = getState().messages.entities
|
||||
const messagesToSaveInDB = currentTopicMessageIds.map((id) => currentEntities[id]).filter((m): m is Message => !!m)
|
||||
await db.topics.update(topicId, { messages: messagesToSaveInDB })
|
||||
} else {
|
||||
logger.error(`[dispatchMultiModelResponses] Topic ${topicId} not found in DB during multi-model save.`)
|
||||
throw new Error(`Topic ${topicId} not found in DB.`)
|
||||
}
|
||||
// Note: Dexie save removed - messages are now persisted via Data API POST above
|
||||
// const topicFromDB = await db.topics.get(topicId)
|
||||
// if (topicFromDB) {
|
||||
// const currentTopicMessageIds = getState().messages.messageIdsByTopic[topicId] || []
|
||||
// const currentEntities = getState().messages.entities
|
||||
// const messagesToSaveInDB = currentTopicMessageIds.map((id) => currentEntities[id]).filter((m): m is Message => !!m)
|
||||
// await db.topics.update(topicId, { messages: messagesToSaveInDB })
|
||||
// } else {
|
||||
// logger.error(`[dispatchMultiModelResponses] Topic ${topicId} not found in DB during multi-model save.`)
|
||||
// throw new Error(`Topic ${topicId} not found in DB.`)
|
||||
// }
|
||||
|
||||
const queue = getTopicQueue(topicId)
|
||||
for (const task of tasksToQueue) {
|
||||
queue.add(async () => {
|
||||
await fetchAndProcessAssistantResponseImpl(dispatch, getState, topicId, task.assistantConfig, task.messageStub)
|
||||
await fetchAndProcessAssistantResponseImpl(
|
||||
dispatch,
|
||||
getState,
|
||||
topicId,
|
||||
task.assistantConfig,
|
||||
task.messageStub,
|
||||
task.siblingsGroupId
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- End Helper Function ---
|
||||
// 发送和处理助手响应的实现函数,话题提示词在此拼接
|
||||
// Send and process assistant response implementation - topic prompts are concatenated here
|
||||
const fetchAndProcessAssistantResponseImpl = async (
|
||||
dispatch: AppDispatch,
|
||||
getState: () => RootState,
|
||||
topicId: string,
|
||||
origAssistant: Assistant,
|
||||
assistantMessage: Message // Pass the prepared assistant message (new or reset)
|
||||
assistantMessage: Message, // Pass the prepared assistant message (new or reset)
|
||||
siblingsGroupId: number = 0 // Multi-model group ID (0=normal, >0=multi-model response)
|
||||
) => {
|
||||
const topic = origAssistant.topics.find((t) => t.id === topicId)
|
||||
const assistant = topic?.prompt
|
||||
? { ...origAssistant, prompt: `${origAssistant.prompt}\n${topic.prompt}` }
|
||||
: origAssistant
|
||||
const assistantMsgId = assistantMessage.id
|
||||
const userMessageId = assistantMessage.askId
|
||||
let callbacks: StreamProcessorCallbacks = {}
|
||||
try {
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: true }))
|
||||
|
||||
// 创建 BlockManager 实例
|
||||
const blockManager = new BlockManager({
|
||||
dispatch,
|
||||
getState,
|
||||
saveUpdatedBlockToDB,
|
||||
saveUpdatesToDB,
|
||||
assistantMsgId,
|
||||
topicId,
|
||||
throttledBlockUpdate,
|
||||
cancelThrottledBlockUpdate
|
||||
})
|
||||
|
||||
// Build context messages first (needed for startSession)
|
||||
const allMessagesForTopic = selectMessagesForTopic(getState(), topicId)
|
||||
|
||||
let messagesForContext: Message[] = []
|
||||
const userMessageId = assistantMessage.askId
|
||||
const userMessageIndex = allMessagesForTopic.findIndex((m) => m?.id === userMessageId)
|
||||
|
||||
if (userMessageIndex === -1) {
|
||||
@ -812,13 +876,32 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize streaming session in StreamingService (includes context for usage estimation)
|
||||
streamingService.startSession(topicId, assistantMsgId, {
|
||||
parentId: userMessageId!,
|
||||
siblingsGroupId,
|
||||
role: 'assistant',
|
||||
model: assistant.model,
|
||||
modelId: assistant.model?.id,
|
||||
assistantId: assistant.id,
|
||||
askId: userMessageId,
|
||||
traceId: assistantMessage.traceId,
|
||||
contextMessages: messagesForContext
|
||||
})
|
||||
|
||||
// Create BlockManager with simplified dependencies (no dispatch/getState/saveUpdatesToDB)
|
||||
const blockManager = new BlockManager({
|
||||
assistantMsgId,
|
||||
topicId,
|
||||
throttledBlockUpdate,
|
||||
cancelThrottledBlockUpdate
|
||||
})
|
||||
|
||||
// Create callbacks with simplified dependencies
|
||||
callbacks = createCallbacks({
|
||||
blockManager,
|
||||
dispatch,
|
||||
getState,
|
||||
topicId,
|
||||
assistantMsgId,
|
||||
saveUpdatesToDB,
|
||||
assistant
|
||||
})
|
||||
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
||||
@ -931,12 +1014,21 @@ export const sendMessage =
|
||||
if (mentionedModels && mentionedModels.length > 0) {
|
||||
await dispatchMultiModelResponses(dispatch, getState, topicId, userMessage, assistant, mentionedModels)
|
||||
} else {
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: userMessage.id,
|
||||
model: assistant.model,
|
||||
traceId: userMessage.traceId
|
||||
})
|
||||
await saveMessageAndBlocksToDB(topicId, assistantMessage, [])
|
||||
// Create message via Data API for normal topics
|
||||
const createDto: CreateMessageDto = {
|
||||
parentId: userMessage.id,
|
||||
role: 'assistant',
|
||||
data: { blocks: [] },
|
||||
status: 'pending',
|
||||
siblingsGroupId: 0,
|
||||
assistantId: assistant.id,
|
||||
modelId: assistant.model?.id,
|
||||
traceId: userMessage.traceId ?? undefined
|
||||
}
|
||||
|
||||
const sharedMessage = await dataApiService.post(`/topics/${topicId}/messages`, { body: createDto })
|
||||
const assistantMessage = convertSharedToRendererMessage(sharedMessage, assistant.id, assistant.model)
|
||||
|
||||
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
||||
|
||||
queue.add(async () => {
|
||||
@ -1126,12 +1218,21 @@ export const resendMessageThunk =
|
||||
|
||||
if (assistantMessagesToReset.length === 0 && !userMessageToResend?.mentions?.length) {
|
||||
// 没有相关的助手消息且没有提及模型时,使用助手模型创建一条消息
|
||||
// Create message via Data API
|
||||
const createDto: CreateMessageDto = {
|
||||
parentId: userMessageToResend.id,
|
||||
role: 'assistant',
|
||||
data: { blocks: [] },
|
||||
status: 'pending',
|
||||
siblingsGroupId: 0,
|
||||
assistantId: assistant.id,
|
||||
modelId: assistant.model?.id,
|
||||
traceId: userMessageToResend.traceId ?? undefined
|
||||
}
|
||||
|
||||
const sharedMessage = await dataApiService.post(`/topics/${topicId}/messages`, { body: createDto })
|
||||
const assistantMessage = convertSharedToRendererMessage(sharedMessage, assistant.id, assistant.model)
|
||||
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: userMessageToResend.id,
|
||||
model: assistant.model
|
||||
})
|
||||
assistantMessage.traceId = userMessageToResend.traceId
|
||||
resetDataList.push(assistantMessage)
|
||||
|
||||
resetDataList.forEach((message) => {
|
||||
@ -1166,11 +1267,21 @@ export const resendMessageThunk =
|
||||
const mentionedModelSet = new Set(userMessageToResend.mentions ?? [])
|
||||
const newModelSet = new Set([...mentionedModelSet].filter((m) => !originModelSet.has(m)))
|
||||
for (const model of newModelSet) {
|
||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: userMessageToResend.id,
|
||||
model: model,
|
||||
modelId: model.id
|
||||
})
|
||||
// Create message via Data API for new mentioned models
|
||||
const createDto: CreateMessageDto = {
|
||||
parentId: userMessageToResend.id,
|
||||
role: 'assistant',
|
||||
data: { blocks: [] },
|
||||
status: 'pending',
|
||||
siblingsGroupId: 0,
|
||||
assistantId: assistant.id,
|
||||
modelId: model.id,
|
||||
traceId: userMessageToResend.traceId ?? undefined
|
||||
}
|
||||
|
||||
const sharedMessage = await dataApiService.post(`/topics/${topicId}/messages`, { body: createDto })
|
||||
const assistantMessage = convertSharedToRendererMessage(sharedMessage, assistant.id, model)
|
||||
|
||||
resetDataList.push(assistantMessage)
|
||||
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
||||
}
|
||||
@ -1178,10 +1289,14 @@ export const resendMessageThunk =
|
||||
messagesToUpdateInRedux.forEach((update) => dispatch(newMessagesActions.updateMessage(update)))
|
||||
cleanupMultipleBlocks(dispatch, allBlockIdsToDelete)
|
||||
|
||||
// Note: Block deletion still uses Dexie for now
|
||||
// TODO: Migrate block deletion to Data API when block endpoints are available
|
||||
try {
|
||||
if (allBlockIdsToDelete.length > 0) {
|
||||
await db.message_blocks.bulkDelete(allBlockIdsToDelete)
|
||||
}
|
||||
// Note: Dexie topic update removed for new messages - they are created via Data API
|
||||
// However, existing message updates still need Dexie sync for now
|
||||
const finalMessagesToSave = selectMessagesForTopic(getState(), topicId)
|
||||
await db.topics.update(topicId, { messages: finalMessagesToSave })
|
||||
} catch (dbError) {
|
||||
@ -1467,21 +1582,28 @@ export const appendAssistantResponseThunk =
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Create the new assistant message stub
|
||||
const newAssistantMessageStub = createAssistantMessage(assistant.id, topicId, {
|
||||
askId: askId, // Crucial: Use the original askId
|
||||
model: newModel,
|
||||
// 2. Create the new assistant message via Data API
|
||||
const createDto: CreateMessageDto = {
|
||||
parentId: askId, // Crucial: Use the original askId
|
||||
role: 'assistant',
|
||||
data: { blocks: [] },
|
||||
status: 'pending',
|
||||
siblingsGroupId: 0,
|
||||
assistantId: assistant.id,
|
||||
modelId: newModel.id,
|
||||
traceId: traceId
|
||||
})
|
||||
traceId: traceId ?? undefined
|
||||
}
|
||||
|
||||
const sharedMessage = await dataApiService.post(`/topics/${topicId}/messages`, { body: createDto })
|
||||
const newAssistantMessageStub = convertSharedToRendererMessage(sharedMessage, assistant.id, newModel)
|
||||
|
||||
// 3. Update Redux Store
|
||||
const currentTopicMessageIds = getState().messages.messageIdsByTopic[topicId] || []
|
||||
const existingMessageIndex = currentTopicMessageIds.findIndex((id) => id === existingAssistantMessageId)
|
||||
const insertAtIndex = existingMessageIndex !== -1 ? existingMessageIndex + 1 : currentTopicMessageIds.length
|
||||
|
||||
// 4. Update Database (Save the stub to the topic's message list)
|
||||
await saveMessageAndBlocksToDB(topicId, newAssistantMessageStub, [], insertAtIndex)
|
||||
// 4. Message already saved via Data API POST above
|
||||
// await saveMessageAndBlocksToDB(topicId, newAssistantMessageStub, [], insertAtIndex)
|
||||
|
||||
dispatch(
|
||||
newMessagesActions.insertMessageAtIndex({ topicId, message: newAssistantMessageStub, index: insertAtIndex })
|
||||
|
||||
Loading…
Reference in New Issue
Block a user