From 49f9dff9da6f5dc08712ddff768d64a8a78d9d5f Mon Sep 17 00:00:00 2001 From: Vaayne Date: Fri, 19 Sep 2025 17:04:27 +0800 Subject: [PATCH] feat(models): update models filtering to use providerType and enhance API schemas --- src/main/apiServer/routes/models.ts | 9 +- src/main/apiServer/services/models.ts | 32 ++--- src/main/apiServer/utils/index.ts | 6 +- src/renderer/src/types/apiModels.ts | 17 ++- src/renderer/src/types/index.ts | 186 +------------------------ src/renderer/src/types/provider.ts | 190 ++++++++++++++++++++++++++ tests/apis/chat.http | 2 +- 7 files changed, 222 insertions(+), 220 deletions(-) create mode 100644 src/renderer/src/types/provider.ts diff --git a/src/main/apiServer/routes/models.ts b/src/main/apiServer/routes/models.ts index 197f5ccb5a..37588150a5 100644 --- a/src/main/apiServer/routes/models.ts +++ b/src/main/apiServer/routes/models.ts @@ -1,7 +1,8 @@ +import { ApiModelsFilterSchema } from '@types' import express, { Request, Response } from 'express' import { loggerService } from '../../services/LoggerService' -import { ModelsFilterSchema, modelsService } from '../services/models' +import { modelsService } from '../services/models' const logger = loggerService.withContext('ApiServerModelsRoutes') @@ -17,10 +18,10 @@ const router = express * tags: [Models] * parameters: * - in: query - * name: provider + * name: providerType * schema: * type: string - * enum: [openai, anthropic] + * enum: [openai, openai-response, anthropic, gemini] * description: Filter models by provider type * - in: query * name: offset @@ -77,7 +78,7 @@ const router = express logger.info('Models list request received', { query: req.query }) // Validate query parameters using Zod schema - const filterResult = ModelsFilterSchema.safeParse(req.query) + const filterResult = ApiModelsFilterSchema.safeParse(req.query) if (!filterResult.success) { logger.warn('Invalid query parameters:', filterResult.error.issues) diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts index 98445081c6..3e281d68ab 100644 --- a/src/main/apiServer/services/models.ts +++ b/src/main/apiServer/services/models.ts @@ -1,53 +1,43 @@ -import { - ApiModelsRequest, - ApiModelsRequestSchema, - ApiModelsResponse, - OpenAICompatibleModel -} from '../../../renderer/src/types/apiModels' +import { ApiModel, ApiModelsRequest, ApiModelsResponse } from '../../../renderer/src/types/apiModels' import { loggerService } from '../../services/LoggerService' import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils' const logger = loggerService.withContext('ModelsService') // Re-export for backward compatibility -export const ModelsFilterSchema = ApiModelsRequestSchema + export type ModelsFilter = ApiModelsRequest export class ModelsService { - async getModels(filter?: ModelsFilter): Promise { + async getModels(filter: ModelsFilter): Promise { try { - logger.info('Getting available models from providers', { filter }) + logger.debug('Getting available models from providers', { filter }) const models = await listAllAvailableModels() const providers = await getAvailableProviders() // Use Map to deduplicate models by their full ID (provider:model_id) - const uniqueModels = new Map() + const uniqueModels = new Map() for (const model of models) { - const openAIModel = transformModelToOpenAI(model) + const openAIModel = transformModelToOpenAI(model, providers) const fullModelId = openAIModel.id // This is already in format "provider:model_id" // Only add if not already present (first occurrence wins) if (!uniqueModels.has(fullModelId)) { - uniqueModels.set(fullModelId, { - ...openAIModel, - name: model.name - }) + uniqueModels.set(fullModelId, openAIModel) } else { logger.debug(`Skipping duplicate model: ${fullModelId}`) } } let modelData = Array.from(uniqueModels.values()) - - // Apply filters - if (filter?.provider) { - const providerType = filter.provider + if (filter.providerType) { + // Apply filters + const providerType = filter.providerType modelData = modelData.filter((model) => { // Find the provider for this model and check its type - const provider = providers.find((p) => p.id === model.provider) - return provider && provider.type === providerType + return model.provider_type === providerType }) logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`) } diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index d3137d3a1f..f1e0d68454 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -1,6 +1,6 @@ import { loggerService } from '@main/services/LoggerService' import { reduxService } from '@main/services/ReduxService' -import { Model, OpenAICompatibleModel, Provider } from '@types' +import { ApiModel, Model, Provider } from '@types' const logger = loggerService.withContext('ApiServerUtils') @@ -173,7 +173,8 @@ export async function validateModelId( } } -export function transformModelToOpenAI(model: Model): OpenAICompatibleModel { +export function transformModelToOpenAI(model: Model, providers: Provider[]): ApiModel { + const provider = providers.find((p) => p.id === model.provider) return { id: `${model.provider}:${model.id}`, object: 'model', @@ -181,6 +182,7 @@ export function transformModelToOpenAI(model: Model): OpenAICompatibleModel { created: Math.floor(Date.now() / 1000), owned_by: model.owned_by || model.provider, provider: model.provider, + provider_type: provider?.type, provider_model_id: model.id } } diff --git a/src/renderer/src/types/apiModels.ts b/src/renderer/src/types/apiModels.ts index f023085bd8..0df2787320 100644 --- a/src/renderer/src/types/apiModels.ts +++ b/src/renderer/src/types/apiModels.ts @@ -1,33 +1,36 @@ import { z } from 'zod' +import { ProviderTypeSchema } from './provider' + // Request schema for /v1/models -export const ApiModelsRequestSchema = z.object({ - provider: z.enum(['openai', 'anthropic']).optional(), +export const ApiModelsFilterSchema = z.object({ + providerType: ProviderTypeSchema.optional(), offset: z.coerce.number().min(0).default(0).optional(), - limit: z.coerce.number().min(1).optional() + limit: z.coerce.number().min(1).default(20).optional() }) // OpenAI compatible model schema -export const OpenAICompatibleModelSchema = z.object({ +export const ApiModelSchema = z.object({ id: z.string(), object: z.literal('model'), created: z.number(), name: z.string(), owned_by: z.string(), provider: z.string().optional(), + provider_type: ProviderTypeSchema.optional(), provider_model_id: z.string().optional() }) // Response schema for /v1/models export const ApiModelsResponseSchema = z.object({ object: z.literal('list'), - data: z.array(OpenAICompatibleModelSchema), + data: z.array(ApiModelSchema), total: z.number().optional(), offset: z.number().optional(), limit: z.number().optional() }) // Inferred TypeScript types -export type ApiModelsRequest = z.infer -export type OpenAICompatibleModel = z.infer +export type ApiModel = z.infer +export type ApiModelsRequest = z.infer export type ApiModelsResponse = z.infer diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index b54e701074..733672d043 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -21,6 +21,7 @@ export * from './knowledge' export * from './mcp' export * from './notification' export * from './ocr' +export * from './provider' export type Assistant = { id: string @@ -216,158 +217,6 @@ export type User = { email: string } -// undefined 视为支持,默认支持 -export type ProviderApiOptions = { - /** 是否不支持 message 的 content 为数组类型 */ - isNotSupportArrayContent?: boolean - /** 是否不支持 stream_options 参数 */ - isNotSupportStreamOptions?: boolean - /** - * @deprecated - * 是否不支持 message 的 role 为 developer */ - isNotSupportDeveloperRole?: boolean - /* 是否支持 message 的 role 为 developer */ - isSupportDeveloperRole?: boolean - /** - * @deprecated - * 是否不支持 service_tier 参数. Only for OpenAI Models. */ - isNotSupportServiceTier?: boolean - /* 是否支持 service_tier 参数. Only for OpenAI Models. */ - isSupportServiceTier?: boolean - /** 是否不支持 enable_thinking 参数 */ - isNotSupportEnableThinking?: boolean -} - -export type Provider = { - id: string - type: ProviderType - name: string - apiKey: string - apiHost: string - apiVersion?: string - models: Model[] - enabled?: boolean - isSystem?: boolean - isAuthed?: boolean - rateLimit?: number - - // API options - apiOptions?: ProviderApiOptions - serviceTier?: ServiceTier - - /** @deprecated */ - isNotSupportArrayContent?: boolean - /** @deprecated */ - isNotSupportStreamOptions?: boolean - /** @deprecated */ - isNotSupportDeveloperRole?: boolean - /** @deprecated */ - isNotSupportServiceTier?: boolean - - authType?: 'apiKey' | 'oauth' - isVertex?: boolean - notes?: string - extra_headers?: Record -} - -export const SystemProviderIds = { - cherryin: 'cherryin', - silicon: 'silicon', - aihubmix: 'aihubmix', - ocoolai: 'ocoolai', - deepseek: 'deepseek', - ppio: 'ppio', - alayanew: 'alayanew', - qiniu: 'qiniu', - dmxapi: 'dmxapi', - burncloud: 'burncloud', - tokenflux: 'tokenflux', - '302ai': '302ai', - cephalon: 'cephalon', - lanyun: 'lanyun', - ph8: 'ph8', - openrouter: 'openrouter', - ollama: 'ollama', - 'new-api': 'new-api', - lmstudio: 'lmstudio', - anthropic: 'anthropic', - openai: 'openai', - 'azure-openai': 'azure-openai', - gemini: 'gemini', - vertexai: 'vertexai', - github: 'github', - copilot: 'copilot', - zhipu: 'zhipu', - yi: 'yi', - moonshot: 'moonshot', - baichuan: 'baichuan', - dashscope: 'dashscope', - stepfun: 'stepfun', - doubao: 'doubao', - infini: 'infini', - minimax: 'minimax', - groq: 'groq', - together: 'together', - fireworks: 'fireworks', - nvidia: 'nvidia', - grok: 'grok', - hyperbolic: 'hyperbolic', - mistral: 'mistral', - jina: 'jina', - perplexity: 'perplexity', - modelscope: 'modelscope', - xirang: 'xirang', - hunyuan: 'hunyuan', - 'tencent-cloud-ti': 'tencent-cloud-ti', - 'baidu-cloud': 'baidu-cloud', - gpustack: 'gpustack', - voyageai: 'voyageai', - 'aws-bedrock': 'aws-bedrock', - poe: 'poe' -} as const - -export type SystemProviderId = keyof typeof SystemProviderIds - -export const isSystemProviderId = (id: string): id is SystemProviderId => { - return Object.hasOwn(SystemProviderIds, id) -} - -export type SystemProvider = Provider & { - id: SystemProviderId - isSystem: true - apiOptions?: never -} - -export type VertexProvider = Provider & { - googleCredentials: { - privateKey: string - clientEmail: string - } - project: string - location: string -} - -/** - * 判断是否为系统内置的提供商。比直接使用`provider.isSystem`更好,因为该数据字段不会随着版本更新而变化。 - * @param provider - Provider对象,包含提供商的信息 - * @returns 是否为系统内置提供商 - */ -export const isSystemProvider = (provider: Provider): provider is SystemProvider => { - return isSystemProviderId(provider.id) && !!provider.isSystem -} - -export type ProviderType = - | 'openai' - | 'openai-response' - | 'anthropic' - | 'gemini' - | 'qwenlm' - | 'azure-openai' - | 'vertexai' - | 'mistral' - | 'aws-bedrock' - | 'vertex-anthropic' - export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank' export type ModelTag = Exclude | 'free' @@ -976,39 +825,6 @@ export type OpenAIVerbosity = 'high' | 'medium' | 'low' export type OpenAISummaryText = 'auto' | 'concise' | 'detailed' | 'off' -export const OpenAIServiceTiers = { - auto: 'auto', - default: 'default', - flex: 'flex', - priority: 'priority' -} as const - -export type OpenAIServiceTier = keyof typeof OpenAIServiceTiers - -export function isOpenAIServiceTier(tier: string): tier is OpenAIServiceTier { - return Object.hasOwn(OpenAIServiceTiers, tier) -} - -export const GroqServiceTiers = { - auto: 'auto', - on_demand: 'on_demand', - flex: 'flex', - performance: 'performance' -} as const - -// 从 GroqServiceTiers 对象中提取类型 -export type GroqServiceTier = keyof typeof GroqServiceTiers - -export function isGroqServiceTier(tier: string): tier is GroqServiceTier { - return Object.hasOwn(GroqServiceTiers, tier) -} - -export type ServiceTier = OpenAIServiceTier | GroqServiceTier - -export function isServiceTier(tier: string): tier is ServiceTier { - return isGroqServiceTier(tier) || isOpenAIServiceTier(tier) -} - export type S3Config = { endpoint: string region: string diff --git a/src/renderer/src/types/provider.ts b/src/renderer/src/types/provider.ts new file mode 100644 index 0000000000..02e15a4b66 --- /dev/null +++ b/src/renderer/src/types/provider.ts @@ -0,0 +1,190 @@ +import { Model } from '@types' +import z from 'zod' + +export const ProviderTypeSchema = z.enum([ + 'openai', + 'openai-response', + 'anthropic', + 'gemini', + 'qwenlm', + 'azure-openai', + 'vertexai', + 'mistral', + 'aws-bedrock', + 'vertex-anthropic' +]) + +export type ProviderType = z.infer + +// undefined 视为支持,默认支持 +export type ProviderApiOptions = { + /** 是否不支持 message 的 content 为数组类型 */ + isNotSupportArrayContent?: boolean + /** 是否不支持 stream_options 参数 */ + isNotSupportStreamOptions?: boolean + /** + * @deprecated + * 是否不支持 message 的 role 为 developer */ + isNotSupportDeveloperRole?: boolean + /* 是否支持 message 的 role 为 developer */ + isSupportDeveloperRole?: boolean + /** + * @deprecated + * 是否不支持 service_tier 参数. Only for OpenAI Models. */ + isNotSupportServiceTier?: boolean + /* 是否支持 service_tier 参数. Only for OpenAI Models. */ + isSupportServiceTier?: boolean + /** 是否不支持 enable_thinking 参数 */ + isNotSupportEnableThinking?: boolean +} + +export const OpenAIServiceTiers = { + auto: 'auto', + default: 'default', + flex: 'flex', + priority: 'priority' +} as const + +export type OpenAIServiceTier = keyof typeof OpenAIServiceTiers + +export function isOpenAIServiceTier(tier: string): tier is OpenAIServiceTier { + return Object.hasOwn(OpenAIServiceTiers, tier) +} + +export const GroqServiceTiers = { + auto: 'auto', + on_demand: 'on_demand', + flex: 'flex', + performance: 'performance' +} as const + +// 从 GroqServiceTiers 对象中提取类型 +export type GroqServiceTier = keyof typeof GroqServiceTiers + +export function isGroqServiceTier(tier: string): tier is GroqServiceTier { + return Object.hasOwn(GroqServiceTiers, tier) +} + +export type ServiceTier = OpenAIServiceTier | GroqServiceTier + +export function isServiceTier(tier: string): tier is ServiceTier { + return isGroqServiceTier(tier) || isOpenAIServiceTier(tier) +} + +export type Provider = { + id: string + type: ProviderType + name: string + apiKey: string + apiHost: string + apiVersion?: string + models: Model[] + enabled?: boolean + isSystem?: boolean + isAuthed?: boolean + rateLimit?: number + + // API options + apiOptions?: ProviderApiOptions + serviceTier?: ServiceTier + + /** @deprecated */ + isNotSupportArrayContent?: boolean + /** @deprecated */ + isNotSupportStreamOptions?: boolean + /** @deprecated */ + isNotSupportDeveloperRole?: boolean + /** @deprecated */ + isNotSupportServiceTier?: boolean + + authType?: 'apiKey' | 'oauth' + isVertex?: boolean + notes?: string + extra_headers?: Record +} + +export const SystemProviderIds = { + cherryin: 'cherryin', + silicon: 'silicon', + aihubmix: 'aihubmix', + ocoolai: 'ocoolai', + deepseek: 'deepseek', + ppio: 'ppio', + alayanew: 'alayanew', + qiniu: 'qiniu', + dmxapi: 'dmxapi', + burncloud: 'burncloud', + tokenflux: 'tokenflux', + '302ai': '302ai', + cephalon: 'cephalon', + lanyun: 'lanyun', + ph8: 'ph8', + openrouter: 'openrouter', + ollama: 'ollama', + 'new-api': 'new-api', + lmstudio: 'lmstudio', + anthropic: 'anthropic', + openai: 'openai', + 'azure-openai': 'azure-openai', + gemini: 'gemini', + vertexai: 'vertexai', + github: 'github', + copilot: 'copilot', + zhipu: 'zhipu', + yi: 'yi', + moonshot: 'moonshot', + baichuan: 'baichuan', + dashscope: 'dashscope', + stepfun: 'stepfun', + doubao: 'doubao', + infini: 'infini', + minimax: 'minimax', + groq: 'groq', + together: 'together', + fireworks: 'fireworks', + nvidia: 'nvidia', + grok: 'grok', + hyperbolic: 'hyperbolic', + mistral: 'mistral', + jina: 'jina', + perplexity: 'perplexity', + modelscope: 'modelscope', + xirang: 'xirang', + hunyuan: 'hunyuan', + 'tencent-cloud-ti': 'tencent-cloud-ti', + 'baidu-cloud': 'baidu-cloud', + gpustack: 'gpustack', + voyageai: 'voyageai', + 'aws-bedrock': 'aws-bedrock', + poe: 'poe' +} as const + +export type SystemProviderId = keyof typeof SystemProviderIds + +export const isSystemProviderId = (id: string): id is SystemProviderId => { + return Object.hasOwn(SystemProviderIds, id) +} + +export type SystemProvider = Provider & { + id: SystemProviderId + isSystem: true + apiOptions?: never +} + +export type VertexProvider = Provider & { + googleCredentials: { + privateKey: string + clientEmail: string + } + project: string + location: string +} + +/** + * 判断是否为系统内置的提供商。比直接使用`provider.isSystem`更好,因为该数据字段不会随着版本更新而变化。 + * @param provider - Provider对象,包含提供商的信息 + * @returns 是否为系统内置提供商 + */ +export const isSystemProvider = (provider: Provider): provider is SystemProvider => { + return isSystemProviderId(provider.id) && !!provider.isSystem +} diff --git a/tests/apis/chat.http b/tests/apis/chat.http index 3025bb3b9d..d07fa6bfac 100644 --- a/tests/apis/chat.http +++ b/tests/apis/chat.http @@ -9,7 +9,7 @@ Authorization: Bearer {{token}} ### List Models With Filters -GET {{host}}/v1/models?provider=anthropic&limit=5 +GET {{host}}/v1/models?providerType=anthropic&limit=5 Authorization: Bearer {{token}}