mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 06:49:02 +08:00
feat: refactor message persistence and block update logic for agent sessions
This commit is contained in:
parent
7631d9d730
commit
e40e1d0b36
@ -7,10 +7,10 @@ import type {
|
|||||||
AgentPersistedMessage,
|
AgentPersistedMessage,
|
||||||
AgentSessionMessageEntity
|
AgentSessionMessageEntity
|
||||||
} from '@types'
|
} from '@types'
|
||||||
import { asc, eq } from 'drizzle-orm'
|
import { and, asc, eq } from 'drizzle-orm'
|
||||||
|
|
||||||
import { BaseService } from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import type { InsertSessionMessageRow } from './schema'
|
import type { InsertSessionMessageRow, SessionMessageRow } from './schema'
|
||||||
import { sessionMessagesTable } from './schema'
|
import { sessionMessagesTable } from './schema'
|
||||||
|
|
||||||
const logger = loggerService.withContext('AgentMessageRepository')
|
const logger = loggerService.withContext('AgentMessageRepository')
|
||||||
@ -91,19 +91,86 @@ class AgentMessageRepository extends BaseService {
|
|||||||
return tx ?? this.database
|
return tx ?? this.database
|
||||||
}
|
}
|
||||||
|
|
||||||
async persistUserMessage(params: PersistUserMessageParams): Promise<AgentSessionMessageEntity> {
|
private async findExistingMessageRow(
|
||||||
|
writer: TxClient,
|
||||||
|
sessionId: string,
|
||||||
|
role: string,
|
||||||
|
messageId: string
|
||||||
|
): Promise<SessionMessageRow | null> {
|
||||||
|
const candidateRows: SessionMessageRow[] = await writer
|
||||||
|
.select()
|
||||||
|
.from(sessionMessagesTable)
|
||||||
|
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
|
||||||
|
.orderBy(asc(sessionMessagesTable.created_at))
|
||||||
|
|
||||||
|
for (const row of candidateRows) {
|
||||||
|
if (!row?.content) continue
|
||||||
|
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(row.content) as AgentPersistedMessage | undefined
|
||||||
|
if (parsed?.message?.id === messageId) {
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn('Failed to parse session message content JSON during lookup', error as Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
private async upsertMessage(
|
||||||
|
params: PersistUserMessageParams | PersistAssistantMessageParams
|
||||||
|
): Promise<AgentSessionMessageEntity> {
|
||||||
await AgentMessageRepository.initialize()
|
await AgentMessageRepository.initialize()
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
const writer = this.getWriter(params.tx)
|
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
|
||||||
const now = params.createdAt ?? params.payload.message.createdAt ?? new Date().toISOString()
|
|
||||||
|
if (!payload?.message?.role) {
|
||||||
|
throw new Error('Message payload missing role')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!payload.message.id) {
|
||||||
|
throw new Error('Message payload missing id')
|
||||||
|
}
|
||||||
|
|
||||||
|
const writer = this.getWriter(tx)
|
||||||
|
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
|
||||||
|
const serializedPayload = this.serializeMessage(payload)
|
||||||
|
const serializedMetadata = this.serializeMetadata(metadata)
|
||||||
|
|
||||||
|
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
|
||||||
|
|
||||||
|
if (existingRow) {
|
||||||
|
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
|
||||||
|
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
|
||||||
|
|
||||||
|
await writer
|
||||||
|
.update(sessionMessagesTable)
|
||||||
|
.set({
|
||||||
|
content: serializedPayload,
|
||||||
|
metadata: metadataToPersist,
|
||||||
|
agent_session_id: agentSessionToPersist,
|
||||||
|
updated_at: now
|
||||||
|
})
|
||||||
|
.where(eq(sessionMessagesTable.id, existingRow.id))
|
||||||
|
|
||||||
|
return this.deserialize({
|
||||||
|
...existingRow,
|
||||||
|
content: serializedPayload,
|
||||||
|
metadata: metadataToPersist,
|
||||||
|
agent_session_id: agentSessionToPersist,
|
||||||
|
updated_at: now
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const insertData: InsertSessionMessageRow = {
|
const insertData: InsertSessionMessageRow = {
|
||||||
session_id: params.sessionId,
|
session_id: sessionId,
|
||||||
role: params.payload.message.role,
|
role: payload.message.role,
|
||||||
content: this.serializeMessage(params.payload),
|
content: serializedPayload,
|
||||||
agent_session_id: params.agentSessionId ?? '',
|
agent_session_id: agentSessionId,
|
||||||
metadata: this.serializeMetadata(params.metadata),
|
metadata: serializedMetadata,
|
||||||
created_at: now,
|
created_at: now,
|
||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
@ -113,26 +180,12 @@ class AgentMessageRepository extends BaseService {
|
|||||||
return this.deserialize(saved)
|
return this.deserialize(saved)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async persistUserMessage(params: PersistUserMessageParams): Promise<AgentSessionMessageEntity> {
|
||||||
|
return this.upsertMessage({ ...params, agentSessionId: params.agentSessionId ?? '' })
|
||||||
|
}
|
||||||
|
|
||||||
async persistAssistantMessage(params: PersistAssistantMessageParams): Promise<AgentSessionMessageEntity> {
|
async persistAssistantMessage(params: PersistAssistantMessageParams): Promise<AgentSessionMessageEntity> {
|
||||||
await AgentMessageRepository.initialize()
|
return this.upsertMessage(params)
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
const writer = this.getWriter(params.tx)
|
|
||||||
const now = params.createdAt ?? params.payload.message.createdAt ?? new Date().toISOString()
|
|
||||||
|
|
||||||
const insertData: InsertSessionMessageRow = {
|
|
||||||
session_id: params.sessionId,
|
|
||||||
role: params.payload.message.role,
|
|
||||||
content: this.serializeMessage(params.payload),
|
|
||||||
agent_session_id: params.agentSessionId,
|
|
||||||
metadata: this.serializeMetadata(params.metadata),
|
|
||||||
created_at: now,
|
|
||||||
updated_at: now
|
|
||||||
}
|
|
||||||
|
|
||||||
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
|
|
||||||
|
|
||||||
return this.deserialize(saved)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
|
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
|
||||||
@ -145,9 +198,6 @@ class AgentMessageRepository extends BaseService {
|
|||||||
const exchangeResult: PersistExchangeResult = {}
|
const exchangeResult: PersistExchangeResult = {}
|
||||||
|
|
||||||
if (user?.payload) {
|
if (user?.payload) {
|
||||||
if (!user.payload.message?.role) {
|
|
||||||
throw new Error('User message payload missing role')
|
|
||||||
}
|
|
||||||
exchangeResult.userMessage = await this.persistUserMessage({
|
exchangeResult.userMessage = await this.persistUserMessage({
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId,
|
agentSessionId,
|
||||||
@ -159,9 +209,6 @@ class AgentMessageRepository extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (assistant?.payload) {
|
if (assistant?.payload) {
|
||||||
if (!assistant.payload.message?.role) {
|
|
||||||
throw new Error('Assistant message payload missing role')
|
|
||||||
}
|
|
||||||
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId,
|
agentSessionId,
|
||||||
|
|||||||
@ -70,13 +70,13 @@ class TextStreamAccumulator {
|
|||||||
if (part.toolCallId) {
|
if (part.toolCallId) {
|
||||||
this.toolCalls.set(part.toolCallId, {
|
this.toolCalls.set(part.toolCallId, {
|
||||||
toolName: part.toolName,
|
toolName: part.toolName,
|
||||||
input: part.input ?? part.args ?? part.providerMetadata?.raw?.input
|
input: part.input ?? part.providerMetadata?.raw?.input
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case 'tool-result':
|
case 'tool-result':
|
||||||
if (part.toolCallId) {
|
if (part.toolCallId) {
|
||||||
this.toolResults.set(part.toolCallId, part.output ?? part.result ?? part.providerMetadata?.raw)
|
this.toolResults.set(part.toolCallId, part.output)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
import store from '@renderer/store'
|
||||||
import type { AgentPersistedMessage } from '@renderer/types/agent'
|
import type { AgentPersistedMessage } from '@renderer/types/agent'
|
||||||
import type { Message, MessageBlock } from '@renderer/types/newMessage'
|
import type { Message, MessageBlock } from '@renderer/types/newMessage'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
@ -21,6 +22,7 @@ const streamingMessageCache = new LRUCache<
|
|||||||
blocks: MessageBlock[]
|
blocks: MessageBlock[]
|
||||||
isComplete: boolean
|
isComplete: boolean
|
||||||
sessionId: string
|
sessionId: string
|
||||||
|
agentSessionId?: string
|
||||||
}
|
}
|
||||||
>({
|
>({
|
||||||
max: 100,
|
max: 100,
|
||||||
@ -51,13 +53,14 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
const cached = streamingMessageCache.get(messageId)
|
const cached = streamingMessageCache.get(messageId)
|
||||||
if (!cached) return
|
if (!cached) return
|
||||||
|
|
||||||
const { message, blocks, sessionId, isComplete } = cached
|
const { message, blocks, sessionId, isComplete, agentSessionId } = cached
|
||||||
|
const sessionPointer = agentSessionId ?? message.agentSessionId ?? ''
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Persist to backend
|
// Persist to backend
|
||||||
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId: '',
|
agentSessionId: sessionPointer,
|
||||||
...(message.role === 'user'
|
...(message.role === 'user'
|
||||||
? { user: { payload: { message, blocks } } }
|
? { user: { payload: { message, blocks } } }
|
||||||
: { assistant: { payload: { message, blocks } } })
|
: { assistant: { payload: { message, blocks } } })
|
||||||
@ -100,6 +103,42 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private mergeBlockUpdates(existingBlocks: MessageBlock[], updates: MessageBlock[]): MessageBlock[] {
|
||||||
|
if (existingBlocks.length === 0) {
|
||||||
|
return [...updates]
|
||||||
|
}
|
||||||
|
|
||||||
|
const existingById = new Map(existingBlocks.map((block) => [block.id, block]))
|
||||||
|
|
||||||
|
for (const update of updates) {
|
||||||
|
if (!update?.id) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingById.set(update.id, update)
|
||||||
|
}
|
||||||
|
|
||||||
|
const merged: MessageBlock[] = []
|
||||||
|
|
||||||
|
for (const original of existingBlocks) {
|
||||||
|
const updated = existingById.get(original.id)
|
||||||
|
if (updated) {
|
||||||
|
merged.push(updated)
|
||||||
|
existingById.delete(original.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const update of updates) {
|
||||||
|
if (!update?.id) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (!merged.some((block) => block.id === update.id)) {
|
||||||
|
merged.push(update)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
// ============ Read Operations ============
|
// ============ Read Operations ============
|
||||||
|
|
||||||
async fetchMessages(topicId: string): Promise<{
|
async fetchMessages(topicId: string): Promise<{
|
||||||
@ -146,6 +185,7 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============ Write Operations ============
|
// ============ Write Operations ============
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async appendMessage(topicId: string, message: Message, blocks: MessageBlock[], _insertIndex?: number): Promise<void> {
|
async appendMessage(topicId: string, message: Message, blocks: MessageBlock[], _insertIndex?: number): Promise<void> {
|
||||||
const sessionId = extractSessionId(topicId)
|
const sessionId = extractSessionId(topicId)
|
||||||
if (!sessionId) {
|
if (!sessionId) {
|
||||||
@ -154,6 +194,7 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const isStreaming = this.isMessageStreaming(message)
|
const isStreaming = this.isMessageStreaming(message)
|
||||||
|
const agentSessionId = message.agentSessionId ?? ''
|
||||||
|
|
||||||
// Always persist immediately for visibility in UI
|
// Always persist immediately for visibility in UI
|
||||||
const payload: AgentPersistedMessage = {
|
const payload: AgentPersistedMessage = {
|
||||||
@ -163,7 +204,7 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
|
|
||||||
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId: '',
|
agentSessionId,
|
||||||
...(message.role === 'user' ? { user: { payload } } : { assistant: { payload } })
|
...(message.role === 'user' ? { user: { payload } } : { assistant: { payload } })
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -180,7 +221,8 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
message,
|
message,
|
||||||
blocks,
|
blocks,
|
||||||
isComplete: false,
|
isComplete: false,
|
||||||
sessionId
|
sessionId,
|
||||||
|
agentSessionId
|
||||||
})
|
})
|
||||||
|
|
||||||
// Set up throttled persister for future updates
|
// Set up throttled persister for future updates
|
||||||
@ -218,16 +260,26 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
|
|
||||||
// Merge updates with existing message
|
// Merge updates with existing message
|
||||||
const updatedMessage = { ...existingMessage.message, ...updates }
|
const updatedMessage = { ...existingMessage.message, ...updates }
|
||||||
|
const agentSessionId = updatedMessage.agentSessionId ?? existingMessage.message.agentSessionId ?? ''
|
||||||
|
|
||||||
// Save updated message back to backend
|
// Save updated message back to backend
|
||||||
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId: '',
|
agentSessionId,
|
||||||
...(updatedMessage.role === 'user'
|
...(updatedMessage.role === 'user'
|
||||||
? { user: { payload: { message: updatedMessage, blocks: existingMessage.blocks || [] } } }
|
? { user: { payload: { message: updatedMessage, blocks: existingMessage.blocks || [] } } }
|
||||||
: { assistant: { payload: { message: updatedMessage, blocks: existingMessage.blocks || [] } } })
|
: { assistant: { payload: { message: updatedMessage, blocks: existingMessage.blocks || [] } } })
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const cacheEntry = streamingMessageCache.get(messageId)
|
||||||
|
if (cacheEntry) {
|
||||||
|
streamingMessageCache.set(messageId, {
|
||||||
|
...cacheEntry,
|
||||||
|
message: updatedMessage,
|
||||||
|
agentSessionId: agentSessionId || cacheEntry.agentSessionId
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(`Updated message ${messageId} in agent session ${sessionId}`)
|
logger.info(`Updated message ${messageId} in agent session ${sessionId}`)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Failed to update message ${messageId} in agent session ${topicId}:`, error as Error)
|
logger.error(`Failed to update message ${messageId} in agent session ${topicId}:`, error as Error)
|
||||||
@ -284,12 +336,15 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const agentSessionId = currentMessage.agentSessionId ?? cached?.agentSessionId ?? ''
|
||||||
|
|
||||||
// Update cache
|
// Update cache
|
||||||
streamingMessageCache.set(messageUpdates.id, {
|
streamingMessageCache.set(messageUpdates.id, {
|
||||||
message: currentMessage,
|
message: currentMessage,
|
||||||
blocks: currentBlocks,
|
blocks: currentBlocks,
|
||||||
isComplete: false,
|
isComplete: false,
|
||||||
sessionId
|
sessionId,
|
||||||
|
agentSessionId
|
||||||
})
|
})
|
||||||
|
|
||||||
// Trigger throttled persist
|
// Trigger throttled persist
|
||||||
@ -331,20 +386,23 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const agentSessionId = finalMessage.agentSessionId ?? cached?.agentSessionId ?? ''
|
||||||
|
|
||||||
// Mark as complete in cache if it was streaming
|
// Mark as complete in cache if it was streaming
|
||||||
if (cached) {
|
if (cached) {
|
||||||
streamingMessageCache.set(messageUpdates.id, {
|
streamingMessageCache.set(messageUpdates.id, {
|
||||||
message: finalMessage,
|
message: finalMessage,
|
||||||
blocks: finalBlocks,
|
blocks: finalBlocks,
|
||||||
isComplete: true,
|
isComplete: true,
|
||||||
sessionId
|
sessionId,
|
||||||
|
agentSessionId
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Persist to backend
|
// Persist to backend
|
||||||
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId: '',
|
agentSessionId,
|
||||||
...(finalMessage.role === 'user'
|
...(finalMessage.role === 'user'
|
||||||
? { user: { payload: { message: finalMessage, blocks: finalBlocks } } }
|
? { user: { payload: { message: finalMessage, blocks: finalBlocks } } }
|
||||||
: { assistant: { payload: { message: finalMessage, blocks: finalBlocks } } })
|
: { assistant: { payload: { message: finalMessage, blocks: finalBlocks } } })
|
||||||
@ -364,6 +422,7 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async deleteMessage(topicId: string, _messageId: string): Promise<void> {
|
async deleteMessage(topicId: string, _messageId: string): Promise<void> {
|
||||||
// Agent session messages cannot be deleted individually
|
// Agent session messages cannot be deleted individually
|
||||||
logger.warn(`deleteMessage called for agent session ${topicId}, operation not supported`)
|
logger.warn(`deleteMessage called for agent session ${topicId}, operation not supported`)
|
||||||
@ -373,6 +432,7 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
// 2. Or just hide from UI without actual deletion
|
// 2. Or just hide from UI without actual deletion
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async deleteMessages(topicId: string, _messageIds: string[]): Promise<void> {
|
async deleteMessages(topicId: string, _messageIds: string[]): Promise<void> {
|
||||||
// Agent session messages cannot be deleted in batch
|
// Agent session messages cannot be deleted in batch
|
||||||
logger.warn(`deleteMessages called for agent session ${topicId}, operation not supported`)
|
logger.warn(`deleteMessages called for agent session ${topicId}, operation not supported`)
|
||||||
@ -382,6 +442,7 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
// 2. Update local state accordingly
|
// 2. Update local state accordingly
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async deleteMessagesByAskId(topicId: string, _askId: string): Promise<void> {
|
async deleteMessagesByAskId(topicId: string, _askId: string): Promise<void> {
|
||||||
// Agent session messages cannot be deleted
|
// Agent session messages cannot be deleted
|
||||||
logger.warn(`deleteMessagesByAskId called for agent session ${topicId}, operation not supported`)
|
logger.warn(`deleteMessagesByAskId called for agent session ${topicId}, operation not supported`)
|
||||||
@ -389,11 +450,134 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
|
|
||||||
// ============ Block Operations ============
|
// ============ Block Operations ============
|
||||||
|
|
||||||
async updateBlocks(_blocks: MessageBlock[]): Promise<void> {
|
async updateBlocks(blocks: MessageBlock[]): Promise<void> {
|
||||||
// Blocks are updated through persistExchange for agent sessions
|
if (!blocks.length) {
|
||||||
logger.warn('updateBlocks called for agent session, operation not supported individually')
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!window.electron?.ipcRenderer) {
|
||||||
|
logger.warn('IPC renderer not available for agent block update')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const state = store.getState()
|
||||||
|
|
||||||
|
const sessionMessageMap = new Map<
|
||||||
|
string,
|
||||||
|
Map<
|
||||||
|
string,
|
||||||
|
{
|
||||||
|
message: Message | undefined
|
||||||
|
updates: MessageBlock[]
|
||||||
|
baseBlocks?: MessageBlock[]
|
||||||
|
}
|
||||||
|
>
|
||||||
|
>()
|
||||||
|
|
||||||
|
for (const block of blocks) {
|
||||||
|
const messageId = block.messageId
|
||||||
|
if (!messageId) {
|
||||||
|
logger.warn('Skipping block update without messageId')
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const cached = streamingMessageCache.get(messageId)
|
||||||
|
const storeMessage = cached?.message ?? state.messages.entities[messageId]
|
||||||
|
|
||||||
|
if (!storeMessage) {
|
||||||
|
logger.warn(`Unable to locate parent message ${messageId} for block update`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const sessionId = cached?.sessionId ?? extractSessionId(storeMessage.topicId)
|
||||||
|
if (!sessionId) {
|
||||||
|
logger.warn(`Unable to determine session for message ${messageId}`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sessionMessageMap.has(sessionId)) {
|
||||||
|
sessionMessageMap.set(sessionId, new Map())
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageMap = sessionMessageMap.get(sessionId)!
|
||||||
|
if (!messageMap.has(messageId)) {
|
||||||
|
messageMap.set(messageId, {
|
||||||
|
message: storeMessage,
|
||||||
|
updates: [],
|
||||||
|
baseBlocks: cached?.blocks
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
messageMap.get(messageId)!.updates.push(block)
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [sessionId, messageMap] of sessionMessageMap) {
|
||||||
|
let historyMap: Map<string, AgentPersistedMessage> | null = null
|
||||||
|
|
||||||
|
for (const [messageId, pending] of messageMap) {
|
||||||
|
let baseBlocks = pending.baseBlocks
|
||||||
|
let message = pending.message
|
||||||
|
|
||||||
|
if (!baseBlocks) {
|
||||||
|
if (!historyMap) {
|
||||||
|
const historicalMessages: AgentPersistedMessage[] = await window.electron.ipcRenderer.invoke(
|
||||||
|
IpcChannel.AgentMessage_GetHistory,
|
||||||
|
{ sessionId }
|
||||||
|
)
|
||||||
|
historyMap = new Map(
|
||||||
|
(historicalMessages || [])
|
||||||
|
.filter((persisted) => persisted?.message?.id)
|
||||||
|
.map((persisted) => [persisted.message.id, persisted])
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const persisted = historyMap.get(messageId)
|
||||||
|
if (persisted) {
|
||||||
|
baseBlocks = persisted.blocks || []
|
||||||
|
if (!message) {
|
||||||
|
message = persisted.message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!message) {
|
||||||
|
logger.warn(`Failed to resolve message payload for ${messageId}, skipping block persist`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
const mergedBlocks = this.mergeBlockUpdates(baseBlocks || [], pending.updates)
|
||||||
|
const cacheEntry = streamingMessageCache.get(messageId)
|
||||||
|
const agentSessionId = message.agentSessionId ?? cacheEntry?.agentSessionId ?? ''
|
||||||
|
|
||||||
|
if (cacheEntry) {
|
||||||
|
streamingMessageCache.set(messageId, {
|
||||||
|
...cacheEntry,
|
||||||
|
blocks: mergedBlocks,
|
||||||
|
agentSessionId
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
await window.electron.ipcRenderer.invoke(IpcChannel.AgentMessage_PersistExchange, {
|
||||||
|
sessionId,
|
||||||
|
agentSessionId,
|
||||||
|
...(message.role === 'user'
|
||||||
|
? { user: { payload: { message, blocks: mergedBlocks } } }
|
||||||
|
: { assistant: { payload: { message, blocks: mergedBlocks } } })
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.debug(`Persisted block updates for message ${messageId} in agent session ${sessionId}`, {
|
||||||
|
blockCount: mergedBlocks.length
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to update agent message blocks:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async deleteBlocks(_blockIds: string[]): Promise<void> {
|
async deleteBlocks(_blockIds: string[]): Promise<void> {
|
||||||
// Blocks cannot be deleted individually for agent sessions
|
// Blocks cannot be deleted individually for agent sessions
|
||||||
logger.warn('deleteBlocks called for agent session, operation not supported')
|
logger.warn('deleteBlocks called for agent session, operation not supported')
|
||||||
@ -456,21 +640,25 @@ export class AgentMessageDataSource implements MessageDataSource {
|
|||||||
|
|
||||||
// ============ Additional Methods for Interface Compatibility ============
|
// ============ Additional Methods for Interface Compatibility ============
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async updateSingleBlock(blockId: string, _updates: Partial<MessageBlock>): Promise<void> {
|
async updateSingleBlock(blockId: string, _updates: Partial<MessageBlock>): Promise<void> {
|
||||||
// Agent session blocks are immutable once persisted
|
// Agent session blocks are immutable once persisted
|
||||||
logger.warn(`updateSingleBlock called for agent session block ${blockId}, operation not supported`)
|
logger.warn(`updateSingleBlock called for agent session block ${blockId}, operation not supported`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async bulkAddBlocks(_blocks: MessageBlock[]): Promise<void> {
|
async bulkAddBlocks(_blocks: MessageBlock[]): Promise<void> {
|
||||||
// Agent session blocks are added through persistExchange
|
// Agent session blocks are added through persistExchange
|
||||||
logger.warn(`bulkAddBlocks called for agent session, operation not supported individually`)
|
logger.warn(`bulkAddBlocks called for agent session, operation not supported individually`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async updateFileCount(fileId: string, _delta: number, _deleteIfZero?: boolean): Promise<void> {
|
async updateFileCount(fileId: string, _delta: number, _deleteIfZero?: boolean): Promise<void> {
|
||||||
// Agent sessions don't manage file reference counts locally
|
// Agent sessions don't manage file reference counts locally
|
||||||
logger.warn(`updateFileCount called for agent session file ${fileId}, operation not supported`)
|
logger.warn(`updateFileCount called for agent session file ${fileId}, operation not supported`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
async updateFileCounts(_files: Array<{ id: string; delta: number; deleteIfZero?: boolean }>): Promise<void> {
|
async updateFileCounts(_files: Array<{ id: string; delta: number; deleteIfZero?: boolean }>): Promise<void> {
|
||||||
// Agent sessions don't manage file reference counts locally
|
// Agent sessions don't manage file reference counts locally
|
||||||
logger.warn(`updateFileCounts called for agent session, operation not supported`)
|
logger.warn(`updateFileCounts called for agent session, operation not supported`)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
import store from '@renderer/store'
|
||||||
import type { Message, MessageBlock } from '@renderer/types/newMessage'
|
import type { Message, MessageBlock } from '@renderer/types/newMessage'
|
||||||
|
|
||||||
import { AgentMessageDataSource } from './AgentMessageDataSource'
|
import { AgentMessageDataSource } from './AgentMessageDataSource'
|
||||||
@ -94,12 +95,31 @@ class DbService implements MessageDataSource {
|
|||||||
// ============ Block Operations ============
|
// ============ Block Operations ============
|
||||||
|
|
||||||
async updateBlocks(blocks: MessageBlock[]): Promise<void> {
|
async updateBlocks(blocks: MessageBlock[]): Promise<void> {
|
||||||
// For block operations, we need to infer the source from the first block's message
|
if (blocks.length === 0) {
|
||||||
// This is a limitation of the current design where blocks don't have topicId
|
return
|
||||||
// In practice, blocks are usually updated in context of a topic operation
|
}
|
||||||
|
|
||||||
// Default to Dexie for now since agent blocks are updated through persistExchange
|
const state = store.getState()
|
||||||
return this.dexieSource.updateBlocks(blocks)
|
|
||||||
|
const agentBlocks: MessageBlock[] = []
|
||||||
|
const regularBlocks: MessageBlock[] = []
|
||||||
|
|
||||||
|
for (const block of blocks) {
|
||||||
|
const parentMessage = state.messages.entities[block.messageId]
|
||||||
|
if (parentMessage && isAgentSessionTopicId(parentMessage.topicId)) {
|
||||||
|
agentBlocks.push(block)
|
||||||
|
} else {
|
||||||
|
regularBlocks.push(block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (agentBlocks.length > 0) {
|
||||||
|
await this.agentSource.updateBlocks(agentBlocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (regularBlocks.length > 0) {
|
||||||
|
await this.dexieSource.updateBlocks(regularBlocks)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async deleteBlocks(blockIds: string[]): Promise<void> {
|
async deleteBlocks(blockIds: string[]): Promise<void> {
|
||||||
|
|||||||
@ -249,25 +249,29 @@ const updateExistingMessageAndBlocksInDB = async (
|
|||||||
updatedBlocks: MessageBlock[]
|
updatedBlocks: MessageBlock[]
|
||||||
) => {
|
) => {
|
||||||
try {
|
try {
|
||||||
if (isAgentSessionTopicId(updatedMessage.topicId)) {
|
// Always update blocks if provided
|
||||||
return
|
if (updatedBlocks.length > 0) {
|
||||||
}
|
if (featureFlags.USE_UNIFIED_DB_SERVICE) {
|
||||||
await db.transaction('rw', db.topics, db.message_blocks, async () => {
|
await updateBlocksV2(updatedBlocks)
|
||||||
// Always update blocks if provided
|
} else {
|
||||||
if (updatedBlocks.length > 0) {
|
await db.transaction('rw', db.topics, db.message_blocks, async () => {
|
||||||
// Use V2 implementation if feature flag is enabled
|
|
||||||
if (featureFlags.USE_UNIFIED_DB_SERVICE) {
|
|
||||||
await updateBlocksV2(updatedBlocks)
|
|
||||||
} else {
|
|
||||||
await db.message_blocks.bulkPut(updatedBlocks)
|
await db.message_blocks.bulkPut(updatedBlocks)
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if there are message properties to update beyond id and topicId
|
// Check if there are message properties to update beyond id and topicId
|
||||||
const messageKeysToUpdate = Object.keys(updatedMessage).filter((key) => key !== 'id' && key !== 'topicId')
|
const messageKeysToUpdate = Object.keys(updatedMessage).filter((key) => key !== 'id' && key !== 'topicId')
|
||||||
|
|
||||||
// Only proceed with topic update if there are actual message changes
|
if (messageKeysToUpdate.length > 0) {
|
||||||
if (messageKeysToUpdate.length > 0) {
|
if (featureFlags.USE_UNIFIED_DB_SERVICE) {
|
||||||
|
const messageUpdatesPayload = messageKeysToUpdate.reduce<Partial<Message>>((acc, key) => {
|
||||||
|
acc[key] = updatedMessage[key]
|
||||||
|
return acc
|
||||||
|
}, {})
|
||||||
|
|
||||||
|
await updateMessageV2(updatedMessage.topicId, updatedMessage.id, messageUpdatesPayload)
|
||||||
|
} else {
|
||||||
// 使用 where().modify() 进行原子更新
|
// 使用 where().modify() 进行原子更新
|
||||||
await db.topics
|
await db.topics
|
||||||
.where('id')
|
.where('id')
|
||||||
@ -283,10 +287,10 @@ const updateExistingMessageAndBlocksInDB = async (
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
store.dispatch(updateTopicUpdatedAt({ topicId: updatedMessage.topicId }))
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
store.dispatch(updateTopicUpdatedAt({ topicId: updatedMessage.topicId }))
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`[updateExistingMsg] Failed to update message ${updatedMessage.id}:`, error as Error)
|
logger.error(`[updateExistingMsg] Failed to update message ${updatedMessage.id}:`, error as Error)
|
||||||
}
|
}
|
||||||
@ -494,8 +498,59 @@ const fetchAndProcessAgentResponseImpl = async (
|
|||||||
)
|
)
|
||||||
|
|
||||||
let latestAgentSessionId = ''
|
let latestAgentSessionId = ''
|
||||||
const adapter = new AiSdkToChunkAdapter(streamProcessorCallbacks, [], false, false, (sessionId) => {
|
|
||||||
|
const persistAgentSessionId = async (sessionId: string) => {
|
||||||
|
if (!sessionId || sessionId === latestAgentSessionId) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
latestAgentSessionId = sessionId
|
latestAgentSessionId = sessionId
|
||||||
|
|
||||||
|
logger.debug(`Agent session ID updated`, {
|
||||||
|
topicId,
|
||||||
|
assistantMessageId: assistantMessage.id,
|
||||||
|
value: sessionId
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
const stateAfterUpdate = getState()
|
||||||
|
const assistantInState = stateAfterUpdate.messages.entities[assistantMessage.id]
|
||||||
|
const userInState = stateAfterUpdate.messages.entities[userMessageId]
|
||||||
|
|
||||||
|
const persistTasks: Promise<void>[] = []
|
||||||
|
|
||||||
|
if (assistantInState?.agentSessionId !== sessionId) {
|
||||||
|
dispatch(
|
||||||
|
newMessagesActions.updateMessage({
|
||||||
|
topicId,
|
||||||
|
messageId: assistantMessage.id,
|
||||||
|
updates: { agentSessionId: sessionId }
|
||||||
|
})
|
||||||
|
)
|
||||||
|
persistTasks.push(saveUpdatesToDB(assistantMessage.id, topicId, { agentSessionId: sessionId }, []))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (userInState && userInState.agentSessionId !== sessionId) {
|
||||||
|
dispatch(
|
||||||
|
newMessagesActions.updateMessage({
|
||||||
|
topicId,
|
||||||
|
messageId: userMessageId,
|
||||||
|
updates: { agentSessionId: sessionId }
|
||||||
|
})
|
||||||
|
)
|
||||||
|
persistTasks.push(saveUpdatesToDB(userMessageId, topicId, { agentSessionId: sessionId }, []))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (persistTasks.length > 0) {
|
||||||
|
await Promise.all(persistTasks)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to persist agent session ID during stream', error as Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const adapter = new AiSdkToChunkAdapter(streamProcessorCallbacks, [], false, false, (sessionId) => {
|
||||||
|
void persistAgentSessionId(sessionId)
|
||||||
})
|
})
|
||||||
|
|
||||||
await adapter.processStream({
|
await adapter.processStream({
|
||||||
@ -509,10 +564,9 @@ const fetchAndProcessAgentResponseImpl = async (
|
|||||||
// 3. Updates during streaming are saved via updateMessageAndBlocks
|
// 3. Updates during streaming are saved via updateMessageAndBlocks
|
||||||
// This eliminates the duplicate save issue
|
// This eliminates the duplicate save issue
|
||||||
|
|
||||||
// Only persist the agentSessionId update if it changed
|
// Attempt final persistence in case the session id arrived late in the stream
|
||||||
if (latestAgentSessionId) {
|
if (latestAgentSessionId) {
|
||||||
logger.info(`Agent session ID updated to: ${latestAgentSessionId}`)
|
await persistAgentSessionId(latestAgentSessionId)
|
||||||
// In the future, you might want to update some session metadata here
|
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error in fetchAndProcessAgentResponseImpl:', error)
|
logger.error('Error in fetchAndProcessAgentResponseImpl:', error)
|
||||||
@ -759,7 +813,8 @@ export const sendMessage =
|
|||||||
* Loads agent session messages from backend
|
* Loads agent session messages from backend
|
||||||
*/
|
*/
|
||||||
export const loadAgentSessionMessagesThunk =
|
export const loadAgentSessionMessagesThunk =
|
||||||
(sessionId: string) => async (dispatch: AppDispatch, getState: () => RootState) => {
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
(sessionId: string) => async (dispatch: AppDispatch, _getState: () => RootState) => {
|
||||||
const topicId = buildAgentSessionTopicId(sessionId)
|
const topicId = buildAgentSessionTopicId(sessionId)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
@ -154,8 +154,13 @@ export const saveMessageAndBlocksToDBV2 = async (
|
|||||||
messageIndex: number = -1
|
messageIndex: number = -1
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
try {
|
try {
|
||||||
|
const blockIds = blocks.map((block) => block.id)
|
||||||
|
const shouldSyncBlocks =
|
||||||
|
blockIds.length > 0 && (!message.blocks || blockIds.some((id, index) => message.blocks?.[index] !== id))
|
||||||
|
|
||||||
|
const messageWithBlocks = shouldSyncBlocks ? { ...message, blocks: blockIds } : message
|
||||||
// Direct call without conditional logic, now with messageIndex
|
// Direct call without conditional logic, now with messageIndex
|
||||||
await dbService.appendMessage(topicId, message, blocks, messageIndex)
|
await dbService.appendMessage(topicId, messageWithBlocks, blocks, messageIndex)
|
||||||
logger.info('Saved message and blocks via DbService', {
|
logger.info('Saved message and blocks via DbService', {
|
||||||
topicId,
|
topicId,
|
||||||
messageId: message.id,
|
messageId: message.id,
|
||||||
|
|||||||
@ -206,6 +206,9 @@ export type Message = {
|
|||||||
// 跟踪Id
|
// 跟踪Id
|
||||||
traceId?: string
|
traceId?: string
|
||||||
|
|
||||||
|
// Agent session identifier used to resume Claude Code runs
|
||||||
|
agentSessionId?: string
|
||||||
|
|
||||||
// raw data
|
// raw data
|
||||||
// TODO: add this providerMetadata to MessageBlock to save raw provider data for each block
|
// TODO: add this providerMetadata to MessageBlock to save raw provider data for each block
|
||||||
providerMetadata?: ProviderMetadata
|
providerMetadata?: ProviderMetadata
|
||||||
|
|||||||
@ -378,6 +378,7 @@ export function resetMessage(
|
|||||||
role: originalMessage.role,
|
role: originalMessage.role,
|
||||||
topicId: originalMessage.topicId,
|
topicId: originalMessage.topicId,
|
||||||
assistantId: originalMessage.assistantId,
|
assistantId: originalMessage.assistantId,
|
||||||
|
agentSessionId: originalMessage.agentSessionId,
|
||||||
type: originalMessage.type,
|
type: originalMessage.type,
|
||||||
createdAt: originalMessage.createdAt, // Keep original creation timestamp
|
createdAt: originalMessage.createdAt, // Keep original creation timestamp
|
||||||
|
|
||||||
@ -426,6 +427,7 @@ export const resetAssistantMessage = (
|
|||||||
// --- Retain Identity ---
|
// --- Retain Identity ---
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
assistantId: originalMessage.assistantId,
|
assistantId: originalMessage.assistantId,
|
||||||
|
agentSessionId: originalMessage.agentSessionId,
|
||||||
model: originalMessage.model, // Keep the model information
|
model: originalMessage.model, // Keep the model information
|
||||||
modelId: originalMessage.modelId,
|
modelId: originalMessage.modelId,
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user