From a2d8beafcfc8c32750ea20ffbb4faefacec30bad Mon Sep 17 00:00:00 2001 From: one Date: Fri, 9 May 2025 14:17:28 +0800 Subject: [PATCH] fix: user message usage (#5657) * fix: user message usage estimation * fix: estimate usage on resending edited user message * refactor: renaming * refactor: renaming --- .../src/hooks/useMessageOperations.ts | 8 +++ .../src/pages/home/Inputbar/Inputbar.tsx | 4 +- src/renderer/src/services/MessagesService.ts | 9 ++- src/renderer/src/services/TokenService.ts | 72 ++++++++++++++----- 4 files changed, 70 insertions(+), 23 deletions(-) diff --git a/src/renderer/src/hooks/useMessageOperations.ts b/src/renderer/src/hooks/useMessageOperations.ts index 1e08c4934e..40cda0213f 100644 --- a/src/renderer/src/hooks/useMessageOperations.ts +++ b/src/renderer/src/hooks/useMessageOperations.ts @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit' import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' +import { estimateUserPromptUsage } from '@renderer/services/TokenService' import store, { type RootState, useAppDispatch, useAppSelector } from '@renderer/store' import { messageBlocksSelectors, updateOneBlock } from '@renderer/store/messageBlock' import { newMessagesActions, selectMessagesForTopic } from '@renderer/store/newMessage' @@ -20,6 +21,7 @@ import type { Assistant, Model, Topic } from '@renderer/types' import type { Message, MessageBlock } from '@renderer/types/newMessage' import { MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' import { abortCompletion } from '@renderer/utils/abortController' +import { findFileBlocks } from '@renderer/utils/messageUtils/find' import { useCallback } from 'react' const findMainTextBlockId = (message: Message): string | undefined => { @@ -128,6 +130,12 @@ export function useMessageOperations(topic: Topic) { return } + const files = findFileBlocks(message).map((block) => block.file) + + const usage = await estimateUserPromptUsage({ content: editedContent, files }) + + await dispatch(updateMessageAndBlocksThunk(topic.id, { id: message.id, usage }, [])) + await dispatch(resendUserMessageWithEditThunk(topic.id, message, mainTextBlockId, editedContent, assistant)) }, [dispatch, topic.id] diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index e2dd58a730..ace06a98d6 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -22,7 +22,7 @@ import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' import FileManager from '@renderer/services/FileManager' import { checkRateLimit, getUserMessage } from '@renderer/services/MessagesService' import { getModelUniqId } from '@renderer/services/ModelService' -import { estimateMessageUsage, estimateTextTokens as estimateTxtTokens } from '@renderer/services/TokenService' +import { estimateTextTokens as estimateTxtTokens, estimateUserPromptUsage } from '@renderer/services/TokenService' import { translateText } from '@renderer/services/TranslateService' import WebSearchService from '@renderer/services/WebSearchService' import { useAppDispatch } from '@renderer/store' @@ -215,7 +215,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = ) } - baseUserMessage.usage = await estimateMessageUsage(baseUserMessage) + baseUserMessage.usage = await estimateUserPromptUsage(baseUserMessage) const { message, blocks } = getUserMessage(baseUserMessage) diff --git a/src/renderer/src/services/MessagesService.ts b/src/renderer/src/services/MessagesService.ts index cef91352bd..d1dea18b9f 100644 --- a/src/renderer/src/services/MessagesService.ts +++ b/src/renderer/src/services/MessagesService.ts @@ -6,7 +6,7 @@ import { fetchMessagesSummary } from '@renderer/services/ApiService' import store from '@renderer/store' import { messageBlocksSelectors, removeManyBlocks } from '@renderer/store/messageBlock' import { selectMessagesForTopic } from '@renderer/store/newMessage' -import type { Assistant, FileType, MCPServer, Model, Topic } from '@renderer/types' +import type { Assistant, FileType, MCPServer, Model, Topic, Usage } from '@renderer/types' import { FileTypes } from '@renderer/types' import type { Message, MessageBlock } from '@renderer/types/newMessage' import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage' @@ -110,7 +110,8 @@ export function getUserMessage({ // Keep other potential params if needed by createMessage knowledgeBaseIds, mentions, - enabledMCPs + enabledMCPs, + usage }: { assistant: Assistant topic: Topic @@ -120,6 +121,7 @@ export function getUserMessage({ knowledgeBaseIds?: string[] mentions?: Model[] enabledMCPs?: MCPServer[] + usage?: Usage }): { message: Message; blocks: MessageBlock[] } { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel @@ -163,7 +165,8 @@ export function getUserMessage({ // 移除knowledgeBaseIds mentions, enabledMCPs, - type + type, + usage } ) diff --git a/src/renderer/src/services/TokenService.ts b/src/renderer/src/services/TokenService.ts index ceac89154c..0247d042ee 100644 --- a/src/renderer/src/services/TokenService.ts +++ b/src/renderer/src/services/TokenService.ts @@ -1,5 +1,5 @@ import { Assistant, FileType, FileTypes, Usage } from '@renderer/types' -import type { Message, MessageInputBaseParams } from '@renderer/types/newMessage' +import type { Message } from '@renderer/types/newMessage' import { findFileBlocks, getMainTextContent, getThinkingContent } from '@renderer/utils/messageUtils/find' import { flatten, takeRight } from 'lodash' import { approximateTokenSize } from 'tokenx' @@ -56,16 +56,59 @@ export function estimateImageTokens(file: FileType) { return Math.floor(file.size / 100) } -export async function estimateMessageUsage(message: Partial, params?: MessageInputBaseParams): Promise { +/** + * 估算用户输入内容(文本和文件)的 token 用量。 + * + * 该函数只根据传入的 content(文本内容)和 files(文件列表)估算, + * 不依赖完整的 Message 结构,也不会处理消息块、上下文等信息。 + * + * @param {Object} params - 输入参数对象 + * @param {string} [params.content] - 用户输入的文本内容 + * @param {FileType[]} [params.files] - 用户上传的文件列表(支持图片和文本) + * @returns {Promise} 返回一个 Usage 对象,包含 prompt_tokens、completion_tokens、total_tokens + */ +export async function estimateUserPromptUsage({ + content, + files +}: { + content?: string + files?: FileType[] +}): Promise { let imageTokens = 0 - let files: FileType[] = [] - if (params?.files) { - files = params.files - } else { - const fileBlocks = findFileBlocks(message as Message) - files = fileBlocks.map((f) => f.file) + + if (files && files.length > 0) { + const images = files.filter((f) => f.type === FileTypes.IMAGE) + if (images.length > 0) { + for (const image of images) { + imageTokens = estimateImageTokens(image) + imageTokens + } + } } + const tokens = estimateTextTokens(content || '') + + return { + prompt_tokens: tokens, + completion_tokens: tokens, + total_tokens: tokens + (imageTokens ? imageTokens - 7 : 0) + } +} + +/** + * 估算完整消息(Message)的 token 用量。 + * + * 该函数会自动从 message 中提取主文本内容、推理内容(reasoningContent)和所有文件块, + * 统计文本和图片的 token 数量,适用于对完整消息对象进行 usage 估算。 + * + * @param {Partial} message - 消息对象,可以是完整或部分 Message + * @returns {Promise} 返回一个 Usage 对象,包含 prompt_tokens、completion_tokens、total_tokens + */ +export async function estimateMessageUsage(message: Partial): Promise { + const fileBlocks = findFileBlocks(message as Message) + const files = fileBlocks.map((f) => f.file) + + let imageTokens = 0 + if (files.length > 0) { const images = files.filter((f) => f.type === FileTypes.IMAGE) if (images.length > 0) { @@ -74,16 +117,9 @@ export async function estimateMessageUsage(message: Partial, params?: M } } } - let content = '' - if (params?.content) { - content = params.content - } else { - content = getMainTextContent(message as Message) - } - let reasoningContent = '' - if (!params) { - reasoningContent = getThinkingContent(message as Message) - } + + const content = getMainTextContent(message as Message) + const reasoningContent = getThinkingContent(message as Message) const combinedContent = [content, reasoningContent].filter((s) => s !== undefined).join(' ') const tokens = estimateTextTokens(combinedContent)