feat(MessageService): optimize message retrieval with CTEs for improved performance

- Enhanced `getTree` method to utilize Common Table Expressions (CTEs) for fetching active paths and tree structures in a single query, reducing database load.
- Updated `getBranchMessages` to implement a similar optimization, allowing for efficient retrieval of message paths without loading all messages.
- Refactored `getPathToNode` to use a recursive CTE for fetching ancestors, addressing the N+1 query problem in deep message trees.
- Introduced transaction handling in `create` and `update` methods to ensure atomic operations and data integrity during message modifications.
This commit is contained in:
fullex 2025-12-29 16:59:00 +08:00
parent 3d0e7a6c15
commit e4fd1af1b8

View File

@ -27,7 +27,7 @@ import type {
TreeNode,
TreeResponse
} from '@shared/data/types/message'
import { eq, inArray, sql } from 'drizzle-orm'
import { and, eq, inArray, or, sql } from 'drizzle-orm'
const logger = loggerService.withContext('MessageService')
@ -112,6 +112,11 @@ export class MessageService {
/**
* Get tree structure for visualization
*
* Optimized to avoid loading all messages:
* 1. Uses CTE to get active path (single query)
* 2. Uses CTE to get tree nodes within depth limit (single query)
* 3. Fetches additional nodes for active path if beyond depth limit
*/
async getTree(
topicId: string,
@ -129,19 +134,108 @@ export class MessageService {
const activeNodeId = options.nodeId || topic.activeNodeId
// Get all messages for this topic
const allMessages = await db.select().from(messageTable).where(eq(messageTable.topicId, topicId))
// Find root node if not specified
let rootId = options.rootId
if (!rootId) {
const [root] = await db
.select({ id: messageTable.id })
.from(messageTable)
.where(and(eq(messageTable.topicId, topicId), sql`${messageTable.parentId} IS NULL`))
.limit(1)
rootId = root?.id
}
if (allMessages.length === 0) {
if (!rootId) {
return { nodes: [], siblingsGroups: [], activeNodeId: null }
}
// Build active path via CTE (single query)
const activePath = new Set<string>()
if (activeNodeId) {
const pathRows = await db.all<{ id: string }>(sql`
WITH RECURSIVE path AS (
SELECT id, parent_id FROM message WHERE id = ${activeNodeId}
UNION ALL
SELECT m.id, m.parent_id FROM message m
INNER JOIN path p ON m.id = p.parent_id
)
SELECT id FROM path
`)
pathRows.forEach((r) => activePath.add(r.id))
}
// Get tree with depth limit via CTE
// Use a large depth for unlimited (-1)
const maxDepth = depth === -1 ? 999 : depth
const treeRows = await db.all<typeof messageTable.$inferSelect & { tree_depth: number }>(sql`
WITH RECURSIVE tree AS (
SELECT *, 0 as tree_depth FROM message WHERE id = ${rootId}
UNION ALL
SELECT m.*, t.tree_depth + 1 FROM message m
INNER JOIN tree t ON m.parent_id = t.id
WHERE t.tree_depth < ${maxDepth}
)
SELECT * FROM tree
`)
// Also fetch active path nodes that might be beyond depth limit
const treeNodeIds = new Set(treeRows.map((r) => r.id))
const missingActivePathIds = [...activePath].filter((id) => !treeNodeIds.has(id))
if (missingActivePathIds.length > 0) {
const additionalRows = await db.select().from(messageTable).where(inArray(messageTable.id, missingActivePathIds))
treeRows.push(...additionalRows.map((r) => ({ ...r, tree_depth: maxDepth + 1 })))
}
// Also need children of active path nodes for proper tree building
// Get all children of active path nodes that we haven't loaded yet
const activePathArray = [...activePath]
if (activePathArray.length > 0 && treeNodeIds.size > 0) {
const childrenRows = await db
.select()
.from(messageTable)
.where(
and(
inArray(messageTable.parentId, activePathArray),
sql`${messageTable.id} NOT IN (${sql.join(
[...treeNodeIds].map((id) => sql`${id}`),
sql`, `
)})`
)
)
for (const row of childrenRows) {
if (!treeNodeIds.has(row.id)) {
treeRows.push({ ...row, tree_depth: maxDepth + 1 })
treeNodeIds.add(row.id)
}
}
} else if (activePathArray.length > 0) {
// No tree nodes loaded yet, just get all children of active path
const childrenRows = await db.select().from(messageTable).where(inArray(messageTable.parentId, activePathArray))
for (const row of childrenRows) {
if (!treeNodeIds.has(row.id)) {
treeRows.push({ ...row, tree_depth: maxDepth + 1 })
treeNodeIds.add(row.id)
}
}
}
if (treeRows.length === 0) {
return { nodes: [], siblingsGroups: [], activeNodeId: null }
}
// Build maps for tree processing
const messagesById = new Map<string, Message>()
const childrenMap = new Map<string, string[]>()
const depthMap = new Map<string, number>()
for (const row of allMessages) {
for (const row of treeRows) {
const message = rowToMessage(row)
messagesById.set(message.id, message)
depthMap.set(message.id, row.tree_depth)
const parentId = message.parentId || 'root'
if (!childrenMap.has(parentId)) {
@ -150,21 +244,6 @@ export class MessageService {
childrenMap.get(parentId)!.push(message.id)
}
// Find root node(s) and build active path
const rootIds = childrenMap.get('root') || []
const rootId = options.rootId || rootIds[0]
// Build path from rootId to activeNodeId
const activePath = new Set<string>()
if (activeNodeId) {
let currentId: string | null = activeNodeId
while (currentId) {
activePath.add(currentId)
const message = messagesById.get(currentId)
currentId = message?.parentId || null
}
}
// Collect nodes based on depth
const resultNodes: TreeNode[] = []
const siblingsGroups: SiblingsGroup[] = []
@ -187,7 +266,7 @@ export class MessageService {
const parentChildren = childrenMap.get(message.parentId || 'root') || []
const groupMembers = parentChildren
.map((id) => messagesById.get(id)!)
.filter((m) => m.siblingsGroupId === message.siblingsGroupId)
.filter((m) => m && m.siblingsGroupId === message.siblingsGroupId)
if (groupMembers.length > 1) {
siblingsGroups.push({
@ -221,9 +300,7 @@ export class MessageService {
}
// Start from root
if (rootId) {
collectNodes(rootId, 0, activePath.has(rootId))
}
collectNodes(rootId, 0, activePath.has(rootId))
return {
nodes: resultNodes,
@ -234,6 +311,10 @@ export class MessageService {
/**
* Get branch messages for conversation view
*
* Optimized implementation using recursive CTE to fetch only the path
* from nodeId to root, avoiding loading all messages for large topics.
* Siblings are batch-queried in a single additional query.
*/
async getBranchMessages(
topicId: string,
@ -249,76 +330,108 @@ export class MessageService {
throw DataApiErrorFactory.notFound('Topic', topicId)
}
// Get all messages for this topic
const allMessages = await db.select().from(messageTable).where(eq(messageTable.topicId, topicId))
const nodeId = options.nodeId || topic.activeNodeId
if (allMessages.length === 0) {
// Return empty if no active node
if (!nodeId) {
return { messages: [], activeNodeId: null }
}
// Check for data inconsistency
if (!topic.activeNodeId) {
throw DataApiErrorFactory.dataInconsistent('Topic', 'has messages but no active node')
// Use recursive CTE to get path from nodeId to root (single query)
const pathMessages = await db.all<typeof messageTable.$inferSelect>(sql`
WITH RECURSIVE path AS (
SELECT * FROM message WHERE id = ${nodeId}
UNION ALL
SELECT m.* FROM message m
INNER JOIN path p ON m.id = p.parent_id
)
SELECT * FROM path
`)
if (pathMessages.length === 0) {
throw DataApiErrorFactory.notFound('Message', nodeId)
}
const nodeId = options.nodeId || topic.activeNodeId
const messagesById = new Map<string, Message>()
for (const row of allMessages) {
messagesById.set(row.id, rowToMessage(row))
}
// Build path from root to nodeId
const path: string[] = []
let currentId: string | null = nodeId
while (currentId) {
path.unshift(currentId)
const message = messagesById.get(currentId)
if (!message) {
throw DataApiErrorFactory.notFound('Message', currentId)
}
currentId = message.parentId
}
// Reverse to get root->nodeId order
const fullPath = pathMessages.reverse()
// Apply pagination
let startIndex = 0
let endIndex = fullPath.length
if (options.beforeNodeId) {
const beforeIndex = path.indexOf(options.beforeNodeId)
const beforeIndex = fullPath.findIndex((m) => m.id === options.beforeNodeId)
if (beforeIndex === -1) {
throw DataApiErrorFactory.notFound('Message', options.beforeNodeId)
}
startIndex = Math.max(0, beforeIndex - limit)
endIndex = beforeIndex
} else {
startIndex = Math.max(0, path.length - limit)
startIndex = Math.max(0, fullPath.length - limit)
}
const endIndex = options.beforeNodeId ? path.indexOf(options.beforeNodeId) : path.length
const resultPath = path.slice(startIndex, endIndex)
const paginatedPath = fullPath.slice(startIndex, endIndex)
// Build result with optional siblings
const result: BranchMessage[] = []
for (const msgId of resultPath) {
const message = messagesById.get(msgId)!
if (includeSiblings) {
// Collect unique (parentId, siblingsGroupId) pairs that need siblings
const uniqueGroups = new Set<string>()
const groupsToQuery: Array<{ parentId: string; siblingsGroupId: number }> = []
let siblingsGroup: Message[] | undefined
if (includeSiblings && message.siblingsGroupId !== 0) {
// Find siblings with same parentId and siblingsGroupId
siblingsGroup = allMessages
.filter(
(row) =>
row.parentId === message.parentId &&
row.siblingsGroupId === message.siblingsGroupId &&
row.id !== message.id
)
.map(rowToMessage)
for (const msg of paginatedPath) {
if (msg.siblingsGroupId && msg.siblingsGroupId !== 0 && msg.parentId) {
const key = `${msg.parentId}-${msg.siblingsGroupId}`
if (!uniqueGroups.has(key)) {
uniqueGroups.add(key)
groupsToQuery.push({ parentId: msg.parentId, siblingsGroupId: msg.siblingsGroupId })
}
}
}
result.push({
message,
siblingsGroup
})
// Batch query all siblings if needed
const siblingsMap = new Map<string, Message[]>()
if (groupsToQuery.length > 0) {
// Build OR conditions for batch query
const orConditions = groupsToQuery.map((g) =>
and(eq(messageTable.parentId, g.parentId), eq(messageTable.siblingsGroupId, g.siblingsGroupId))
)
const siblingsRows = await db
.select()
.from(messageTable)
.where(or(...orConditions))
// Group results by parentId-siblingsGroupId
for (const row of siblingsRows) {
const key = `${row.parentId}-${row.siblingsGroupId}`
if (!siblingsMap.has(key)) siblingsMap.set(key, [])
siblingsMap.get(key)!.push(rowToMessage(row))
}
}
// Build result with siblings from map
for (const msg of paginatedPath) {
const message = rowToMessage(msg)
let siblingsGroup: Message[] | undefined
if (msg.siblingsGroupId !== 0 && msg.parentId) {
const key = `${msg.parentId}-${msg.siblingsGroupId}`
const group = siblingsMap.get(key)
if (group && group.length > 1) {
siblingsGroup = group.filter((m) => m.id !== message.id)
}
}
result.push({ message, siblingsGroup })
}
} else {
// No siblings needed, just map messages
for (const msg of paginatedPath) {
result.push({ message: rowToMessage(msg) })
}
}
return {
@ -344,94 +457,114 @@ export class MessageService {
/**
* Create a new message
*
* Uses transaction to ensure atomicity of:
* - Topic existence validation
* - Parent message validation (if specified)
* - Message insertion
* - Topic activeNodeId update
*/
async create(topicId: string, dto: CreateMessageDto): Promise<Message> {
const db = dbService.getDb()
// Verify topic exists
const [topic] = await db.select().from(topicTable).where(eq(topicTable.id, topicId)).limit(1)
return await db.transaction(async (tx) => {
// Verify topic exists
const [topic] = await tx.select().from(topicTable).where(eq(topicTable.id, topicId)).limit(1)
if (!topic) {
throw DataApiErrorFactory.notFound('Topic', topicId)
}
// Verify parent exists if specified
if (dto.parentId) {
const [parent] = await db.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1)
if (!parent) {
throw DataApiErrorFactory.notFound('Message', dto.parentId)
if (!topic) {
throw DataApiErrorFactory.notFound('Topic', topicId)
}
}
const [row] = await db
.insert(messageTable)
.values({
topicId,
parentId: dto.parentId,
role: dto.role,
data: dto.data,
status: dto.status ?? 'pending',
siblingsGroupId: dto.siblingsGroupId ?? 0,
assistantId: dto.assistantId,
assistantMeta: dto.assistantMeta,
modelId: dto.modelId,
modelMeta: dto.modelMeta,
traceId: dto.traceId,
stats: dto.stats
})
.returning()
// Update activeNodeId if setAsActive is not explicitly false
if (dto.setAsActive !== false) {
await db.update(topicTable).set({ activeNodeId: row.id }).where(eq(topicTable.id, topicId))
}
logger.info('Created message', { id: row.id, topicId, role: dto.role, setAsActive: dto.setAsActive !== false })
return rowToMessage(row)
}
/**
* Update a message
*/
async update(id: string, dto: UpdateMessageDto): Promise<Message> {
const db = dbService.getDb()
// Get existing message
const existing = await this.getById(id)
// Check for cycle if moving to new parent
if (dto.parentId !== undefined && dto.parentId !== existing.parentId) {
if (dto.parentId !== null) {
// Check that new parent is not a descendant
const descendants = await this.getDescendantIds(id)
if (descendants.includes(dto.parentId)) {
throw DataApiErrorFactory.invalidOperation('move message', 'would create cycle')
}
// Verify new parent exists
const [parent] = await db.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1)
// Verify parent exists if specified
if (dto.parentId) {
const [parent] = await tx.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1)
if (!parent) {
throw DataApiErrorFactory.notFound('Message', dto.parentId)
}
}
const [row] = await tx
.insert(messageTable)
.values({
topicId,
parentId: dto.parentId,
role: dto.role,
data: dto.data,
status: dto.status ?? 'pending',
siblingsGroupId: dto.siblingsGroupId ?? 0,
assistantId: dto.assistantId,
assistantMeta: dto.assistantMeta,
modelId: dto.modelId,
modelMeta: dto.modelMeta,
traceId: dto.traceId,
stats: dto.stats
})
.returning()
// Update activeNodeId if setAsActive is not explicitly false
if (dto.setAsActive !== false) {
await tx.update(topicTable).set({ activeNodeId: row.id }).where(eq(topicTable.id, topicId))
}
logger.info('Created message', { id: row.id, topicId, role: dto.role, setAsActive: dto.setAsActive !== false })
return rowToMessage(row)
})
}
/**
* Update a message
*
* Uses transaction to ensure atomicity of validation and update.
* Cycle check is performed outside transaction as a read-only safety check.
*/
async update(id: string, dto: UpdateMessageDto): Promise<Message> {
const db = dbService.getDb()
// Pre-transaction: Check for cycle if moving to new parent
// This is done outside transaction since getDescendantIds uses its own db context
// and cycle check is a safety check (worst case: reject valid operation)
if (dto.parentId !== undefined && dto.parentId !== null) {
const descendants = await this.getDescendantIds(id)
if (descendants.includes(dto.parentId)) {
throw DataApiErrorFactory.invalidOperation('move message', 'would create cycle')
}
}
// Build update object
const updates: Partial<typeof messageTable.$inferInsert> = {}
return await db.transaction(async (tx) => {
// Get existing message within transaction
const [existingRow] = await tx.select().from(messageTable).where(eq(messageTable.id, id)).limit(1)
if (dto.data !== undefined) updates.data = dto.data
if (dto.parentId !== undefined) updates.parentId = dto.parentId
if (dto.siblingsGroupId !== undefined) updates.siblingsGroupId = dto.siblingsGroupId
if (dto.status !== undefined) updates.status = dto.status
if (!existingRow) {
throw DataApiErrorFactory.notFound('Message', id)
}
const [row] = await db.update(messageTable).set(updates).where(eq(messageTable.id, id)).returning()
const existing = rowToMessage(existingRow)
logger.info('Updated message', { id, changes: Object.keys(dto) })
// Verify new parent exists if changing parent
if (dto.parentId !== undefined && dto.parentId !== existing.parentId && dto.parentId !== null) {
const [parent] = await tx.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1)
return rowToMessage(row)
if (!parent) {
throw DataApiErrorFactory.notFound('Message', dto.parentId)
}
}
// Build update object
const updates: Partial<typeof messageTable.$inferInsert> = {}
if (dto.data !== undefined) updates.data = dto.data
if (dto.parentId !== undefined) updates.parentId = dto.parentId
if (dto.siblingsGroupId !== undefined) updates.siblingsGroupId = dto.siblingsGroupId
if (dto.status !== undefined) updates.status = dto.status
const [row] = await tx.update(messageTable).set(updates).where(eq(messageTable.id, id)).returning()
logger.info('Updated message', { id, changes: Object.keys(dto) })
return rowToMessage(row)
})
}
/**
@ -573,18 +706,30 @@ export class MessageService {
/**
* Get path from root to a node
*
* Uses recursive CTE to fetch all ancestors in a single query,
* avoiding N+1 query problem for deep message trees.
*/
async getPathToNode(nodeId: string): Promise<Message[]> {
const path: Message[] = []
let currentId: string | null = nodeId
const db = dbService.getDb()
while (currentId) {
const message = await this.getById(currentId)
path.unshift(message)
currentId = message.parentId
// Use recursive CTE to get all ancestors in one query
const result = await db.all<typeof messageTable.$inferSelect>(sql`
WITH RECURSIVE ancestors AS (
SELECT * FROM message WHERE id = ${nodeId}
UNION ALL
SELECT m.* FROM message m
INNER JOIN ancestors a ON m.id = a.parent_id
)
SELECT * FROM ancestors
`)
if (result.length === 0) {
throw DataApiErrorFactory.notFound('Message', nodeId)
}
return path
// Result is from nodeId to root, reverse to get root to nodeId
return result.reverse().map(rowToMessage)
}
}