mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-26 11:44:28 +08:00
feat: 添加提取会话ID的功能并更新消息发送逻辑
This commit is contained in:
parent
7cdc80c3e2
commit
b493172090
@ -20,7 +20,11 @@ import type { FileMessageBlock, ImageMessageBlock, Message, MessageBlock } from
|
|||||||
import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
|
||||||
import { uuid } from '@renderer/utils'
|
import { uuid } from '@renderer/utils'
|
||||||
import { addAbortController } from '@renderer/utils/abortController'
|
import { addAbortController } from '@renderer/utils/abortController'
|
||||||
import { buildAgentSessionTopicId, isAgentSessionTopicId } from '@renderer/utils/agentSession'
|
import {
|
||||||
|
buildAgentSessionTopicId,
|
||||||
|
extractAgentSessionIdFromTopicId,
|
||||||
|
isAgentSessionTopicId
|
||||||
|
} from '@renderer/utils/agentSession'
|
||||||
import {
|
import {
|
||||||
createAssistantMessage,
|
createAssistantMessage,
|
||||||
createTranslationBlock,
|
createTranslationBlock,
|
||||||
@ -63,11 +67,55 @@ const finishTopicLoading = async (topicId: string) => {
|
|||||||
type AgentSessionContext = {
|
type AgentSessionContext = {
|
||||||
agentId: string
|
agentId: string
|
||||||
sessionId: string
|
sessionId: string
|
||||||
|
agentSessionId?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
const agentSessionRenameLocks = new Set<string>()
|
const agentSessionRenameLocks = new Set<string>()
|
||||||
const dbFacade = DbService.getInstance()
|
const dbFacade = DbService.getInstance()
|
||||||
|
|
||||||
|
const findExistingAgentSessionContext = (
|
||||||
|
state: RootState,
|
||||||
|
topicId: string,
|
||||||
|
assistantId: string
|
||||||
|
): AgentSessionContext | undefined => {
|
||||||
|
if (!isAgentSessionTopicId(topicId)) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
const sessionId = extractAgentSessionIdFromTopicId(topicId)
|
||||||
|
if (!sessionId) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageIds = state.messages.messageIdsByTopic[topicId]
|
||||||
|
let existingAgentSessionId: string | undefined
|
||||||
|
|
||||||
|
if (messageIds?.length) {
|
||||||
|
for (let index = messageIds.length - 1; index >= 0; index -= 1) {
|
||||||
|
const messageId = messageIds[index]
|
||||||
|
const message = state.messages.entities[messageId]
|
||||||
|
const candidate = message?.agentSessionId?.trim()
|
||||||
|
|
||||||
|
if (!candidate) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (message.assistantId !== assistantId) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
existingAgentSessionId = candidate
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
agentId: assistantId,
|
||||||
|
sessionId,
|
||||||
|
agentSessionId: existingAgentSessionId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const buildAgentBaseURL = (apiServer: ApiServerConfig) => {
|
const buildAgentBaseURL = (apiServer: ApiServerConfig) => {
|
||||||
const hasProtocol = apiServer.host.startsWith('http://') || apiServer.host.startsWith('https://')
|
const hasProtocol = apiServer.host.startsWith('http://') || apiServer.host.startsWith('https://')
|
||||||
const baseHost = hasProtocol ? apiServer.host : `http://${apiServer.host}`
|
const baseHost = hasProtocol ? apiServer.host : `http://${apiServer.host}`
|
||||||
@ -605,6 +653,7 @@ const fetchAndProcessAgentResponseImpl = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
latestAgentSessionId = sessionId
|
latestAgentSessionId = sessionId
|
||||||
|
agentSession.agentSessionId = sessionId
|
||||||
|
|
||||||
logger.debug(`Agent session ID updated`, {
|
logger.debug(`Agent session ID updated`, {
|
||||||
topicId,
|
topicId,
|
||||||
@ -650,7 +699,7 @@ const fetchAndProcessAgentResponseImpl = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
const adapter = new AiSdkToChunkAdapter(streamProcessorCallbacks, [], false, false, (sessionId) => {
|
const adapter = new AiSdkToChunkAdapter(streamProcessorCallbacks, [], false, false, (sessionId) => {
|
||||||
void persistAgentSessionId(sessionId)
|
persistAgentSessionId(sessionId)
|
||||||
})
|
})
|
||||||
|
|
||||||
await adapter.processStream({
|
await adapter.processStream({
|
||||||
@ -658,13 +707,6 @@ const fetchAndProcessAgentResponseImpl = async (
|
|||||||
text: Promise.resolve('')
|
text: Promise.resolve('')
|
||||||
})
|
})
|
||||||
|
|
||||||
// No longer need persistAgentExchange here since:
|
|
||||||
// 1. User message is already saved via appendMessage when created
|
|
||||||
// 2. Assistant message is saved via appendMessage when created
|
|
||||||
// 3. Updates during streaming are saved via updateMessageAndBlocks
|
|
||||||
// This eliminates the duplicate save issue
|
|
||||||
|
|
||||||
// Attempt final persistence in case the session id arrived late in the stream
|
|
||||||
if (latestAgentSessionId) {
|
if (latestAgentSessionId) {
|
||||||
await persistAgentSessionId(latestAgentSessionId)
|
await persistAgentSessionId(latestAgentSessionId)
|
||||||
}
|
}
|
||||||
@ -858,6 +900,19 @@ export const sendMessage =
|
|||||||
logger.warn('sendMessage: No blocks in the provided message.')
|
logger.warn('sendMessage: No blocks in the provided message.')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const stateBeforeSend = getState()
|
||||||
|
let activeAgentSession = agentSession ?? findExistingAgentSessionContext(stateBeforeSend, topicId, assistant.id)
|
||||||
|
if (activeAgentSession) {
|
||||||
|
const derivedSession = findExistingAgentSessionContext(stateBeforeSend, topicId, assistant.id)
|
||||||
|
if (derivedSession?.agentSessionId && derivedSession.agentSessionId !== activeAgentSession.agentSessionId) {
|
||||||
|
activeAgentSession = { ...activeAgentSession, agentSessionId: derivedSession.agentSessionId }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (activeAgentSession?.agentSessionId && !userMessage.agentSessionId) {
|
||||||
|
userMessage.agentSessionId = activeAgentSession.agentSessionId
|
||||||
|
}
|
||||||
|
|
||||||
await saveMessageAndBlocksToDB(userMessage, userMessageBlocks)
|
await saveMessageAndBlocksToDB(userMessage, userMessageBlocks)
|
||||||
dispatch(newMessagesActions.addMessage({ topicId, message: userMessage }))
|
dispatch(newMessagesActions.addMessage({ topicId, message: userMessage }))
|
||||||
if (userMessageBlocks.length > 0) {
|
if (userMessageBlocks.length > 0) {
|
||||||
@ -867,12 +922,15 @@ export const sendMessage =
|
|||||||
|
|
||||||
const queue = getTopicQueue(topicId)
|
const queue = getTopicQueue(topicId)
|
||||||
|
|
||||||
if (agentSession) {
|
if (activeAgentSession) {
|
||||||
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
const assistantMessage = createAssistantMessage(assistant.id, topicId, {
|
||||||
askId: userMessage.id,
|
askId: userMessage.id,
|
||||||
model: assistant.model,
|
model: assistant.model,
|
||||||
traceId: userMessage.traceId
|
traceId: userMessage.traceId
|
||||||
})
|
})
|
||||||
|
if (activeAgentSession.agentSessionId && !assistantMessage.agentSessionId) {
|
||||||
|
assistantMessage.agentSessionId = activeAgentSession.agentSessionId
|
||||||
|
}
|
||||||
await saveMessageAndBlocksToDB(assistantMessage, [])
|
await saveMessageAndBlocksToDB(assistantMessage, [])
|
||||||
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
dispatch(newMessagesActions.addMessage({ topicId, message: assistantMessage }))
|
||||||
|
|
||||||
@ -881,7 +939,7 @@ export const sendMessage =
|
|||||||
topicId,
|
topicId,
|
||||||
assistant,
|
assistant,
|
||||||
assistantMessage,
|
assistantMessage,
|
||||||
agentSession,
|
agentSession: activeAgentSession,
|
||||||
userMessageId: userMessage.id
|
userMessageId: userMessage.id
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -946,7 +1004,7 @@ export const loadAgentSessionMessagesThunk =
|
|||||||
}
|
}
|
||||||
dispatch(newMessagesActions.messagesReceived({ topicId, messages }))
|
dispatch(newMessagesActions.messagesReceived({ topicId, messages }))
|
||||||
|
|
||||||
logger.info(`Loaded ${messages.length} messages for agent session ${sessionId}`)
|
logger.silly(`Loaded ${messages.length} messages for agent session ${sessionId}`)
|
||||||
} else {
|
} else {
|
||||||
dispatch(newMessagesActions.messagesReceived({ topicId, messages: [] }))
|
dispatch(newMessagesActions.messagesReceived({ topicId, messages: [] }))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,6 +26,8 @@ export const loadTopicMessagesThunkV2 =
|
|||||||
async (dispatch: AppDispatch, getState: () => RootState) => {
|
async (dispatch: AppDispatch, getState: () => RootState) => {
|
||||||
const state = getState()
|
const state = getState()
|
||||||
|
|
||||||
|
dispatch(newMessagesActions.setCurrentTopicId(topicId))
|
||||||
|
|
||||||
// Skip if already cached and not forcing reload
|
// Skip if already cached and not forcing reload
|
||||||
if (!forceReload && state.messages.messageIdsByTopic[topicId]) {
|
if (!forceReload && state.messages.messageIdsByTopic[topicId]) {
|
||||||
return
|
return
|
||||||
|
|||||||
@ -7,3 +7,7 @@ export const buildAgentSessionTopicId = (sessionId: string): string => {
|
|||||||
export const isAgentSessionTopicId = (topicId: string): boolean => {
|
export const isAgentSessionTopicId = (topicId: string): boolean => {
|
||||||
return topicId.startsWith(SESSION_TOPIC_PREFIX)
|
return topicId.startsWith(SESSION_TOPIC_PREFIX)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const extractAgentSessionIdFromTopicId = (topicId: string): string => {
|
||||||
|
return topicId.replace(SESSION_TOPIC_PREFIX, '')
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user