mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 22:39:36 +08:00
feat(tokens): enhance token estimation and refactor count_tokens endpoint for improved handling
This commit is contained in:
parent
2c910322f8
commit
45d404e127
@ -46,20 +46,11 @@ const providerRouter = express.Router({ mergeParams: true })
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Estimate token count from messages
|
* Estimate token count from messages
|
||||||
* Uses tokenx library for accurate token estimation and supports images
|
* Uses tokenx library for accurate token estimation and supports images, tools
|
||||||
*/
|
*/
|
||||||
interface CountTokensInput {
|
interface CountTokensInput {
|
||||||
messages: Array<{
|
messages: MessageCreateParams['messages']
|
||||||
role: string
|
system?: MessageCreateParams['system']
|
||||||
content:
|
|
||||||
| string
|
|
||||||
| Array<{
|
|
||||||
type: string
|
|
||||||
text?: string
|
|
||||||
source?: { type: string; media_type?: string; data?: string }
|
|
||||||
}>
|
|
||||||
}>
|
|
||||||
system?: string | Array<{ type: string; text?: string }>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function estimateTokenCount(input: CountTokensInput): number {
|
function estimateTokenCount(input: CountTokensInput): number {
|
||||||
@ -89,14 +80,40 @@ function estimateTokenCount(input: CountTokensInput): number {
|
|||||||
totalTokens += approximateTokenSize(block.text)
|
totalTokens += approximateTokenSize(block.text)
|
||||||
} else if (block.type === 'image') {
|
} else if (block.type === 'image') {
|
||||||
// Image token estimation (consistent with TokenService)
|
// Image token estimation (consistent with TokenService)
|
||||||
// Base64 images: estimate from data length
|
if (block.source.type === 'base64') {
|
||||||
if (block.source?.data) {
|
// Base64 images: estimate from data length
|
||||||
const dataSize = block.source.data.length * 0.75 // base64 to bytes
|
const dataSize = block.source.data.length * 0.75 // base64 to bytes
|
||||||
totalTokens += Math.floor(dataSize / 100)
|
totalTokens += Math.floor(dataSize / 100)
|
||||||
} else {
|
} else {
|
||||||
// Default image token estimate
|
// URL images: use default estimate
|
||||||
totalTokens += 1000
|
totalTokens += 1000
|
||||||
}
|
}
|
||||||
|
} else if (block.type === 'tool_use') {
|
||||||
|
// Tool use token estimation: name + input JSON
|
||||||
|
if (block.name) {
|
||||||
|
totalTokens += approximateTokenSize(block.name)
|
||||||
|
}
|
||||||
|
if (block.input) {
|
||||||
|
const inputJson = JSON.stringify(block.input)
|
||||||
|
totalTokens += approximateTokenSize(inputJson)
|
||||||
|
}
|
||||||
|
// Add overhead for tool use structure
|
||||||
|
totalTokens += 10
|
||||||
|
} else if (block.type === 'tool_result') {
|
||||||
|
// Tool result token estimation
|
||||||
|
if (typeof block.content === 'string') {
|
||||||
|
totalTokens += approximateTokenSize(block.content)
|
||||||
|
} else if (Array.isArray(block.content)) {
|
||||||
|
for (const item of block.content) {
|
||||||
|
if (typeof item === 'string') {
|
||||||
|
totalTokens += approximateTokenSize(item)
|
||||||
|
} else if (item.type === 'text' && item.text) {
|
||||||
|
totalTokens += approximateTokenSize(item.text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add overhead for tool result structure
|
||||||
|
totalTokens += 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -127,6 +144,70 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
|
|||||||
return { valid: true }
|
return { valid: true }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shared handler for count_tokens endpoint
|
||||||
|
* Validates request and returns token count estimation
|
||||||
|
*/
|
||||||
|
async function handleCountTokens(
|
||||||
|
req: Request,
|
||||||
|
res: Response,
|
||||||
|
options: {
|
||||||
|
requireModel?: boolean
|
||||||
|
logContext?: Record<string, any>
|
||||||
|
} = {}
|
||||||
|
): Promise<Response> {
|
||||||
|
try {
|
||||||
|
const { model, messages, system } = req.body
|
||||||
|
const { requireModel = false, logContext = {} } = options
|
||||||
|
|
||||||
|
// Validate model parameter if required
|
||||||
|
if (requireModel && !model) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: 'model parameter is required'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate messages parameter
|
||||||
|
if (!messages || !Array.isArray(messages)) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: 'messages parameter is required'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate token count
|
||||||
|
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||||
|
|
||||||
|
// Log with context
|
||||||
|
logger.debug('Token count estimated', {
|
||||||
|
model,
|
||||||
|
messageCount: messages.length,
|
||||||
|
estimatedTokens,
|
||||||
|
...logContext
|
||||||
|
})
|
||||||
|
|
||||||
|
return res.json({
|
||||||
|
input_tokens: estimatedTokens
|
||||||
|
})
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Token counting error', { error })
|
||||||
|
return res.status(500).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'api_error',
|
||||||
|
message: error.message || 'Internal server error'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
interface HandleMessageProcessingOptions {
|
interface HandleMessageProcessingOptions {
|
||||||
res: Response
|
res: Response
|
||||||
provider: Provider
|
provider: Provider
|
||||||
@ -650,91 +731,17 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
|||||||
* description: Bad request
|
* description: Bad request
|
||||||
*/
|
*/
|
||||||
router.post('/count_tokens', async (req: Request, res: Response) => {
|
router.post('/count_tokens', async (req: Request, res: Response) => {
|
||||||
try {
|
return handleCountTokens(req, res, { requireModel: true })
|
||||||
const { model, messages, system } = req.body
|
|
||||||
|
|
||||||
if (!model) {
|
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: 'model parameter is required'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!messages || !Array.isArray(messages)) {
|
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: 'messages parameter is required'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const estimatedTokens = estimateTokenCount({ messages, system })
|
|
||||||
|
|
||||||
logger.debug('Token count estimated', {
|
|
||||||
model,
|
|
||||||
messageCount: messages.length,
|
|
||||||
estimatedTokens
|
|
||||||
})
|
|
||||||
|
|
||||||
return res.json({
|
|
||||||
input_tokens: estimatedTokens
|
|
||||||
})
|
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Token counting error', { error })
|
|
||||||
return res.status(500).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'api_error',
|
|
||||||
message: error.message || 'Internal server error'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provider-specific count_tokens endpoint
|
* Provider-specific count_tokens endpoint
|
||||||
*/
|
*/
|
||||||
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
|
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
|
||||||
try {
|
return handleCountTokens(req, res, {
|
||||||
const { model, messages, system } = req.body
|
requireModel: false,
|
||||||
|
logContext: { providerId: req.params.provider }
|
||||||
if (!messages || !Array.isArray(messages)) {
|
})
|
||||||
return res.status(400).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'invalid_request_error',
|
|
||||||
message: 'messages parameter is required'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const estimatedTokens = estimateTokenCount({ messages, system })
|
|
||||||
|
|
||||||
logger.debug('Token count estimated (provider route)', {
|
|
||||||
providerId: req.params.provider,
|
|
||||||
model,
|
|
||||||
messageCount: messages.length,
|
|
||||||
estimatedTokens
|
|
||||||
})
|
|
||||||
|
|
||||||
return res.json({
|
|
||||||
input_tokens: estimatedTokens
|
|
||||||
})
|
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Token counting error', { error })
|
|
||||||
return res.status(500).json({
|
|
||||||
type: 'error',
|
|
||||||
error: {
|
|
||||||
type: 'api_error',
|
|
||||||
message: error.message || 'Internal server error'
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user