feat(token): enhance token estimation using tokenx library for improved accuracy and image support

This commit is contained in:
suyao 2025-12-18 15:59:27 +08:00
parent 5304b585b9
commit 2c910322f8
No known key found for this signature in database

View File

@ -5,6 +5,7 @@ import { getAiSdkProviderId } from '@shared/provider'
import type { Provider } from '@types' import type { Provider } from '@types'
import type { Request, Response } from 'express' import type { Request, Response } from 'express'
import express from 'express' import express from 'express'
import { approximateTokenSize } from 'tokenx'
import { messagesService } from '../services/messages' import { messagesService } from '../services/messages'
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
@ -45,25 +46,34 @@ const providerRouter = express.Router({ mergeParams: true })
/** /**
* Estimate token count from messages * Estimate token count from messages
* Simple approximation: ~4 characters per token for English text * Uses tokenx library for accurate token estimation and supports images
*/ */
interface CountTokensInput { interface CountTokensInput {
messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }> messages: Array<{
role: string
content:
| string
| Array<{
type: string
text?: string
source?: { type: string; media_type?: string; data?: string }
}>
}>
system?: string | Array<{ type: string; text?: string }> system?: string | Array<{ type: string; text?: string }>
} }
function estimateTokenCount(input: CountTokensInput): number { function estimateTokenCount(input: CountTokensInput): number {
const { messages, system } = input const { messages, system } = input
let totalChars = 0 let totalTokens = 0
// Count system message tokens // Count system message tokens using tokenx
if (system) { if (system) {
if (typeof system === 'string') { if (typeof system === 'string') {
totalChars += system.length totalTokens += approximateTokenSize(system)
} else if (Array.isArray(system)) { } else if (Array.isArray(system)) {
for (const block of system) { for (const block of system) {
if (block.type === 'text' && block.text) { if (block.type === 'text' && block.text) {
totalChars += block.text.length totalTokens += approximateTokenSize(block.text)
} }
} }
} }
@ -72,20 +82,29 @@ function estimateTokenCount(input: CountTokensInput): number {
// Count message tokens // Count message tokens
for (const msg of messages) { for (const msg of messages) {
if (typeof msg.content === 'string') { if (typeof msg.content === 'string') {
totalChars += msg.content.length totalTokens += approximateTokenSize(msg.content)
} else if (Array.isArray(msg.content)) { } else if (Array.isArray(msg.content)) {
for (const block of msg.content) { for (const block of msg.content) {
if (block.type === 'text' && block.text) { if (block.type === 'text' && block.text) {
totalChars += block.text.length totalTokens += approximateTokenSize(block.text)
} else if (block.type === 'image') {
// Image token estimation (consistent with TokenService)
// Base64 images: estimate from data length
if (block.source?.data) {
const dataSize = block.source.data.length * 0.75 // base64 to bytes
totalTokens += Math.floor(dataSize / 100)
} else {
// Default image token estimate
totalTokens += 1000
}
} }
} }
} }
// Add overhead for role // Add role overhead
totalChars += 10 totalTokens += 3
} }
// Estimate tokens (~4 chars per token, with some overhead) return totalTokens
return Math.ceil(totalChars / 4) + messages.length * 3
} }
// Helper function for basic request validation // Helper function for basic request validation