mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat(tokens): enhance token estimation and refactor count_tokens endpoint for improved handling
This commit is contained in:
parent
2c910322f8
commit
45d404e127
@ -46,20 +46,11 @@ const providerRouter = express.Router({ mergeParams: true })
|
||||
|
||||
/**
|
||||
* Estimate token count from messages
|
||||
* Uses tokenx library for accurate token estimation and supports images
|
||||
* Uses tokenx library for accurate token estimation and supports images, tools
|
||||
*/
|
||||
interface CountTokensInput {
|
||||
messages: Array<{
|
||||
role: string
|
||||
content:
|
||||
| string
|
||||
| Array<{
|
||||
type: string
|
||||
text?: string
|
||||
source?: { type: string; media_type?: string; data?: string }
|
||||
}>
|
||||
}>
|
||||
system?: string | Array<{ type: string; text?: string }>
|
||||
messages: MessageCreateParams['messages']
|
||||
system?: MessageCreateParams['system']
|
||||
}
|
||||
|
||||
function estimateTokenCount(input: CountTokensInput): number {
|
||||
@ -89,14 +80,40 @@ function estimateTokenCount(input: CountTokensInput): number {
|
||||
totalTokens += approximateTokenSize(block.text)
|
||||
} else if (block.type === 'image') {
|
||||
// Image token estimation (consistent with TokenService)
|
||||
// Base64 images: estimate from data length
|
||||
if (block.source?.data) {
|
||||
if (block.source.type === 'base64') {
|
||||
// Base64 images: estimate from data length
|
||||
const dataSize = block.source.data.length * 0.75 // base64 to bytes
|
||||
totalTokens += Math.floor(dataSize / 100)
|
||||
} else {
|
||||
// Default image token estimate
|
||||
// URL images: use default estimate
|
||||
totalTokens += 1000
|
||||
}
|
||||
} else if (block.type === 'tool_use') {
|
||||
// Tool use token estimation: name + input JSON
|
||||
if (block.name) {
|
||||
totalTokens += approximateTokenSize(block.name)
|
||||
}
|
||||
if (block.input) {
|
||||
const inputJson = JSON.stringify(block.input)
|
||||
totalTokens += approximateTokenSize(inputJson)
|
||||
}
|
||||
// Add overhead for tool use structure
|
||||
totalTokens += 10
|
||||
} else if (block.type === 'tool_result') {
|
||||
// Tool result token estimation
|
||||
if (typeof block.content === 'string') {
|
||||
totalTokens += approximateTokenSize(block.content)
|
||||
} else if (Array.isArray(block.content)) {
|
||||
for (const item of block.content) {
|
||||
if (typeof item === 'string') {
|
||||
totalTokens += approximateTokenSize(item)
|
||||
} else if (item.type === 'text' && item.text) {
|
||||
totalTokens += approximateTokenSize(item.text)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add overhead for tool result structure
|
||||
totalTokens += 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -127,6 +144,70 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
|
||||
return { valid: true }
|
||||
}
|
||||
|
||||
/**
|
||||
* Shared handler for count_tokens endpoint
|
||||
* Validates request and returns token count estimation
|
||||
*/
|
||||
async function handleCountTokens(
|
||||
req: Request,
|
||||
res: Response,
|
||||
options: {
|
||||
requireModel?: boolean
|
||||
logContext?: Record<string, any>
|
||||
} = {}
|
||||
): Promise<Response> {
|
||||
try {
|
||||
const { model, messages, system } = req.body
|
||||
const { requireModel = false, logContext = {} } = options
|
||||
|
||||
// Validate model parameter if required
|
||||
if (requireModel && !model) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'model parameter is required'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate messages parameter
|
||||
if (!messages || !Array.isArray(messages)) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'messages parameter is required'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Estimate token count
|
||||
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||
|
||||
// Log with context
|
||||
logger.debug('Token count estimated', {
|
||||
model,
|
||||
messageCount: messages.length,
|
||||
estimatedTokens,
|
||||
...logContext
|
||||
})
|
||||
|
||||
return res.json({
|
||||
input_tokens: estimatedTokens
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Token counting error', { error })
|
||||
return res.status(500).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'api_error',
|
||||
message: error.message || 'Internal server error'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
interface HandleMessageProcessingOptions {
|
||||
res: Response
|
||||
provider: Provider
|
||||
@ -650,91 +731,17 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
||||
* description: Bad request
|
||||
*/
|
||||
router.post('/count_tokens', async (req: Request, res: Response) => {
|
||||
try {
|
||||
const { model, messages, system } = req.body
|
||||
|
||||
if (!model) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'model parameter is required'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if (!messages || !Array.isArray(messages)) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'messages parameter is required'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||
|
||||
logger.debug('Token count estimated', {
|
||||
model,
|
||||
messageCount: messages.length,
|
||||
estimatedTokens
|
||||
})
|
||||
|
||||
return res.json({
|
||||
input_tokens: estimatedTokens
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Token counting error', { error })
|
||||
return res.status(500).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'api_error',
|
||||
message: error.message || 'Internal server error'
|
||||
}
|
||||
})
|
||||
}
|
||||
return handleCountTokens(req, res, { requireModel: true })
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider-specific count_tokens endpoint
|
||||
*/
|
||||
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
|
||||
try {
|
||||
const { model, messages, system } = req.body
|
||||
|
||||
if (!messages || !Array.isArray(messages)) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'messages parameter is required'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||
|
||||
logger.debug('Token count estimated (provider route)', {
|
||||
providerId: req.params.provider,
|
||||
model,
|
||||
messageCount: messages.length,
|
||||
estimatedTokens
|
||||
})
|
||||
|
||||
return res.json({
|
||||
input_tokens: estimatedTokens
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Token counting error', { error })
|
||||
return res.status(500).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'api_error',
|
||||
message: error.message || 'Internal server error'
|
||||
}
|
||||
})
|
||||
}
|
||||
return handleCountTokens(req, res, {
|
||||
requireModel: false,
|
||||
logContext: { providerId: req.params.provider }
|
||||
})
|
||||
})
|
||||
|
||||
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
||||
|
||||
Loading…
Reference in New Issue
Block a user