mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
feat(session messages): implement user message persistence and retrieve last agent session ID
This commit is contained in:
parent
1c19e529ac
commit
d8b47e30c4
@ -9,7 +9,7 @@ import type {
|
||||
} from '@types'
|
||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { desc, eq } from 'drizzle-orm'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
||||
@ -170,29 +170,6 @@ export class SessionMessageService extends BaseService {
|
||||
return { messages }
|
||||
}
|
||||
|
||||
async saveUserMessage(
|
||||
tx: any,
|
||||
sessionId: string,
|
||||
prompt: string,
|
||||
agentSessionId: string
|
||||
): Promise<AgentSessionMessageEntity> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const now = new Date().toISOString()
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: sessionId,
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
agent_session_id: agentSessionId,
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||
}
|
||||
|
||||
createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter {
|
||||
this.ensureInitialized()
|
||||
|
||||
@ -210,12 +187,8 @@ export class SessionMessageService extends BaseService {
|
||||
req: CreateSessionMessageRequest,
|
||||
sessionStream: EventEmitter
|
||||
): Promise<void> {
|
||||
const previousMessages = session.messages || []
|
||||
let agentSessionId: string = ''
|
||||
if (previousMessages.length > 0) {
|
||||
agentSessionId = previousMessages[previousMessages.length - 1].agent_session_id
|
||||
}
|
||||
|
||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||
let newAgentSessionId = ''
|
||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||
|
||||
if (session.agent_type !== 'claude-code') {
|
||||
@ -223,7 +196,6 @@ export class SessionMessageService extends BaseService {
|
||||
logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type })
|
||||
throw new Error('Unsupported agent type for streaming')
|
||||
}
|
||||
let newAgentSessionId = ''
|
||||
|
||||
// Create the streaming agent invocation (using invokeStream for streaming)
|
||||
const claudeStream = this.cc.invoke(req.content, session.accessible_paths[0], agentSessionId, {
|
||||
@ -273,8 +245,8 @@ export class SessionMessageService extends BaseService {
|
||||
case 'complete': {
|
||||
// Then handle async persistence
|
||||
this.database.transaction(async (tx) => {
|
||||
await this.saveUserMessage(tx, session.id, req.content, newAgentSessionId)
|
||||
await this.persistSessionMessageAsync({
|
||||
await this.persistUserMessage(tx, session.id, req.content, newAgentSessionId)
|
||||
await this.persistAssistantMessage({
|
||||
tx,
|
||||
session,
|
||||
accumulator,
|
||||
@ -304,7 +276,51 @@ export class SessionMessageService extends BaseService {
|
||||
})
|
||||
}
|
||||
|
||||
private async persistSessionMessageAsync({
|
||||
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||
this.ensureInitialized()
|
||||
|
||||
try {
|
||||
const result = await this.database
|
||||
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||
.orderBy(desc(sessionMessagesTable.created_at))
|
||||
.limit(1)
|
||||
|
||||
return result[0]?.agent_session_id || ''
|
||||
} catch (error) {
|
||||
logger.error('Failed to get last agent session ID', {
|
||||
sessionId,
|
||||
error
|
||||
})
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
async persistUserMessage(
|
||||
tx: any,
|
||||
sessionId: string,
|
||||
prompt: string,
|
||||
agentSessionId: string
|
||||
): Promise<AgentSessionMessageEntity> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const now = new Date().toISOString()
|
||||
const insertData: InsertSessionMessageRow = {
|
||||
session_id: sessionId,
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
agent_session_id: agentSessionId,
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||
}
|
||||
|
||||
private async persistAssistantMessage({
|
||||
tx,
|
||||
session,
|
||||
accumulator,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user