mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-30 07:39:06 +08:00
feat(token): enhance token estimation using tokenx library for improved accuracy and image support
This commit is contained in:
parent
5304b585b9
commit
2c910322f8
@ -5,6 +5,7 @@ import { getAiSdkProviderId } from '@shared/provider'
|
|||||||
import type { Provider } from '@types'
|
import type { Provider } from '@types'
|
||||||
import type { Request, Response } from 'express'
|
import type { Request, Response } from 'express'
|
||||||
import express from 'express'
|
import express from 'express'
|
||||||
|
import { approximateTokenSize } from 'tokenx'
|
||||||
|
|
||||||
import { messagesService } from '../services/messages'
|
import { messagesService } from '../services/messages'
|
||||||
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
|
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
|
||||||
@ -45,25 +46,34 @@ const providerRouter = express.Router({ mergeParams: true })
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Estimate token count from messages
|
* Estimate token count from messages
|
||||||
* Simple approximation: ~4 characters per token for English text
|
* Uses tokenx library for accurate token estimation and supports images
|
||||||
*/
|
*/
|
||||||
interface CountTokensInput {
|
interface CountTokensInput {
|
||||||
messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }>
|
messages: Array<{
|
||||||
|
role: string
|
||||||
|
content:
|
||||||
|
| string
|
||||||
|
| Array<{
|
||||||
|
type: string
|
||||||
|
text?: string
|
||||||
|
source?: { type: string; media_type?: string; data?: string }
|
||||||
|
}>
|
||||||
|
}>
|
||||||
system?: string | Array<{ type: string; text?: string }>
|
system?: string | Array<{ type: string; text?: string }>
|
||||||
}
|
}
|
||||||
|
|
||||||
function estimateTokenCount(input: CountTokensInput): number {
|
function estimateTokenCount(input: CountTokensInput): number {
|
||||||
const { messages, system } = input
|
const { messages, system } = input
|
||||||
let totalChars = 0
|
let totalTokens = 0
|
||||||
|
|
||||||
// Count system message tokens
|
// Count system message tokens using tokenx
|
||||||
if (system) {
|
if (system) {
|
||||||
if (typeof system === 'string') {
|
if (typeof system === 'string') {
|
||||||
totalChars += system.length
|
totalTokens += approximateTokenSize(system)
|
||||||
} else if (Array.isArray(system)) {
|
} else if (Array.isArray(system)) {
|
||||||
for (const block of system) {
|
for (const block of system) {
|
||||||
if (block.type === 'text' && block.text) {
|
if (block.type === 'text' && block.text) {
|
||||||
totalChars += block.text.length
|
totalTokens += approximateTokenSize(block.text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -72,20 +82,29 @@ function estimateTokenCount(input: CountTokensInput): number {
|
|||||||
// Count message tokens
|
// Count message tokens
|
||||||
for (const msg of messages) {
|
for (const msg of messages) {
|
||||||
if (typeof msg.content === 'string') {
|
if (typeof msg.content === 'string') {
|
||||||
totalChars += msg.content.length
|
totalTokens += approximateTokenSize(msg.content)
|
||||||
} else if (Array.isArray(msg.content)) {
|
} else if (Array.isArray(msg.content)) {
|
||||||
for (const block of msg.content) {
|
for (const block of msg.content) {
|
||||||
if (block.type === 'text' && block.text) {
|
if (block.type === 'text' && block.text) {
|
||||||
totalChars += block.text.length
|
totalTokens += approximateTokenSize(block.text)
|
||||||
|
} else if (block.type === 'image') {
|
||||||
|
// Image token estimation (consistent with TokenService)
|
||||||
|
// Base64 images: estimate from data length
|
||||||
|
if (block.source?.data) {
|
||||||
|
const dataSize = block.source.data.length * 0.75 // base64 to bytes
|
||||||
|
totalTokens += Math.floor(dataSize / 100)
|
||||||
|
} else {
|
||||||
|
// Default image token estimate
|
||||||
|
totalTokens += 1000
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Add overhead for role
|
// Add role overhead
|
||||||
totalChars += 10
|
totalTokens += 3
|
||||||
}
|
}
|
||||||
|
|
||||||
// Estimate tokens (~4 chars per token, with some overhead)
|
return totalTokens
|
||||||
return Math.ceil(totalChars / 4) + messages.length * 3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function for basic request validation
|
// Helper function for basic request validation
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user