diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 1e18c86118..bcba993915 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -5,6 +5,7 @@ import { getAiSdkProviderId } from '@shared/provider' import type { Provider } from '@types' import type { Request, Response } from 'express' import express from 'express' +import { approximateTokenSize } from 'tokenx' import { messagesService } from '../services/messages' import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' @@ -45,25 +46,34 @@ const providerRouter = express.Router({ mergeParams: true }) /** * Estimate token count from messages - * Simple approximation: ~4 characters per token for English text + * Uses tokenx library for accurate token estimation and supports images */ interface CountTokensInput { - messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }> + messages: Array<{ + role: string + content: + | string + | Array<{ + type: string + text?: string + source?: { type: string; media_type?: string; data?: string } + }> + }> system?: string | Array<{ type: string; text?: string }> } function estimateTokenCount(input: CountTokensInput): number { const { messages, system } = input - let totalChars = 0 + let totalTokens = 0 - // Count system message tokens + // Count system message tokens using tokenx if (system) { if (typeof system === 'string') { - totalChars += system.length + totalTokens += approximateTokenSize(system) } else if (Array.isArray(system)) { for (const block of system) { if (block.type === 'text' && block.text) { - totalChars += block.text.length + totalTokens += approximateTokenSize(block.text) } } } @@ -72,20 +82,29 @@ function estimateTokenCount(input: CountTokensInput): number { // Count message tokens for (const msg of messages) { if (typeof msg.content === 'string') { - totalChars += msg.content.length + totalTokens += approximateTokenSize(msg.content) } else if (Array.isArray(msg.content)) { for (const block of msg.content) { if (block.type === 'text' && block.text) { - totalChars += block.text.length + totalTokens += approximateTokenSize(block.text) + } else if (block.type === 'image') { + // Image token estimation (consistent with TokenService) + // Base64 images: estimate from data length + if (block.source?.data) { + const dataSize = block.source.data.length * 0.75 // base64 to bytes + totalTokens += Math.floor(dataSize / 100) + } else { + // Default image token estimate + totalTokens += 1000 + } } } } - // Add overhead for role - totalChars += 10 + // Add role overhead + totalTokens += 3 } - // Estimate tokens (~4 chars per token, with some overhead) - return Math.ceil(totalChars / 4) + messages.length * 3 + return totalTokens } // Helper function for basic request validation