mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
feat(session messages): implement user message persistence and retrieve last agent session ID
This commit is contained in:
parent
1c19e529ac
commit
d8b47e30c4
@ -9,7 +9,7 @@ import type {
|
|||||||
} from '@types'
|
} from '@types'
|
||||||
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
import { ModelMessage, UIMessage, UIMessageChunk } from 'ai'
|
||||||
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
import { convertToModelMessages, readUIMessageStream } from 'ai'
|
||||||
import { eq } from 'drizzle-orm'
|
import { desc, eq } from 'drizzle-orm'
|
||||||
|
|
||||||
import { BaseService } from '../BaseService'
|
import { BaseService } from '../BaseService'
|
||||||
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
import { InsertSessionMessageRow, sessionMessagesTable } from '../database/schema'
|
||||||
@ -170,29 +170,6 @@ export class SessionMessageService extends BaseService {
|
|||||||
return { messages }
|
return { messages }
|
||||||
}
|
}
|
||||||
|
|
||||||
async saveUserMessage(
|
|
||||||
tx: any,
|
|
||||||
sessionId: string,
|
|
||||||
prompt: string,
|
|
||||||
agentSessionId: string
|
|
||||||
): Promise<AgentSessionMessageEntity> {
|
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
const now = new Date().toISOString()
|
|
||||||
const insertData: InsertSessionMessageRow = {
|
|
||||||
session_id: sessionId,
|
|
||||||
role: 'user',
|
|
||||||
content: prompt,
|
|
||||||
agent_session_id: agentSessionId,
|
|
||||||
created_at: now,
|
|
||||||
updated_at: now
|
|
||||||
}
|
|
||||||
|
|
||||||
const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning()
|
|
||||||
|
|
||||||
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
|
||||||
}
|
|
||||||
|
|
||||||
createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter {
|
createSessionMessage(session: GetAgentSessionResponse, messageData: CreateSessionMessageRequest): EventEmitter {
|
||||||
this.ensureInitialized()
|
this.ensureInitialized()
|
||||||
|
|
||||||
@ -210,12 +187,8 @@ export class SessionMessageService extends BaseService {
|
|||||||
req: CreateSessionMessageRequest,
|
req: CreateSessionMessageRequest,
|
||||||
sessionStream: EventEmitter
|
sessionStream: EventEmitter
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const previousMessages = session.messages || []
|
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||||
let agentSessionId: string = ''
|
let newAgentSessionId = ''
|
||||||
if (previousMessages.length > 0) {
|
|
||||||
agentSessionId = previousMessages[previousMessages.length - 1].agent_session_id
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||||
|
|
||||||
if (session.agent_type !== 'claude-code') {
|
if (session.agent_type !== 'claude-code') {
|
||||||
@ -223,7 +196,6 @@ export class SessionMessageService extends BaseService {
|
|||||||
logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type })
|
logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type })
|
||||||
throw new Error('Unsupported agent type for streaming')
|
throw new Error('Unsupported agent type for streaming')
|
||||||
}
|
}
|
||||||
let newAgentSessionId = ''
|
|
||||||
|
|
||||||
// Create the streaming agent invocation (using invokeStream for streaming)
|
// Create the streaming agent invocation (using invokeStream for streaming)
|
||||||
const claudeStream = this.cc.invoke(req.content, session.accessible_paths[0], agentSessionId, {
|
const claudeStream = this.cc.invoke(req.content, session.accessible_paths[0], agentSessionId, {
|
||||||
@ -273,8 +245,8 @@ export class SessionMessageService extends BaseService {
|
|||||||
case 'complete': {
|
case 'complete': {
|
||||||
// Then handle async persistence
|
// Then handle async persistence
|
||||||
this.database.transaction(async (tx) => {
|
this.database.transaction(async (tx) => {
|
||||||
await this.saveUserMessage(tx, session.id, req.content, newAgentSessionId)
|
await this.persistUserMessage(tx, session.id, req.content, newAgentSessionId)
|
||||||
await this.persistSessionMessageAsync({
|
await this.persistAssistantMessage({
|
||||||
tx,
|
tx,
|
||||||
session,
|
session,
|
||||||
accumulator,
|
accumulator,
|
||||||
@ -304,7 +276,51 @@ export class SessionMessageService extends BaseService {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
private async persistSessionMessageAsync({
|
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||||
|
this.ensureInitialized()
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await this.database
|
||||||
|
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||||
|
.from(sessionMessagesTable)
|
||||||
|
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||||
|
.orderBy(desc(sessionMessagesTable.created_at))
|
||||||
|
.limit(1)
|
||||||
|
|
||||||
|
return result[0]?.agent_session_id || ''
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to get last agent session ID', {
|
||||||
|
sessionId,
|
||||||
|
error
|
||||||
|
})
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async persistUserMessage(
|
||||||
|
tx: any,
|
||||||
|
sessionId: string,
|
||||||
|
prompt: string,
|
||||||
|
agentSessionId: string
|
||||||
|
): Promise<AgentSessionMessageEntity> {
|
||||||
|
this.ensureInitialized()
|
||||||
|
|
||||||
|
const now = new Date().toISOString()
|
||||||
|
const insertData: InsertSessionMessageRow = {
|
||||||
|
session_id: sessionId,
|
||||||
|
role: 'user',
|
||||||
|
content: prompt,
|
||||||
|
agent_session_id: agentSessionId,
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now
|
||||||
|
}
|
||||||
|
|
||||||
|
const [saved] = await tx.insert(sessionMessagesTable).values(insertData).returning()
|
||||||
|
|
||||||
|
return this.deserializeSessionMessage(saved) as AgentSessionMessageEntity
|
||||||
|
}
|
||||||
|
|
||||||
|
private async persistAssistantMessage({
|
||||||
tx,
|
tx,
|
||||||
session,
|
session,
|
||||||
accumulator,
|
accumulator,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user