feat(tokens): enhance token estimation and refactor count_tokens endpoint for improved handling

This commit is contained in:
suyao 2025-12-18 16:19:17 +08:00
parent 2c910322f8
commit 45d404e127
No known key found for this signature in database

View File

@ -46,20 +46,11 @@ const providerRouter = express.Router({ mergeParams: true })
/**
* Estimate token count from messages
* Uses tokenx library for accurate token estimation and supports images
* Uses tokenx library for accurate token estimation and supports images, tools
*/
interface CountTokensInput {
messages: Array<{
role: string
content:
| string
| Array<{
type: string
text?: string
source?: { type: string; media_type?: string; data?: string }
}>
}>
system?: string | Array<{ type: string; text?: string }>
messages: MessageCreateParams['messages']
system?: MessageCreateParams['system']
}
function estimateTokenCount(input: CountTokensInput): number {
@ -89,14 +80,40 @@ function estimateTokenCount(input: CountTokensInput): number {
totalTokens += approximateTokenSize(block.text)
} else if (block.type === 'image') {
// Image token estimation (consistent with TokenService)
// Base64 images: estimate from data length
if (block.source?.data) {
if (block.source.type === 'base64') {
// Base64 images: estimate from data length
const dataSize = block.source.data.length * 0.75 // base64 to bytes
totalTokens += Math.floor(dataSize / 100)
} else {
// Default image token estimate
// URL images: use default estimate
totalTokens += 1000
}
} else if (block.type === 'tool_use') {
// Tool use token estimation: name + input JSON
if (block.name) {
totalTokens += approximateTokenSize(block.name)
}
if (block.input) {
const inputJson = JSON.stringify(block.input)
totalTokens += approximateTokenSize(inputJson)
}
// Add overhead for tool use structure
totalTokens += 10
} else if (block.type === 'tool_result') {
// Tool result token estimation
if (typeof block.content === 'string') {
totalTokens += approximateTokenSize(block.content)
} else if (Array.isArray(block.content)) {
for (const item of block.content) {
if (typeof item === 'string') {
totalTokens += approximateTokenSize(item)
} else if (item.type === 'text' && item.text) {
totalTokens += approximateTokenSize(item.text)
}
}
}
// Add overhead for tool result structure
totalTokens += 10
}
}
}
@ -127,6 +144,70 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
return { valid: true }
}
/**
* Shared handler for count_tokens endpoint
* Validates request and returns token count estimation
*/
async function handleCountTokens(
req: Request,
res: Response,
options: {
requireModel?: boolean
logContext?: Record<string, any>
} = {}
): Promise<Response> {
try {
const { model, messages, system } = req.body
const { requireModel = false, logContext = {} } = options
// Validate model parameter if required
if (requireModel && !model) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'model parameter is required'
}
})
}
// Validate messages parameter
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
// Estimate token count
const estimatedTokens = estimateTokenCount({ messages, system })
// Log with context
logger.debug('Token count estimated', {
model,
messageCount: messages.length,
estimatedTokens,
...logContext
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
}
interface HandleMessageProcessingOptions {
res: Response
provider: Provider
@ -650,91 +731,17 @@ providerRouter.post('/', async (req: Request, res: Response) => {
* description: Bad request
*/
router.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!model) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'model parameter is required'
}
})
}
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated', {
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
return handleCountTokens(req, res, { requireModel: true })
})
/**
* Provider-specific count_tokens endpoint
*/
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated (provider route)', {
providerId: req.params.provider,
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
return handleCountTokens(req, res, {
requireModel: false,
logContext: { providerId: req.params.provider }
})
})
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }