From a566cd65f4773825c8750f99bbc7ca983053deea Mon Sep 17 00:00:00 2001 From: Phantom Date: Fri, 5 Dec 2025 00:29:38 +0800 Subject: [PATCH] fix: normalize provider model data (#11580) * fix: normalize provider model data * fix(tests): correct provider type in ModelAdapter test --- src/renderer/src/aiCore/index_new.ts | 17 +- .../ModelList/ManageModelsPopup.tsx | 22 +-- src/renderer/src/services/ApiService.ts | 3 +- .../services/__tests__/ModelAdapter.test.ts | 102 ++++++++++ .../src/services/models/ModelAdapter.ts | 180 ++++++++++++++++++ src/renderer/src/types/index.ts | 12 +- 6 files changed, 301 insertions(+), 35 deletions(-) create mode 100644 src/renderer/src/services/__tests__/ModelAdapter.test.ts create mode 100644 src/renderer/src/services/models/ModelAdapter.ts diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 8c031f7754..090b3fc9e1 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -7,10 +7,10 @@ * 2. 暂时保持接口兼容性 */ -import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway' import { createExecutor } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' +import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter' import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types' @@ -481,18 +481,11 @@ export default class ModernAiProvider { // 代理其他方法到原有实现 public async models() { if (this.actualProvider.id === SystemProviderIds.gateway) { - const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] { - return models.map((m) => ({ - id: m.id, - name: m.name, - provider: 'gateway', - group: m.id.split('/')[0], - description: m.description ?? undefined - })) - } - return formatModel((await gateway.getAvailableModels()).models) + const gatewayModels = (await gateway.getAvailableModels()).models + return normalizeGatewayModels(this.actualProvider, gatewayModels) } - return this.legacyProvider.models() + const sdkModels = await this.legacyProvider.models() + return normalizeSdkModels(this.actualProvider, sdkModels) } public async getEmbeddingDimensions(model: Model): Promise { diff --git a/src/renderer/src/pages/settings/ProviderSettings/ModelList/ManageModelsPopup.tsx b/src/renderer/src/pages/settings/ProviderSettings/ModelList/ManageModelsPopup.tsx index 4e4307b8fb..4de4f3104b 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ModelList/ManageModelsPopup.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ModelList/ManageModelsPopup.tsx @@ -18,7 +18,7 @@ import NewApiAddModelPopup from '@renderer/pages/settings/ProviderSettings/Model import NewApiBatchAddModelPopup from '@renderer/pages/settings/ProviderSettings/ModelList/NewApiBatchAddModelPopup' import { fetchModels } from '@renderer/services/ApiService' import type { Model, Provider } from '@renderer/types' -import { filterModelsByKeywords, getDefaultGroupName, getFancyProviderName } from '@renderer/utils' +import { filterModelsByKeywords, getFancyProviderName } from '@renderer/utils' import { isFreeModel } from '@renderer/utils/model' import { isNewApiProvider } from '@renderer/utils/provider' import { Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd' @@ -183,25 +183,7 @@ const PopupContainer: React.FC = ({ providerId, resolve }) => { setLoadingModels(true) try { const models = await fetchModels(provider) - // TODO: More robust conversion - const filteredModels = models - .map((model) => ({ - // @ts-ignore modelId - id: model?.id || model?.name, - // @ts-ignore name - name: model?.display_name || model?.displayName || model?.name || model?.id, - provider: provider.id, - // @ts-ignore group - group: getDefaultGroupName(model?.id || model?.name, provider.id), - // @ts-ignore description - description: model?.description || '', - // @ts-ignore owned_by - owned_by: model?.owned_by || '', - // @ts-ignore supported_endpoint_types - supported_endpoint_types: model?.supported_endpoint_types - })) - .filter((model) => !isEmpty(model.name)) - + const filteredModels = models.filter((model) => !isEmpty(model.name)) setListModels(filteredModels) } catch (error) { logger.error(`Failed to load models for provider ${getFancyProviderName(provider)}`, error as Error) diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index d265d5cb48..c49e88f1ff 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -13,7 +13,6 @@ import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/t import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { type Chunk, ChunkType } from '@renderer/types/chunk' import type { Message, ResponseError } from '@renderer/types/newMessage' -import type { SdkModel } from '@renderer/types/sdk' import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils' import { abortCompletion, readyToAbort } from '@renderer/utils/abortController' import { isToolUseModeFunction } from '@renderer/utils/assistant' @@ -424,7 +423,7 @@ export function hasApiKey(provider: Provider) { // return undefined // } -export async function fetchModels(provider: Provider): Promise { +export async function fetchModels(provider: Provider): Promise { const AI = new AiProviderNew(provider) try { diff --git a/src/renderer/src/services/__tests__/ModelAdapter.test.ts b/src/renderer/src/services/__tests__/ModelAdapter.test.ts new file mode 100644 index 0000000000..29dd6b1625 --- /dev/null +++ b/src/renderer/src/services/__tests__/ModelAdapter.test.ts @@ -0,0 +1,102 @@ +import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway' +import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter' +import type { Model, Provider } from '@renderer/types' +import type { EndpointType } from '@renderer/types/index' +import type { SdkModel } from '@renderer/types/sdk' +import { describe, expect, it } from 'vitest' + +const createProvider = (overrides: Partial = {}): Provider => ({ + id: 'openai', + type: 'openai', + name: 'OpenAI', + apiKey: 'test-key', + apiHost: 'https://example.com/v1', + models: [], + ...overrides +}) + +describe('ModelAdapter', () => { + it('adapts generic SDK models into internal models', () => { + const provider = createProvider({ id: 'openai' }) + const models = normalizeSdkModels(provider, [ + { + id: 'gpt-4o-mini', + display_name: 'GPT-4o mini', + description: 'General purpose model', + owned_by: 'openai' + } as unknown as SdkModel + ]) + + expect(models).toHaveLength(1) + expect(models[0]).toMatchObject({ + id: 'gpt-4o-mini', + name: 'GPT-4o mini', + provider: 'openai', + group: 'gpt-4o', + description: 'General purpose model', + owned_by: 'openai' + } as Partial) + }) + + it('preserves supported endpoint types for New API models', () => { + const provider = createProvider({ id: 'new-api' }) + const endpointTypes: EndpointType[] = ['openai', 'image-generation'] + const [model] = normalizeSdkModels(provider, [ + { + id: 'new-api-model', + name: 'New API Model', + supported_endpoint_types: endpointTypes + } as unknown as SdkModel + ]) + + expect(model.supported_endpoint_types).toEqual(endpointTypes) + }) + + it('filters unsupported endpoint types while keeping valid ones', () => { + const provider = createProvider({ id: 'new-api' }) + const [model] = normalizeSdkModels(provider, [ + { + id: 'another-model', + name: 'Another Model', + supported_endpoint_types: ['openai', 'unknown-endpoint', 'gemini'] + } as unknown as SdkModel + ]) + + expect(model.supported_endpoint_types).toEqual(['openai', 'gemini']) + }) + + it('adapts ai-gateway entries through the same adapter', () => { + const provider = createProvider({ id: 'ai-gateway', type: 'gateway' }) + const [model] = normalizeGatewayModels(provider, [ + { + id: 'openai/gpt-4o', + name: 'OpenAI GPT-4o', + description: 'Gateway entry', + specification: { + specificationVersion: 'v2', + provider: 'openai', + modelId: 'gpt-4o' + } + } as GatewayLanguageModelEntry + ]) + + expect(model).toMatchObject({ + id: 'openai/gpt-4o', + group: 'openai', + provider: 'ai-gateway', + description: 'Gateway entry' + }) + }) + + it('drops invalid entries without ids or names', () => { + const provider = createProvider() + const models = normalizeSdkModels(provider, [ + { + id: '', + name: '' + } as unknown as SdkModel + ]) + + expect(models).toHaveLength(0) + }) +}) diff --git a/src/renderer/src/services/models/ModelAdapter.ts b/src/renderer/src/services/models/ModelAdapter.ts new file mode 100644 index 0000000000..deea631693 --- /dev/null +++ b/src/renderer/src/services/models/ModelAdapter.ts @@ -0,0 +1,180 @@ +import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway' +import { loggerService } from '@logger' +import { type EndpointType, EndPointTypeSchema, type Model, type Provider } from '@renderer/types' +import type { NewApiModel, SdkModel } from '@renderer/types/sdk' +import { getDefaultGroupName } from '@renderer/utils/naming' +import * as z from 'zod' + +const logger = loggerService.withContext('ModelAdapter') + +const EndpointTypeArraySchema = z.array(EndPointTypeSchema).nonempty() + +const NormalizedModelSchema = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), + provider: z.string().trim().min(1), + group: z.string().trim().min(1), + description: z.string().optional(), + owned_by: z.string().optional(), + supported_endpoint_types: EndpointTypeArraySchema.optional() +}) + +type NormalizedModelInput = z.input + +export function normalizeSdkModels(provider: Provider, models: SdkModel[]): Model[] { + return normalizeModels(models, (entry) => adaptSdkModel(provider, entry)) +} + +export function normalizeGatewayModels(provider: Provider, models: GatewayLanguageModelEntry[]): Model[] { + return normalizeModels(models, (entry) => adaptGatewayModel(provider, entry)) +} + +function normalizeModels(models: T[], transformer: (entry: T) => Model | null): Model[] { + const uniqueModels: Model[] = [] + const seen = new Set() + + for (const entry of models) { + const normalized = transformer(entry) + if (!normalized) continue + if (seen.has(normalized.id)) continue + seen.add(normalized.id) + uniqueModels.push(normalized) + } + + return uniqueModels +} + +function adaptSdkModel(provider: Provider, model: SdkModel): Model | null { + const id = pickPreferredString([(model as any)?.id, (model as any)?.modelId]) + const name = pickPreferredString([ + (model as any)?.display_name, + (model as any)?.displayName, + (model as any)?.name, + id + ]) + + if (!id || !name) { + logger.warn('Skip SDK model with missing id or name', { + providerId: provider.id, + modelSnippet: summarizeModel(model) + }) + return null + } + + const candidate: NormalizedModelInput = { + id, + name, + provider: provider.id, + group: getDefaultGroupName(id, provider.id), + description: pickPreferredString([(model as any)?.description, (model as any)?.summary]), + owned_by: pickPreferredString([(model as any)?.owned_by, (model as any)?.publisher]) + } + + const supportedEndpointTypes = pickSupportedEndpointTypes(provider.id, model) + if (supportedEndpointTypes) { + candidate.supported_endpoint_types = supportedEndpointTypes + } + + return validateModel(candidate, model) +} + +function adaptGatewayModel(provider: Provider, model: GatewayLanguageModelEntry): Model | null { + const id = model?.id?.trim() + const name = model?.name?.trim() || id + + if (!id || !name) { + logger.warn('Skip gateway model with missing id or name', { + providerId: provider.id, + modelSnippet: summarizeModel(model) + }) + return null + } + + const candidate: NormalizedModelInput = { + id, + name, + provider: provider.id, + group: getDefaultGroupName(id, provider.id), + description: model.description ?? undefined + } + + return validateModel(candidate, model) +} + +function pickPreferredString(values: Array): string | undefined { + for (const value of values) { + if (typeof value === 'string') { + const trimmed = value.trim() + if (trimmed.length > 0) { + return trimmed + } + } + } + return undefined +} + +function pickSupportedEndpointTypes(providerId: string, model: SdkModel): EndpointType[] | undefined { + const candidate = + (model as Partial).supported_endpoint_types ?? + ((model as Record).supported_endpoint_types as EndpointType[] | undefined) + + if (!Array.isArray(candidate) || candidate.length === 0) { + return undefined + } + + const supported: EndpointType[] = [] + const unsupported: unknown[] = [] + + for (const value of candidate) { + const parsed = EndPointTypeSchema.safeParse(value) + if (parsed.success) { + supported.push(parsed.data) + } else { + unsupported.push(value) + } + } + + if (unsupported.length > 0) { + logger.warn('Pruned unsupported endpoint types', { + providerId, + values: unsupported, + modelSnippet: summarizeModel(model) + }) + } + + return supported.length > 0 ? supported : undefined +} + +function validateModel(candidate: NormalizedModelInput, source: unknown): Model | null { + const parsed = NormalizedModelSchema.safeParse(candidate) + if (!parsed.success) { + logger.warn('Discard invalid model entry', { + providerId: candidate.provider, + issues: parsed.error.issues, + modelSnippet: summarizeModel(source) + }) + return null + } + + return parsed.data +} + +function summarizeModel(model: unknown) { + if (!model || typeof model !== 'object') { + return model + } + const { id, name, display_name, displayName, description, owned_by, supported_endpoint_types } = model as Record< + string, + unknown + > + + return { + id, + name, + display_name, + displayName, + description, + owned_by, + supported_endpoint_types + } +} diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index ad9beb5d5a..5b72a4181c 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -7,6 +7,8 @@ import type { CSSProperties } from 'react' export * from './file' export * from './note' +import * as z from 'zod' + import type { StreamTextParams } from './aiCoreTypes' import type { Chunk } from './chunk' import type { FileMetadata } from './file' @@ -240,7 +242,15 @@ export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'functio export type ModelTag = Exclude | 'free' // "image-generation" is also openai endpoint, but specifically for image generation. -export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank' +export const EndPointTypeSchema = z.enum([ + 'openai', + 'openai-response', + 'anthropic', + 'gemini', + 'image-generation', + 'jina-rerank' +]) +export type EndpointType = z.infer export type ModelPricing = { input_per_million_tokens: number