mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 10:40:07 +08:00
feat(models): update models filtering to use providerType and enhance API schemas
This commit is contained in:
parent
92ba1e4fc3
commit
49f9dff9da
@ -1,7 +1,8 @@
|
||||
import { ApiModelsFilterSchema } from '@types'
|
||||
import express, { Request, Response } from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { ModelsFilterSchema, modelsService } from '../services/models'
|
||||
import { modelsService } from '../services/models'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerModelsRoutes')
|
||||
|
||||
@ -17,10 +18,10 @@ const router = express
|
||||
* tags: [Models]
|
||||
* parameters:
|
||||
* - in: query
|
||||
* name: provider
|
||||
* name: providerType
|
||||
* schema:
|
||||
* type: string
|
||||
* enum: [openai, anthropic]
|
||||
* enum: [openai, openai-response, anthropic, gemini]
|
||||
* description: Filter models by provider type
|
||||
* - in: query
|
||||
* name: offset
|
||||
@ -77,7 +78,7 @@ const router = express
|
||||
logger.info('Models list request received', { query: req.query })
|
||||
|
||||
// Validate query parameters using Zod schema
|
||||
const filterResult = ModelsFilterSchema.safeParse(req.query)
|
||||
const filterResult = ApiModelsFilterSchema.safeParse(req.query)
|
||||
|
||||
if (!filterResult.success) {
|
||||
logger.warn('Invalid query parameters:', filterResult.error.issues)
|
||||
|
||||
@ -1,53 +1,43 @@
|
||||
import {
|
||||
ApiModelsRequest,
|
||||
ApiModelsRequestSchema,
|
||||
ApiModelsResponse,
|
||||
OpenAICompatibleModel
|
||||
} from '../../../renderer/src/types/apiModels'
|
||||
import { ApiModel, ApiModelsRequest, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ModelsService')
|
||||
|
||||
// Re-export for backward compatibility
|
||||
export const ModelsFilterSchema = ApiModelsRequestSchema
|
||||
|
||||
export type ModelsFilter = ApiModelsRequest
|
||||
|
||||
export class ModelsService {
|
||||
async getModels(filter?: ModelsFilter): Promise<ApiModelsResponse> {
|
||||
async getModels(filter: ModelsFilter): Promise<ApiModelsResponse> {
|
||||
try {
|
||||
logger.info('Getting available models from providers', { filter })
|
||||
logger.debug('Getting available models from providers', { filter })
|
||||
|
||||
const models = await listAllAvailableModels()
|
||||
const providers = await getAvailableProviders()
|
||||
|
||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||
const uniqueModels = new Map<string, OpenAICompatibleModel>()
|
||||
const uniqueModels = new Map<string, ApiModel>()
|
||||
|
||||
for (const model of models) {
|
||||
const openAIModel = transformModelToOpenAI(model)
|
||||
const openAIModel = transformModelToOpenAI(model, providers)
|
||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||
|
||||
// Only add if not already present (first occurrence wins)
|
||||
if (!uniqueModels.has(fullModelId)) {
|
||||
uniqueModels.set(fullModelId, {
|
||||
...openAIModel,
|
||||
name: model.name
|
||||
})
|
||||
uniqueModels.set(fullModelId, openAIModel)
|
||||
} else {
|
||||
logger.debug(`Skipping duplicate model: ${fullModelId}`)
|
||||
}
|
||||
}
|
||||
|
||||
let modelData = Array.from(uniqueModels.values())
|
||||
|
||||
// Apply filters
|
||||
if (filter?.provider) {
|
||||
const providerType = filter.provider
|
||||
if (filter.providerType) {
|
||||
// Apply filters
|
||||
const providerType = filter.providerType
|
||||
modelData = modelData.filter((model) => {
|
||||
// Find the provider for this model and check its type
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
return provider && provider.type === providerType
|
||||
return model.provider_type === providerType
|
||||
})
|
||||
logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`)
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import { Model, OpenAICompatibleModel, Provider } from '@types'
|
||||
import { ApiModel, Model, Provider } from '@types'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerUtils')
|
||||
|
||||
@ -173,7 +173,8 @@ export async function validateModelId(
|
||||
}
|
||||
}
|
||||
|
||||
export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
|
||||
export function transformModelToOpenAI(model: Model, providers: Provider[]): ApiModel {
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
return {
|
||||
id: `${model.provider}:${model.id}`,
|
||||
object: 'model',
|
||||
@ -181,6 +182,7 @@ export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
owned_by: model.owned_by || model.provider,
|
||||
provider: model.provider,
|
||||
provider_type: provider?.type,
|
||||
provider_model_id: model.id
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,33 +1,36 @@
|
||||
import { z } from 'zod'
|
||||
|
||||
import { ProviderTypeSchema } from './provider'
|
||||
|
||||
// Request schema for /v1/models
|
||||
export const ApiModelsRequestSchema = z.object({
|
||||
provider: z.enum(['openai', 'anthropic']).optional(),
|
||||
export const ApiModelsFilterSchema = z.object({
|
||||
providerType: ProviderTypeSchema.optional(),
|
||||
offset: z.coerce.number().min(0).default(0).optional(),
|
||||
limit: z.coerce.number().min(1).optional()
|
||||
limit: z.coerce.number().min(1).default(20).optional()
|
||||
})
|
||||
|
||||
// OpenAI compatible model schema
|
||||
export const OpenAICompatibleModelSchema = z.object({
|
||||
export const ApiModelSchema = z.object({
|
||||
id: z.string(),
|
||||
object: z.literal('model'),
|
||||
created: z.number(),
|
||||
name: z.string(),
|
||||
owned_by: z.string(),
|
||||
provider: z.string().optional(),
|
||||
provider_type: ProviderTypeSchema.optional(),
|
||||
provider_model_id: z.string().optional()
|
||||
})
|
||||
|
||||
// Response schema for /v1/models
|
||||
export const ApiModelsResponseSchema = z.object({
|
||||
object: z.literal('list'),
|
||||
data: z.array(OpenAICompatibleModelSchema),
|
||||
data: z.array(ApiModelSchema),
|
||||
total: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
limit: z.number().optional()
|
||||
})
|
||||
|
||||
// Inferred TypeScript types
|
||||
export type ApiModelsRequest = z.infer<typeof ApiModelsRequestSchema>
|
||||
export type OpenAICompatibleModel = z.infer<typeof OpenAICompatibleModelSchema>
|
||||
export type ApiModel = z.infer<typeof ApiModelSchema>
|
||||
export type ApiModelsRequest = z.infer<typeof ApiModelsFilterSchema>
|
||||
export type ApiModelsResponse = z.infer<typeof ApiModelsResponseSchema>
|
||||
|
||||
@ -21,6 +21,7 @@ export * from './knowledge'
|
||||
export * from './mcp'
|
||||
export * from './notification'
|
||||
export * from './ocr'
|
||||
export * from './provider'
|
||||
|
||||
export type Assistant = {
|
||||
id: string
|
||||
@ -216,158 +217,6 @@ export type User = {
|
||||
email: string
|
||||
}
|
||||
|
||||
// undefined 视为支持,默认支持
|
||||
export type ProviderApiOptions = {
|
||||
/** 是否不支持 message 的 content 为数组类型 */
|
||||
isNotSupportArrayContent?: boolean
|
||||
/** 是否不支持 stream_options 参数 */
|
||||
isNotSupportStreamOptions?: boolean
|
||||
/**
|
||||
* @deprecated
|
||||
* 是否不支持 message 的 role 为 developer */
|
||||
isNotSupportDeveloperRole?: boolean
|
||||
/* 是否支持 message 的 role 为 developer */
|
||||
isSupportDeveloperRole?: boolean
|
||||
/**
|
||||
* @deprecated
|
||||
* 是否不支持 service_tier 参数. Only for OpenAI Models. */
|
||||
isNotSupportServiceTier?: boolean
|
||||
/* 是否支持 service_tier 参数. Only for OpenAI Models. */
|
||||
isSupportServiceTier?: boolean
|
||||
/** 是否不支持 enable_thinking 参数 */
|
||||
isNotSupportEnableThinking?: boolean
|
||||
}
|
||||
|
||||
export type Provider = {
|
||||
id: string
|
||||
type: ProviderType
|
||||
name: string
|
||||
apiKey: string
|
||||
apiHost: string
|
||||
apiVersion?: string
|
||||
models: Model[]
|
||||
enabled?: boolean
|
||||
isSystem?: boolean
|
||||
isAuthed?: boolean
|
||||
rateLimit?: number
|
||||
|
||||
// API options
|
||||
apiOptions?: ProviderApiOptions
|
||||
serviceTier?: ServiceTier
|
||||
|
||||
/** @deprecated */
|
||||
isNotSupportArrayContent?: boolean
|
||||
/** @deprecated */
|
||||
isNotSupportStreamOptions?: boolean
|
||||
/** @deprecated */
|
||||
isNotSupportDeveloperRole?: boolean
|
||||
/** @deprecated */
|
||||
isNotSupportServiceTier?: boolean
|
||||
|
||||
authType?: 'apiKey' | 'oauth'
|
||||
isVertex?: boolean
|
||||
notes?: string
|
||||
extra_headers?: Record<string, string>
|
||||
}
|
||||
|
||||
export const SystemProviderIds = {
|
||||
cherryin: 'cherryin',
|
||||
silicon: 'silicon',
|
||||
aihubmix: 'aihubmix',
|
||||
ocoolai: 'ocoolai',
|
||||
deepseek: 'deepseek',
|
||||
ppio: 'ppio',
|
||||
alayanew: 'alayanew',
|
||||
qiniu: 'qiniu',
|
||||
dmxapi: 'dmxapi',
|
||||
burncloud: 'burncloud',
|
||||
tokenflux: 'tokenflux',
|
||||
'302ai': '302ai',
|
||||
cephalon: 'cephalon',
|
||||
lanyun: 'lanyun',
|
||||
ph8: 'ph8',
|
||||
openrouter: 'openrouter',
|
||||
ollama: 'ollama',
|
||||
'new-api': 'new-api',
|
||||
lmstudio: 'lmstudio',
|
||||
anthropic: 'anthropic',
|
||||
openai: 'openai',
|
||||
'azure-openai': 'azure-openai',
|
||||
gemini: 'gemini',
|
||||
vertexai: 'vertexai',
|
||||
github: 'github',
|
||||
copilot: 'copilot',
|
||||
zhipu: 'zhipu',
|
||||
yi: 'yi',
|
||||
moonshot: 'moonshot',
|
||||
baichuan: 'baichuan',
|
||||
dashscope: 'dashscope',
|
||||
stepfun: 'stepfun',
|
||||
doubao: 'doubao',
|
||||
infini: 'infini',
|
||||
minimax: 'minimax',
|
||||
groq: 'groq',
|
||||
together: 'together',
|
||||
fireworks: 'fireworks',
|
||||
nvidia: 'nvidia',
|
||||
grok: 'grok',
|
||||
hyperbolic: 'hyperbolic',
|
||||
mistral: 'mistral',
|
||||
jina: 'jina',
|
||||
perplexity: 'perplexity',
|
||||
modelscope: 'modelscope',
|
||||
xirang: 'xirang',
|
||||
hunyuan: 'hunyuan',
|
||||
'tencent-cloud-ti': 'tencent-cloud-ti',
|
||||
'baidu-cloud': 'baidu-cloud',
|
||||
gpustack: 'gpustack',
|
||||
voyageai: 'voyageai',
|
||||
'aws-bedrock': 'aws-bedrock',
|
||||
poe: 'poe'
|
||||
} as const
|
||||
|
||||
export type SystemProviderId = keyof typeof SystemProviderIds
|
||||
|
||||
export const isSystemProviderId = (id: string): id is SystemProviderId => {
|
||||
return Object.hasOwn(SystemProviderIds, id)
|
||||
}
|
||||
|
||||
export type SystemProvider = Provider & {
|
||||
id: SystemProviderId
|
||||
isSystem: true
|
||||
apiOptions?: never
|
||||
}
|
||||
|
||||
export type VertexProvider = Provider & {
|
||||
googleCredentials: {
|
||||
privateKey: string
|
||||
clientEmail: string
|
||||
}
|
||||
project: string
|
||||
location: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断是否为系统内置的提供商。比直接使用`provider.isSystem`更好,因为该数据字段不会随着版本更新而变化。
|
||||
* @param provider - Provider对象,包含提供商的信息
|
||||
* @returns 是否为系统内置提供商
|
||||
*/
|
||||
export const isSystemProvider = (provider: Provider): provider is SystemProvider => {
|
||||
return isSystemProviderId(provider.id) && !!provider.isSystem
|
||||
}
|
||||
|
||||
export type ProviderType =
|
||||
| 'openai'
|
||||
| 'openai-response'
|
||||
| 'anthropic'
|
||||
| 'gemini'
|
||||
| 'qwenlm'
|
||||
| 'azure-openai'
|
||||
| 'vertexai'
|
||||
| 'mistral'
|
||||
| 'aws-bedrock'
|
||||
| 'vertex-anthropic'
|
||||
|
||||
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank'
|
||||
|
||||
export type ModelTag = Exclude<ModelType, 'text'> | 'free'
|
||||
@ -976,39 +825,6 @@ export type OpenAIVerbosity = 'high' | 'medium' | 'low'
|
||||
|
||||
export type OpenAISummaryText = 'auto' | 'concise' | 'detailed' | 'off'
|
||||
|
||||
export const OpenAIServiceTiers = {
|
||||
auto: 'auto',
|
||||
default: 'default',
|
||||
flex: 'flex',
|
||||
priority: 'priority'
|
||||
} as const
|
||||
|
||||
export type OpenAIServiceTier = keyof typeof OpenAIServiceTiers
|
||||
|
||||
export function isOpenAIServiceTier(tier: string): tier is OpenAIServiceTier {
|
||||
return Object.hasOwn(OpenAIServiceTiers, tier)
|
||||
}
|
||||
|
||||
export const GroqServiceTiers = {
|
||||
auto: 'auto',
|
||||
on_demand: 'on_demand',
|
||||
flex: 'flex',
|
||||
performance: 'performance'
|
||||
} as const
|
||||
|
||||
// 从 GroqServiceTiers 对象中提取类型
|
||||
export type GroqServiceTier = keyof typeof GroqServiceTiers
|
||||
|
||||
export function isGroqServiceTier(tier: string): tier is GroqServiceTier {
|
||||
return Object.hasOwn(GroqServiceTiers, tier)
|
||||
}
|
||||
|
||||
export type ServiceTier = OpenAIServiceTier | GroqServiceTier
|
||||
|
||||
export function isServiceTier(tier: string): tier is ServiceTier {
|
||||
return isGroqServiceTier(tier) || isOpenAIServiceTier(tier)
|
||||
}
|
||||
|
||||
export type S3Config = {
|
||||
endpoint: string
|
||||
region: string
|
||||
|
||||
190
src/renderer/src/types/provider.ts
Normal file
190
src/renderer/src/types/provider.ts
Normal file
@ -0,0 +1,190 @@
|
||||
import { Model } from '@types'
|
||||
import z from 'zod'
|
||||
|
||||
export const ProviderTypeSchema = z.enum([
|
||||
'openai',
|
||||
'openai-response',
|
||||
'anthropic',
|
||||
'gemini',
|
||||
'qwenlm',
|
||||
'azure-openai',
|
||||
'vertexai',
|
||||
'mistral',
|
||||
'aws-bedrock',
|
||||
'vertex-anthropic'
|
||||
])
|
||||
|
||||
export type ProviderType = z.infer<typeof ProviderTypeSchema>
|
||||
|
||||
// undefined 视为支持,默认支持
|
||||
export type ProviderApiOptions = {
|
||||
/** 是否不支持 message 的 content 为数组类型 */
|
||||
isNotSupportArrayContent?: boolean
|
||||
/** 是否不支持 stream_options 参数 */
|
||||
isNotSupportStreamOptions?: boolean
|
||||
/**
|
||||
* @deprecated
|
||||
* 是否不支持 message 的 role 为 developer */
|
||||
isNotSupportDeveloperRole?: boolean
|
||||
/* 是否支持 message 的 role 为 developer */
|
||||
isSupportDeveloperRole?: boolean
|
||||
/**
|
||||
* @deprecated
|
||||
* 是否不支持 service_tier 参数. Only for OpenAI Models. */
|
||||
isNotSupportServiceTier?: boolean
|
||||
/* 是否支持 service_tier 参数. Only for OpenAI Models. */
|
||||
isSupportServiceTier?: boolean
|
||||
/** 是否不支持 enable_thinking 参数 */
|
||||
isNotSupportEnableThinking?: boolean
|
||||
}
|
||||
|
||||
export const OpenAIServiceTiers = {
|
||||
auto: 'auto',
|
||||
default: 'default',
|
||||
flex: 'flex',
|
||||
priority: 'priority'
|
||||
} as const
|
||||
|
||||
export type OpenAIServiceTier = keyof typeof OpenAIServiceTiers
|
||||
|
||||
export function isOpenAIServiceTier(tier: string): tier is OpenAIServiceTier {
|
||||
return Object.hasOwn(OpenAIServiceTiers, tier)
|
||||
}
|
||||
|
||||
export const GroqServiceTiers = {
|
||||
auto: 'auto',
|
||||
on_demand: 'on_demand',
|
||||
flex: 'flex',
|
||||
performance: 'performance'
|
||||
} as const
|
||||
|
||||
// 从 GroqServiceTiers 对象中提取类型
|
||||
export type GroqServiceTier = keyof typeof GroqServiceTiers
|
||||
|
||||
export function isGroqServiceTier(tier: string): tier is GroqServiceTier {
|
||||
return Object.hasOwn(GroqServiceTiers, tier)
|
||||
}
|
||||
|
||||
export type ServiceTier = OpenAIServiceTier | GroqServiceTier
|
||||
|
||||
export function isServiceTier(tier: string): tier is ServiceTier {
|
||||
return isGroqServiceTier(tier) || isOpenAIServiceTier(tier)
|
||||
}
|
||||
|
||||
export type Provider = {
|
||||
id: string
|
||||
type: ProviderType
|
||||
name: string
|
||||
apiKey: string
|
||||
apiHost: string
|
||||
apiVersion?: string
|
||||
models: Model[]
|
||||
enabled?: boolean
|
||||
isSystem?: boolean
|
||||
isAuthed?: boolean
|
||||
rateLimit?: number
|
||||
|
||||
// API options
|
||||
apiOptions?: ProviderApiOptions
|
||||
serviceTier?: ServiceTier
|
||||
|
||||
/** @deprecated */
|
||||
isNotSupportArrayContent?: boolean
|
||||
/** @deprecated */
|
||||
isNotSupportStreamOptions?: boolean
|
||||
/** @deprecated */
|
||||
isNotSupportDeveloperRole?: boolean
|
||||
/** @deprecated */
|
||||
isNotSupportServiceTier?: boolean
|
||||
|
||||
authType?: 'apiKey' | 'oauth'
|
||||
isVertex?: boolean
|
||||
notes?: string
|
||||
extra_headers?: Record<string, string>
|
||||
}
|
||||
|
||||
export const SystemProviderIds = {
|
||||
cherryin: 'cherryin',
|
||||
silicon: 'silicon',
|
||||
aihubmix: 'aihubmix',
|
||||
ocoolai: 'ocoolai',
|
||||
deepseek: 'deepseek',
|
||||
ppio: 'ppio',
|
||||
alayanew: 'alayanew',
|
||||
qiniu: 'qiniu',
|
||||
dmxapi: 'dmxapi',
|
||||
burncloud: 'burncloud',
|
||||
tokenflux: 'tokenflux',
|
||||
'302ai': '302ai',
|
||||
cephalon: 'cephalon',
|
||||
lanyun: 'lanyun',
|
||||
ph8: 'ph8',
|
||||
openrouter: 'openrouter',
|
||||
ollama: 'ollama',
|
||||
'new-api': 'new-api',
|
||||
lmstudio: 'lmstudio',
|
||||
anthropic: 'anthropic',
|
||||
openai: 'openai',
|
||||
'azure-openai': 'azure-openai',
|
||||
gemini: 'gemini',
|
||||
vertexai: 'vertexai',
|
||||
github: 'github',
|
||||
copilot: 'copilot',
|
||||
zhipu: 'zhipu',
|
||||
yi: 'yi',
|
||||
moonshot: 'moonshot',
|
||||
baichuan: 'baichuan',
|
||||
dashscope: 'dashscope',
|
||||
stepfun: 'stepfun',
|
||||
doubao: 'doubao',
|
||||
infini: 'infini',
|
||||
minimax: 'minimax',
|
||||
groq: 'groq',
|
||||
together: 'together',
|
||||
fireworks: 'fireworks',
|
||||
nvidia: 'nvidia',
|
||||
grok: 'grok',
|
||||
hyperbolic: 'hyperbolic',
|
||||
mistral: 'mistral',
|
||||
jina: 'jina',
|
||||
perplexity: 'perplexity',
|
||||
modelscope: 'modelscope',
|
||||
xirang: 'xirang',
|
||||
hunyuan: 'hunyuan',
|
||||
'tencent-cloud-ti': 'tencent-cloud-ti',
|
||||
'baidu-cloud': 'baidu-cloud',
|
||||
gpustack: 'gpustack',
|
||||
voyageai: 'voyageai',
|
||||
'aws-bedrock': 'aws-bedrock',
|
||||
poe: 'poe'
|
||||
} as const
|
||||
|
||||
export type SystemProviderId = keyof typeof SystemProviderIds
|
||||
|
||||
export const isSystemProviderId = (id: string): id is SystemProviderId => {
|
||||
return Object.hasOwn(SystemProviderIds, id)
|
||||
}
|
||||
|
||||
export type SystemProvider = Provider & {
|
||||
id: SystemProviderId
|
||||
isSystem: true
|
||||
apiOptions?: never
|
||||
}
|
||||
|
||||
export type VertexProvider = Provider & {
|
||||
googleCredentials: {
|
||||
privateKey: string
|
||||
clientEmail: string
|
||||
}
|
||||
project: string
|
||||
location: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断是否为系统内置的提供商。比直接使用`provider.isSystem`更好,因为该数据字段不会随着版本更新而变化。
|
||||
* @param provider - Provider对象,包含提供商的信息
|
||||
* @returns 是否为系统内置提供商
|
||||
*/
|
||||
export const isSystemProvider = (provider: Provider): provider is SystemProvider => {
|
||||
return isSystemProviderId(provider.id) && !!provider.isSystem
|
||||
}
|
||||
@ -9,7 +9,7 @@ Authorization: Bearer {{token}}
|
||||
|
||||
|
||||
### List Models With Filters
|
||||
GET {{host}}/v1/models?provider=anthropic&limit=5
|
||||
GET {{host}}/v1/models?providerType=anthropic&limit=5
|
||||
Authorization: Bearer {{token}}
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user