From e4fd1af1b8845c59f189857bc405129370f983c4 Mon Sep 17 00:00:00 2001 From: fullex <0xfullex@gmail.com> Date: Mon, 29 Dec 2025 16:59:00 +0800 Subject: [PATCH] feat(MessageService): optimize message retrieval with CTEs for improved performance - Enhanced `getTree` method to utilize Common Table Expressions (CTEs) for fetching active paths and tree structures in a single query, reducing database load. - Updated `getBranchMessages` to implement a similar optimization, allowing for efficient retrieval of message paths without loading all messages. - Refactored `getPathToNode` to use a recursive CTE for fetching ancestors, addressing the N+1 query problem in deep message trees. - Introduced transaction handling in `create` and `update` methods to ensure atomic operations and data integrity during message modifications. --- src/main/data/services/MessageService.ts | 439 +++++++++++++++-------- 1 file changed, 292 insertions(+), 147 deletions(-) diff --git a/src/main/data/services/MessageService.ts b/src/main/data/services/MessageService.ts index 9956983b6a..5f1c05e749 100644 --- a/src/main/data/services/MessageService.ts +++ b/src/main/data/services/MessageService.ts @@ -27,7 +27,7 @@ import type { TreeNode, TreeResponse } from '@shared/data/types/message' -import { eq, inArray, sql } from 'drizzle-orm' +import { and, eq, inArray, or, sql } from 'drizzle-orm' const logger = loggerService.withContext('MessageService') @@ -112,6 +112,11 @@ export class MessageService { /** * Get tree structure for visualization + * + * Optimized to avoid loading all messages: + * 1. Uses CTE to get active path (single query) + * 2. Uses CTE to get tree nodes within depth limit (single query) + * 3. Fetches additional nodes for active path if beyond depth limit */ async getTree( topicId: string, @@ -129,19 +134,108 @@ export class MessageService { const activeNodeId = options.nodeId || topic.activeNodeId - // Get all messages for this topic - const allMessages = await db.select().from(messageTable).where(eq(messageTable.topicId, topicId)) + // Find root node if not specified + let rootId = options.rootId + if (!rootId) { + const [root] = await db + .select({ id: messageTable.id }) + .from(messageTable) + .where(and(eq(messageTable.topicId, topicId), sql`${messageTable.parentId} IS NULL`)) + .limit(1) + rootId = root?.id + } - if (allMessages.length === 0) { + if (!rootId) { return { nodes: [], siblingsGroups: [], activeNodeId: null } } + // Build active path via CTE (single query) + const activePath = new Set() + if (activeNodeId) { + const pathRows = await db.all<{ id: string }>(sql` + WITH RECURSIVE path AS ( + SELECT id, parent_id FROM message WHERE id = ${activeNodeId} + UNION ALL + SELECT m.id, m.parent_id FROM message m + INNER JOIN path p ON m.id = p.parent_id + ) + SELECT id FROM path + `) + pathRows.forEach((r) => activePath.add(r.id)) + } + + // Get tree with depth limit via CTE + // Use a large depth for unlimited (-1) + const maxDepth = depth === -1 ? 999 : depth + + const treeRows = await db.all(sql` + WITH RECURSIVE tree AS ( + SELECT *, 0 as tree_depth FROM message WHERE id = ${rootId} + UNION ALL + SELECT m.*, t.tree_depth + 1 FROM message m + INNER JOIN tree t ON m.parent_id = t.id + WHERE t.tree_depth < ${maxDepth} + ) + SELECT * FROM tree + `) + + // Also fetch active path nodes that might be beyond depth limit + const treeNodeIds = new Set(treeRows.map((r) => r.id)) + const missingActivePathIds = [...activePath].filter((id) => !treeNodeIds.has(id)) + + if (missingActivePathIds.length > 0) { + const additionalRows = await db.select().from(messageTable).where(inArray(messageTable.id, missingActivePathIds)) + treeRows.push(...additionalRows.map((r) => ({ ...r, tree_depth: maxDepth + 1 }))) + } + + // Also need children of active path nodes for proper tree building + // Get all children of active path nodes that we haven't loaded yet + const activePathArray = [...activePath] + if (activePathArray.length > 0 && treeNodeIds.size > 0) { + const childrenRows = await db + .select() + .from(messageTable) + .where( + and( + inArray(messageTable.parentId, activePathArray), + sql`${messageTable.id} NOT IN (${sql.join( + [...treeNodeIds].map((id) => sql`${id}`), + sql`, ` + )})` + ) + ) + + for (const row of childrenRows) { + if (!treeNodeIds.has(row.id)) { + treeRows.push({ ...row, tree_depth: maxDepth + 1 }) + treeNodeIds.add(row.id) + } + } + } else if (activePathArray.length > 0) { + // No tree nodes loaded yet, just get all children of active path + const childrenRows = await db.select().from(messageTable).where(inArray(messageTable.parentId, activePathArray)) + + for (const row of childrenRows) { + if (!treeNodeIds.has(row.id)) { + treeRows.push({ ...row, tree_depth: maxDepth + 1 }) + treeNodeIds.add(row.id) + } + } + } + + if (treeRows.length === 0) { + return { nodes: [], siblingsGroups: [], activeNodeId: null } + } + + // Build maps for tree processing const messagesById = new Map() const childrenMap = new Map() + const depthMap = new Map() - for (const row of allMessages) { + for (const row of treeRows) { const message = rowToMessage(row) messagesById.set(message.id, message) + depthMap.set(message.id, row.tree_depth) const parentId = message.parentId || 'root' if (!childrenMap.has(parentId)) { @@ -150,21 +244,6 @@ export class MessageService { childrenMap.get(parentId)!.push(message.id) } - // Find root node(s) and build active path - const rootIds = childrenMap.get('root') || [] - const rootId = options.rootId || rootIds[0] - - // Build path from rootId to activeNodeId - const activePath = new Set() - if (activeNodeId) { - let currentId: string | null = activeNodeId - while (currentId) { - activePath.add(currentId) - const message = messagesById.get(currentId) - currentId = message?.parentId || null - } - } - // Collect nodes based on depth const resultNodes: TreeNode[] = [] const siblingsGroups: SiblingsGroup[] = [] @@ -187,7 +266,7 @@ export class MessageService { const parentChildren = childrenMap.get(message.parentId || 'root') || [] const groupMembers = parentChildren .map((id) => messagesById.get(id)!) - .filter((m) => m.siblingsGroupId === message.siblingsGroupId) + .filter((m) => m && m.siblingsGroupId === message.siblingsGroupId) if (groupMembers.length > 1) { siblingsGroups.push({ @@ -221,9 +300,7 @@ export class MessageService { } // Start from root - if (rootId) { - collectNodes(rootId, 0, activePath.has(rootId)) - } + collectNodes(rootId, 0, activePath.has(rootId)) return { nodes: resultNodes, @@ -234,6 +311,10 @@ export class MessageService { /** * Get branch messages for conversation view + * + * Optimized implementation using recursive CTE to fetch only the path + * from nodeId to root, avoiding loading all messages for large topics. + * Siblings are batch-queried in a single additional query. */ async getBranchMessages( topicId: string, @@ -249,76 +330,108 @@ export class MessageService { throw DataApiErrorFactory.notFound('Topic', topicId) } - // Get all messages for this topic - const allMessages = await db.select().from(messageTable).where(eq(messageTable.topicId, topicId)) + const nodeId = options.nodeId || topic.activeNodeId - if (allMessages.length === 0) { + // Return empty if no active node + if (!nodeId) { return { messages: [], activeNodeId: null } } - // Check for data inconsistency - if (!topic.activeNodeId) { - throw DataApiErrorFactory.dataInconsistent('Topic', 'has messages but no active node') + // Use recursive CTE to get path from nodeId to root (single query) + const pathMessages = await db.all(sql` + WITH RECURSIVE path AS ( + SELECT * FROM message WHERE id = ${nodeId} + UNION ALL + SELECT m.* FROM message m + INNER JOIN path p ON m.id = p.parent_id + ) + SELECT * FROM path + `) + + if (pathMessages.length === 0) { + throw DataApiErrorFactory.notFound('Message', nodeId) } - const nodeId = options.nodeId || topic.activeNodeId - const messagesById = new Map() - - for (const row of allMessages) { - messagesById.set(row.id, rowToMessage(row)) - } - - // Build path from root to nodeId - const path: string[] = [] - let currentId: string | null = nodeId - while (currentId) { - path.unshift(currentId) - const message = messagesById.get(currentId) - if (!message) { - throw DataApiErrorFactory.notFound('Message', currentId) - } - currentId = message.parentId - } + // Reverse to get root->nodeId order + const fullPath = pathMessages.reverse() // Apply pagination let startIndex = 0 + let endIndex = fullPath.length + if (options.beforeNodeId) { - const beforeIndex = path.indexOf(options.beforeNodeId) + const beforeIndex = fullPath.findIndex((m) => m.id === options.beforeNodeId) if (beforeIndex === -1) { throw DataApiErrorFactory.notFound('Message', options.beforeNodeId) } startIndex = Math.max(0, beforeIndex - limit) + endIndex = beforeIndex } else { - startIndex = Math.max(0, path.length - limit) + startIndex = Math.max(0, fullPath.length - limit) } - const endIndex = options.beforeNodeId ? path.indexOf(options.beforeNodeId) : path.length - - const resultPath = path.slice(startIndex, endIndex) + const paginatedPath = fullPath.slice(startIndex, endIndex) // Build result with optional siblings const result: BranchMessage[] = [] - for (const msgId of resultPath) { - const message = messagesById.get(msgId)! + if (includeSiblings) { + // Collect unique (parentId, siblingsGroupId) pairs that need siblings + const uniqueGroups = new Set() + const groupsToQuery: Array<{ parentId: string; siblingsGroupId: number }> = [] - let siblingsGroup: Message[] | undefined - if (includeSiblings && message.siblingsGroupId !== 0) { - // Find siblings with same parentId and siblingsGroupId - siblingsGroup = allMessages - .filter( - (row) => - row.parentId === message.parentId && - row.siblingsGroupId === message.siblingsGroupId && - row.id !== message.id - ) - .map(rowToMessage) + for (const msg of paginatedPath) { + if (msg.siblingsGroupId && msg.siblingsGroupId !== 0 && msg.parentId) { + const key = `${msg.parentId}-${msg.siblingsGroupId}` + if (!uniqueGroups.has(key)) { + uniqueGroups.add(key) + groupsToQuery.push({ parentId: msg.parentId, siblingsGroupId: msg.siblingsGroupId }) + } + } } - result.push({ - message, - siblingsGroup - }) + // Batch query all siblings if needed + const siblingsMap = new Map() + + if (groupsToQuery.length > 0) { + // Build OR conditions for batch query + const orConditions = groupsToQuery.map((g) => + and(eq(messageTable.parentId, g.parentId), eq(messageTable.siblingsGroupId, g.siblingsGroupId)) + ) + + const siblingsRows = await db + .select() + .from(messageTable) + .where(or(...orConditions)) + + // Group results by parentId-siblingsGroupId + for (const row of siblingsRows) { + const key = `${row.parentId}-${row.siblingsGroupId}` + if (!siblingsMap.has(key)) siblingsMap.set(key, []) + siblingsMap.get(key)!.push(rowToMessage(row)) + } + } + + // Build result with siblings from map + for (const msg of paginatedPath) { + const message = rowToMessage(msg) + let siblingsGroup: Message[] | undefined + + if (msg.siblingsGroupId !== 0 && msg.parentId) { + const key = `${msg.parentId}-${msg.siblingsGroupId}` + const group = siblingsMap.get(key) + if (group && group.length > 1) { + siblingsGroup = group.filter((m) => m.id !== message.id) + } + } + + result.push({ message, siblingsGroup }) + } + } else { + // No siblings needed, just map messages + for (const msg of paginatedPath) { + result.push({ message: rowToMessage(msg) }) + } } return { @@ -344,94 +457,114 @@ export class MessageService { /** * Create a new message + * + * Uses transaction to ensure atomicity of: + * - Topic existence validation + * - Parent message validation (if specified) + * - Message insertion + * - Topic activeNodeId update */ async create(topicId: string, dto: CreateMessageDto): Promise { const db = dbService.getDb() - // Verify topic exists - const [topic] = await db.select().from(topicTable).where(eq(topicTable.id, topicId)).limit(1) + return await db.transaction(async (tx) => { + // Verify topic exists + const [topic] = await tx.select().from(topicTable).where(eq(topicTable.id, topicId)).limit(1) - if (!topic) { - throw DataApiErrorFactory.notFound('Topic', topicId) - } - - // Verify parent exists if specified - if (dto.parentId) { - const [parent] = await db.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1) - - if (!parent) { - throw DataApiErrorFactory.notFound('Message', dto.parentId) + if (!topic) { + throw DataApiErrorFactory.notFound('Topic', topicId) } - } - const [row] = await db - .insert(messageTable) - .values({ - topicId, - parentId: dto.parentId, - role: dto.role, - data: dto.data, - status: dto.status ?? 'pending', - siblingsGroupId: dto.siblingsGroupId ?? 0, - assistantId: dto.assistantId, - assistantMeta: dto.assistantMeta, - modelId: dto.modelId, - modelMeta: dto.modelMeta, - traceId: dto.traceId, - stats: dto.stats - }) - .returning() - - // Update activeNodeId if setAsActive is not explicitly false - if (dto.setAsActive !== false) { - await db.update(topicTable).set({ activeNodeId: row.id }).where(eq(topicTable.id, topicId)) - } - - logger.info('Created message', { id: row.id, topicId, role: dto.role, setAsActive: dto.setAsActive !== false }) - - return rowToMessage(row) - } - - /** - * Update a message - */ - async update(id: string, dto: UpdateMessageDto): Promise { - const db = dbService.getDb() - - // Get existing message - const existing = await this.getById(id) - - // Check for cycle if moving to new parent - if (dto.parentId !== undefined && dto.parentId !== existing.parentId) { - if (dto.parentId !== null) { - // Check that new parent is not a descendant - const descendants = await this.getDescendantIds(id) - if (descendants.includes(dto.parentId)) { - throw DataApiErrorFactory.invalidOperation('move message', 'would create cycle') - } - - // Verify new parent exists - const [parent] = await db.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1) + // Verify parent exists if specified + if (dto.parentId) { + const [parent] = await tx.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1) if (!parent) { throw DataApiErrorFactory.notFound('Message', dto.parentId) } } + + const [row] = await tx + .insert(messageTable) + .values({ + topicId, + parentId: dto.parentId, + role: dto.role, + data: dto.data, + status: dto.status ?? 'pending', + siblingsGroupId: dto.siblingsGroupId ?? 0, + assistantId: dto.assistantId, + assistantMeta: dto.assistantMeta, + modelId: dto.modelId, + modelMeta: dto.modelMeta, + traceId: dto.traceId, + stats: dto.stats + }) + .returning() + + // Update activeNodeId if setAsActive is not explicitly false + if (dto.setAsActive !== false) { + await tx.update(topicTable).set({ activeNodeId: row.id }).where(eq(topicTable.id, topicId)) + } + + logger.info('Created message', { id: row.id, topicId, role: dto.role, setAsActive: dto.setAsActive !== false }) + + return rowToMessage(row) + }) + } + + /** + * Update a message + * + * Uses transaction to ensure atomicity of validation and update. + * Cycle check is performed outside transaction as a read-only safety check. + */ + async update(id: string, dto: UpdateMessageDto): Promise { + const db = dbService.getDb() + + // Pre-transaction: Check for cycle if moving to new parent + // This is done outside transaction since getDescendantIds uses its own db context + // and cycle check is a safety check (worst case: reject valid operation) + if (dto.parentId !== undefined && dto.parentId !== null) { + const descendants = await this.getDescendantIds(id) + if (descendants.includes(dto.parentId)) { + throw DataApiErrorFactory.invalidOperation('move message', 'would create cycle') + } } - // Build update object - const updates: Partial = {} + return await db.transaction(async (tx) => { + // Get existing message within transaction + const [existingRow] = await tx.select().from(messageTable).where(eq(messageTable.id, id)).limit(1) - if (dto.data !== undefined) updates.data = dto.data - if (dto.parentId !== undefined) updates.parentId = dto.parentId - if (dto.siblingsGroupId !== undefined) updates.siblingsGroupId = dto.siblingsGroupId - if (dto.status !== undefined) updates.status = dto.status + if (!existingRow) { + throw DataApiErrorFactory.notFound('Message', id) + } - const [row] = await db.update(messageTable).set(updates).where(eq(messageTable.id, id)).returning() + const existing = rowToMessage(existingRow) - logger.info('Updated message', { id, changes: Object.keys(dto) }) + // Verify new parent exists if changing parent + if (dto.parentId !== undefined && dto.parentId !== existing.parentId && dto.parentId !== null) { + const [parent] = await tx.select().from(messageTable).where(eq(messageTable.id, dto.parentId)).limit(1) - return rowToMessage(row) + if (!parent) { + throw DataApiErrorFactory.notFound('Message', dto.parentId) + } + } + + // Build update object + const updates: Partial = {} + + if (dto.data !== undefined) updates.data = dto.data + if (dto.parentId !== undefined) updates.parentId = dto.parentId + if (dto.siblingsGroupId !== undefined) updates.siblingsGroupId = dto.siblingsGroupId + if (dto.status !== undefined) updates.status = dto.status + + const [row] = await tx.update(messageTable).set(updates).where(eq(messageTable.id, id)).returning() + + logger.info('Updated message', { id, changes: Object.keys(dto) }) + + return rowToMessage(row) + }) } /** @@ -573,18 +706,30 @@ export class MessageService { /** * Get path from root to a node + * + * Uses recursive CTE to fetch all ancestors in a single query, + * avoiding N+1 query problem for deep message trees. */ async getPathToNode(nodeId: string): Promise { - const path: Message[] = [] - let currentId: string | null = nodeId + const db = dbService.getDb() - while (currentId) { - const message = await this.getById(currentId) - path.unshift(message) - currentId = message.parentId + // Use recursive CTE to get all ancestors in one query + const result = await db.all(sql` + WITH RECURSIVE ancestors AS ( + SELECT * FROM message WHERE id = ${nodeId} + UNION ALL + SELECT m.* FROM message m + INNER JOIN ancestors a ON m.id = a.parent_id + ) + SELECT * FROM ancestors + `) + + if (result.length === 0) { + throw DataApiErrorFactory.notFound('Message', nodeId) } - return path + // Result is from nodeId to root, reverse to get root to nodeId + return result.reverse().map(rowToMessage) } }