mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 13:59:28 +08:00
refactor: extract shared token counting logic in messages routes
Extract duplicated token estimation code from both count_tokens endpoints into a shared `estimateTokenCount` function to improve maintainability. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
f163c4d3ee
commit
77c1b77113
@ -41,6 +41,51 @@ const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||
const router = express.Router()
|
||||
const providerRouter = express.Router({ mergeParams: true })
|
||||
|
||||
/**
|
||||
* Estimate token count from messages
|
||||
* Simple approximation: ~4 characters per token for English text
|
||||
*/
|
||||
interface CountTokensInput {
|
||||
messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }>
|
||||
system?: string | Array<{ type: string; text?: string }>
|
||||
}
|
||||
|
||||
function estimateTokenCount(input: CountTokensInput): number {
|
||||
const { messages, system } = input
|
||||
let totalChars = 0
|
||||
|
||||
// Count system message tokens
|
||||
if (system) {
|
||||
if (typeof system === 'string') {
|
||||
totalChars += system.length
|
||||
} else if (Array.isArray(system)) {
|
||||
for (const block of system) {
|
||||
if (block.type === 'text' && block.text) {
|
||||
totalChars += block.text.length
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count message tokens
|
||||
for (const msg of messages) {
|
||||
if (typeof msg.content === 'string') {
|
||||
totalChars += msg.content.length
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'text' && block.text) {
|
||||
totalChars += block.text.length
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add overhead for role
|
||||
totalChars += 10
|
||||
}
|
||||
|
||||
// Estimate tokens (~4 chars per token, with some overhead)
|
||||
return Math.ceil(totalChars / 4) + messages.length * 3
|
||||
}
|
||||
|
||||
// Helper function for basic request validation
|
||||
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
||||
const request: MessageCreateParams = req.body
|
||||
@ -589,45 +634,11 @@ router.post('/count_tokens', async (req: Request, res: Response) => {
|
||||
})
|
||||
}
|
||||
|
||||
// Simple token estimation based on character count
|
||||
// This is a rough approximation: ~4 characters per token for English text
|
||||
let totalChars = 0
|
||||
|
||||
// Count system message tokens
|
||||
if (system) {
|
||||
if (typeof system === 'string') {
|
||||
totalChars += system.length
|
||||
} else if (Array.isArray(system)) {
|
||||
for (const block of system) {
|
||||
if (block.type === 'text' && block.text) {
|
||||
totalChars += block.text.length
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count message tokens
|
||||
for (const msg of messages) {
|
||||
if (typeof msg.content === 'string') {
|
||||
totalChars += msg.content.length
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'text' && block.text) {
|
||||
totalChars += block.text.length
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add overhead for role
|
||||
totalChars += 10
|
||||
}
|
||||
|
||||
// Estimate tokens (~4 chars per token, with some overhead)
|
||||
const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3
|
||||
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||
|
||||
logger.debug('Token count estimated', {
|
||||
model,
|
||||
messageCount: messages.length,
|
||||
totalChars,
|
||||
estimatedTokens
|
||||
})
|
||||
|
||||
@ -663,35 +674,7 @@ providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
|
||||
})
|
||||
}
|
||||
|
||||
// Simple token estimation
|
||||
let totalChars = 0
|
||||
|
||||
if (system) {
|
||||
if (typeof system === 'string') {
|
||||
totalChars += system.length
|
||||
} else if (Array.isArray(system)) {
|
||||
for (const block of system) {
|
||||
if (block.type === 'text' && block.text) {
|
||||
totalChars += block.text.length
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const msg of messages) {
|
||||
if (typeof msg.content === 'string') {
|
||||
totalChars += msg.content.length
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'text' && block.text) {
|
||||
totalChars += block.text.length
|
||||
}
|
||||
}
|
||||
}
|
||||
totalChars += 10
|
||||
}
|
||||
|
||||
const estimatedTokens = Math.ceil(totalChars / 4) + messages.length * 3
|
||||
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||
|
||||
logger.debug('Token count estimated (provider route)', {
|
||||
providerId: req.params.provider,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user