diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index f0eaac8e4e..907b498273 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -41,6 +41,51 @@ const logger = loggerService.withContext('ApiServerMessagesRoutes') const router = express.Router() const providerRouter = express.Router({ mergeParams: true }) +/** + * Estimate token count from messages + * Simple approximation: ~4 characters per token for English text + */ +interface CountTokensInput { + messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }> + system?: string | Array<{ type: string; text?: string }> +} + +function estimateTokenCount(input: CountTokensInput): number { + const { messages, system } = input + let totalChars = 0 + + // Count system message tokens + if (system) { + if (typeof system === 'string') { + totalChars += system.length + } else if (Array.isArray(system)) { + for (const block of system) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + } + + // Count message tokens + for (const msg of messages) { + if (typeof msg.content === 'string') { + totalChars += msg.content.length + } else if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'text' && block.text) { + totalChars += block.text.length + } + } + } + // Add overhead for role + totalChars += 10 + } + + // Estimate tokens (~4 chars per token, with some overhead) + return Math.ceil(totalChars / 4) + messages.length * 3 +} + // Helper function for basic request validation async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> { const request: MessageCreateParams = req.body @@ -589,45 +634,11 @@ router.post('/count_tokens', async (req: Request, res: Response) => { }) } - // Simple token estimation based on character count - // This is a rough approximation: ~4 characters per token for English text - let totalChars = 0 - - // Count system message tokens - if (system) { - if (typeof system === 'string') { - totalChars += system.length - } else if (Array.isArray(system)) { - for (const block of system) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - } - - // Count message tokens - for (const msg of messages) { - if (typeof msg.content === 'string') { - totalChars += msg.content.length - } else if (Array.isArray(msg.content)) { - for (const block of msg.content) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - // Add overhead for role - totalChars += 10 - } - - // Estimate tokens (~4 chars per token, with some overhead) - const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3 + const estimatedTokens = estimateTokenCount({ messages, system }) logger.debug('Token count estimated', { model, messageCount: messages.length, - totalChars, estimatedTokens }) @@ -663,35 +674,7 @@ providerRouter.post('/count_tokens', async (req: Request, res: Response) => { }) } - // Simple token estimation - let totalChars = 0 - - if (system) { - if (typeof system === 'string') { - totalChars += system.length - } else if (Array.isArray(system)) { - for (const block of system) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - } - - for (const msg of messages) { - if (typeof msg.content === 'string') { - totalChars += msg.content.length - } else if (Array.isArray(msg.content)) { - for (const block of msg.content) { - if (block.type === 'text' && block.text) { - totalChars += block.text.length - } - } - } - totalChars += 10 - } - - const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3 + const estimatedTokens = estimateTokenCount({ messages, system }) logger.debug('Token count estimated (provider route)', { providerId: req.params.provider,