mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 05:39:05 +08:00
feat(models): enhance models endpoint with filtering and pagination support
This commit is contained in:
parent
38076babcf
commit
73380d76df
@ -1,73 +1,129 @@
|
|||||||
import express, { Request, Response } from 'express'
|
import express, { Request, Response } from 'express'
|
||||||
|
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import { chatCompletionService } from '../services/chat-completion'
|
import { ModelsFilterSchema, modelsService } from '../services/models'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerModelsRoutes')
|
const logger = loggerService.withContext('ApiServerModelsRoutes')
|
||||||
|
|
||||||
const router = express.Router()
|
const router = express
|
||||||
|
.Router()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /v1/models:
|
* /v1/models:
|
||||||
* get:
|
* get:
|
||||||
* summary: List available models
|
* summary: List available models
|
||||||
* description: Returns a list of available AI models from all configured providers
|
* description: Returns a list of available AI models from all configured providers with optional filtering
|
||||||
* tags: [Models]
|
* tags: [Models]
|
||||||
* responses:
|
* parameters:
|
||||||
* 200:
|
* - in: query
|
||||||
* description: List of available models
|
* name: provider
|
||||||
* content:
|
* schema:
|
||||||
* application/json:
|
* type: string
|
||||||
* schema:
|
* enum: [openai, anthropic]
|
||||||
* type: object
|
* description: Filter models by provider type
|
||||||
* properties:
|
* - in: query
|
||||||
* object:
|
* name: offset
|
||||||
* type: string
|
* schema:
|
||||||
* example: list
|
* type: integer
|
||||||
* data:
|
* minimum: 0
|
||||||
* type: array
|
* default: 0
|
||||||
* items:
|
* description: Pagination offset
|
||||||
* $ref: '#/components/schemas/Model'
|
* - in: query
|
||||||
* 503:
|
* name: limit
|
||||||
* description: Service unavailable
|
* schema:
|
||||||
* content:
|
* type: integer
|
||||||
* application/json:
|
* minimum: 1
|
||||||
* schema:
|
* description: Maximum number of models to return
|
||||||
* $ref: '#/components/schemas/Error'
|
* responses:
|
||||||
*/
|
* 200:
|
||||||
router.get('/', async (_req: Request, res: Response) => {
|
* description: List of available models
|
||||||
try {
|
* content:
|
||||||
logger.info('Models list request received')
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* type: object
|
||||||
|
* properties:
|
||||||
|
* object:
|
||||||
|
* type: string
|
||||||
|
* example: list
|
||||||
|
* data:
|
||||||
|
* type: array
|
||||||
|
* items:
|
||||||
|
* $ref: '#/components/schemas/Model'
|
||||||
|
* total:
|
||||||
|
* type: integer
|
||||||
|
* description: Total number of models (when using pagination)
|
||||||
|
* offset:
|
||||||
|
* type: integer
|
||||||
|
* description: Current offset (when using pagination)
|
||||||
|
* limit:
|
||||||
|
* type: integer
|
||||||
|
* description: Current limit (when using pagination)
|
||||||
|
* 400:
|
||||||
|
* description: Invalid query parameters
|
||||||
|
* content:
|
||||||
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* $ref: '#/components/schemas/Error'
|
||||||
|
* 503:
|
||||||
|
* description: Service unavailable
|
||||||
|
* content:
|
||||||
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* $ref: '#/components/schemas/Error'
|
||||||
|
*/
|
||||||
|
.get('/', async (req: Request, res: Response) => {
|
||||||
|
try {
|
||||||
|
logger.info('Models list request received', { query: req.query })
|
||||||
|
|
||||||
const models = await chatCompletionService.getModels()
|
// Validate query parameters using Zod schema
|
||||||
|
const filterResult = ModelsFilterSchema.safeParse(req.query)
|
||||||
|
|
||||||
if (models.length === 0) {
|
if (!filterResult.success) {
|
||||||
logger.warn(
|
logger.warn('Invalid query parameters:', filterResult.error.issues)
|
||||||
'No models available from providers. This may be because no OpenAI providers are configured or enabled.'
|
return res.status(400).json({
|
||||||
)
|
error: {
|
||||||
}
|
message: 'Invalid query parameters',
|
||||||
|
type: 'invalid_request_error',
|
||||||
logger.info(`Returning ${models.length} models (OpenAI providers only)`)
|
code: 'invalid_parameters',
|
||||||
logger.debug(
|
details: filterResult.error.issues.map((issue) => ({
|
||||||
'Model IDs:',
|
field: issue.path.join('.'),
|
||||||
models.map((m) => m.id)
|
message: issue.message
|
||||||
)
|
}))
|
||||||
|
}
|
||||||
return res.json({
|
})
|
||||||
object: 'list',
|
|
||||||
data: models
|
|
||||||
})
|
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Error fetching models:', error)
|
|
||||||
return res.status(503).json({
|
|
||||||
error: {
|
|
||||||
message: 'Failed to retrieve models from available providers',
|
|
||||||
type: 'service_unavailable',
|
|
||||||
code: 'models_unavailable'
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
const filter = filterResult.data
|
||||||
})
|
const response = await modelsService.getModels(filter)
|
||||||
|
|
||||||
|
if (response.data.length === 0) {
|
||||||
|
logger.warn(
|
||||||
|
'No models available from providers. This may be because no OpenAI/Anthropic providers are configured or enabled.',
|
||||||
|
{ filter }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`Returning ${response.data.length} models`, {
|
||||||
|
filter,
|
||||||
|
total: response.total
|
||||||
|
})
|
||||||
|
logger.debug(
|
||||||
|
'Model IDs:',
|
||||||
|
response.data.map((m) => m.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return res.json(response)
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Error fetching models:', error)
|
||||||
|
return res.status(503).json({
|
||||||
|
error: {
|
||||||
|
message: 'Failed to retrieve models from available providers',
|
||||||
|
type: 'service_unavailable',
|
||||||
|
code: 'models_unavailable'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
export { router as modelsRoutes }
|
export { router as modelsRoutes }
|
||||||
|
|||||||
@ -2,70 +2,16 @@ import OpenAI from 'openai'
|
|||||||
import { ChatCompletionCreateParams } from 'openai/resources'
|
import { ChatCompletionCreateParams } from 'openai/resources'
|
||||||
|
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import {
|
import { getProviderByModel, getRealProviderModel, validateProvider } from '../utils'
|
||||||
getProviderByModel,
|
|
||||||
getRealProviderModel,
|
|
||||||
listAllAvailableModels,
|
|
||||||
OpenAICompatibleModel,
|
|
||||||
transformModelToOpenAI,
|
|
||||||
validateProvider
|
|
||||||
} from '../utils'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('ChatCompletionService')
|
const logger = loggerService.withContext('ChatCompletionService')
|
||||||
|
|
||||||
export interface ModelData extends OpenAICompatibleModel {
|
|
||||||
provider_id: string
|
|
||||||
model_id: string
|
|
||||||
name: string
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface ValidationResult {
|
export interface ValidationResult {
|
||||||
isValid: boolean
|
isValid: boolean
|
||||||
errors: string[]
|
errors: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ChatCompletionService {
|
export class ChatCompletionService {
|
||||||
async getModels(): Promise<ModelData[]> {
|
|
||||||
try {
|
|
||||||
logger.info('Getting available models from providers')
|
|
||||||
|
|
||||||
const models = await listAllAvailableModels()
|
|
||||||
|
|
||||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
|
||||||
const uniqueModels = new Map<string, ModelData>()
|
|
||||||
|
|
||||||
for (const model of models) {
|
|
||||||
const openAIModel = transformModelToOpenAI(model)
|
|
||||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
|
||||||
|
|
||||||
// Only add if not already present (first occurrence wins)
|
|
||||||
if (!uniqueModels.has(fullModelId)) {
|
|
||||||
uniqueModels.set(fullModelId, {
|
|
||||||
...openAIModel,
|
|
||||||
provider_id: model.provider,
|
|
||||||
model_id: model.id,
|
|
||||||
name: model.name
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
logger.debug(`Skipping duplicate model: ${fullModelId}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelData = Array.from(uniqueModels.values())
|
|
||||||
|
|
||||||
logger.info(`Successfully retrieved ${modelData.length} unique models from ${models.length} total models`)
|
|
||||||
|
|
||||||
if (models.length > modelData.length) {
|
|
||||||
logger.debug(`Filtered out ${models.length - modelData.length} duplicate models`)
|
|
||||||
}
|
|
||||||
|
|
||||||
return modelData
|
|
||||||
} catch (error: any) {
|
|
||||||
logger.error('Error getting models:', error)
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
|
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
|
||||||
const errors: string[] = []
|
const errors: string[] = []
|
||||||
|
|
||||||
@ -98,17 +44,6 @@ export class ChatCompletionService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate optional parameters
|
// Validate optional parameters
|
||||||
if (request.temperature !== undefined) {
|
|
||||||
if (typeof request.temperature !== 'number' || request.temperature < 0 || request.temperature > 2) {
|
|
||||||
errors.push('Temperature must be a number between 0 and 2')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (request.max_tokens !== undefined) {
|
|
||||||
if (typeof request.max_tokens !== 'number' || request.max_tokens < 1) {
|
|
||||||
errors.push('max_tokens must be a positive number')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isValid: errors.length === 0,
|
isValid: errors.length === 0,
|
||||||
|
|||||||
112
src/main/apiServer/services/models.ts
Normal file
112
src/main/apiServer/services/models.ts
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
import { loggerService } from '../../services/LoggerService'
|
||||||
|
import { getAvailableProviders, listAllAvailableModels, OpenAICompatibleModel, transformModelToOpenAI } from '../utils'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('ModelsService')
|
||||||
|
|
||||||
|
// Zod schema for models filtering
|
||||||
|
export const ModelsFilterSchema = z.object({
|
||||||
|
provider: z.enum(['openai', 'anthropic']).optional(),
|
||||||
|
offset: z.coerce.number().min(0).default(0).optional(),
|
||||||
|
limit: z.coerce.number().min(1).optional()
|
||||||
|
})
|
||||||
|
|
||||||
|
export type ModelsFilter = z.infer<typeof ModelsFilterSchema>
|
||||||
|
|
||||||
|
export interface ModelsResponse {
|
||||||
|
object: 'list'
|
||||||
|
data: OpenAICompatibleModel[]
|
||||||
|
total?: number
|
||||||
|
offset?: number
|
||||||
|
limit?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ModelsService {
|
||||||
|
async getModels(filter?: ModelsFilter): Promise<ModelsResponse> {
|
||||||
|
try {
|
||||||
|
logger.info('Getting available models from providers', { filter })
|
||||||
|
|
||||||
|
const models = await listAllAvailableModels()
|
||||||
|
const providers = await getAvailableProviders()
|
||||||
|
|
||||||
|
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||||
|
const uniqueModels = new Map<string, OpenAICompatibleModel>()
|
||||||
|
|
||||||
|
for (const model of models) {
|
||||||
|
const openAIModel = transformModelToOpenAI(model)
|
||||||
|
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||||
|
|
||||||
|
// Only add if not already present (first occurrence wins)
|
||||||
|
if (!uniqueModels.has(fullModelId)) {
|
||||||
|
uniqueModels.set(fullModelId, {
|
||||||
|
...openAIModel,
|
||||||
|
name: model.name
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
logger.debug(`Skipping duplicate model: ${fullModelId}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let modelData = Array.from(uniqueModels.values())
|
||||||
|
|
||||||
|
// Apply filters
|
||||||
|
if (filter?.provider) {
|
||||||
|
const providerType = filter.provider
|
||||||
|
modelData = modelData.filter((model) => {
|
||||||
|
// Find the provider for this model and check its type
|
||||||
|
const provider = providers.find((p) => p.id === model.provider)
|
||||||
|
return provider && provider.type === providerType
|
||||||
|
})
|
||||||
|
logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const total = modelData.length
|
||||||
|
|
||||||
|
// Apply pagination
|
||||||
|
const offset = filter?.offset || 0
|
||||||
|
const limit = filter?.limit
|
||||||
|
|
||||||
|
if (limit !== undefined) {
|
||||||
|
modelData = modelData.slice(offset, offset + limit)
|
||||||
|
logger.debug(
|
||||||
|
`Applied pagination: offset=${offset}, limit=${limit}, showing ${modelData.length} of ${total} models`
|
||||||
|
)
|
||||||
|
} else if (offset > 0) {
|
||||||
|
modelData = modelData.slice(offset)
|
||||||
|
logger.debug(`Applied offset: offset=${offset}, showing ${modelData.length} of ${total} models`)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`Successfully retrieved ${modelData.length} models from ${models.length} total models`)
|
||||||
|
|
||||||
|
if (models.length > total) {
|
||||||
|
logger.debug(`Filtered out ${models.length - total} models after deduplication and filtering`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const response: ModelsResponse = {
|
||||||
|
object: 'list',
|
||||||
|
data: modelData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add pagination metadata if applicable
|
||||||
|
if (filter?.limit !== undefined || filter?.offset !== undefined) {
|
||||||
|
response.total = total
|
||||||
|
response.offset = offset
|
||||||
|
if (filter?.limit !== undefined) {
|
||||||
|
response.limit = filter.limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Error getting models:', error)
|
||||||
|
return {
|
||||||
|
object: 'list',
|
||||||
|
data: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export singleton instance
|
||||||
|
export const modelsService = new ModelsService()
|
||||||
@ -9,6 +9,7 @@ export interface OpenAICompatibleModel {
|
|||||||
id: string
|
id: string
|
||||||
object: 'model'
|
object: 'model'
|
||||||
created: number
|
created: number
|
||||||
|
name: string
|
||||||
owned_by: string
|
owned_by: string
|
||||||
provider?: string
|
provider?: string
|
||||||
provider_model_id?: string
|
provider_model_id?: string
|
||||||
@ -185,6 +186,7 @@ export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
|
|||||||
return {
|
return {
|
||||||
id: `${model.provider}:${model.id}`,
|
id: `${model.provider}:${model.id}`,
|
||||||
object: 'model',
|
object: 'model',
|
||||||
|
name: model.name,
|
||||||
created: Math.floor(Date.now() / 1000),
|
created: Math.floor(Date.now() / 1000),
|
||||||
owned_by: model.owned_by || model.provider,
|
owned_by: model.owned_by || model.provider,
|
||||||
provider: model.provider,
|
provider: model.provider,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user