mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 14:59:27 +08:00
fix: normalize provider model data (#11580)
* fix: normalize provider model data * fix(tests): correct provider type in ModelAdapter test
This commit is contained in:
parent
cd699825ed
commit
a566cd65f4
@ -7,10 +7,10 @@
|
|||||||
* 2. 暂时保持接口兼容性
|
* 2. 暂时保持接口兼容性
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
|
||||||
import { createExecutor } from '@cherrystudio/ai-core'
|
import { createExecutor } from '@cherrystudio/ai-core'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||||
|
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
|
||||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||||
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||||
@ -481,18 +481,11 @@ export default class ModernAiProvider {
|
|||||||
// 代理其他方法到原有实现
|
// 代理其他方法到原有实现
|
||||||
public async models() {
|
public async models() {
|
||||||
if (this.actualProvider.id === SystemProviderIds.gateway) {
|
if (this.actualProvider.id === SystemProviderIds.gateway) {
|
||||||
const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] {
|
const gatewayModels = (await gateway.getAvailableModels()).models
|
||||||
return models.map((m) => ({
|
return normalizeGatewayModels(this.actualProvider, gatewayModels)
|
||||||
id: m.id,
|
|
||||||
name: m.name,
|
|
||||||
provider: 'gateway',
|
|
||||||
group: m.id.split('/')[0],
|
|
||||||
description: m.description ?? undefined
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
return formatModel((await gateway.getAvailableModels()).models)
|
|
||||||
}
|
}
|
||||||
return this.legacyProvider.models()
|
const sdkModels = await this.legacyProvider.models()
|
||||||
|
return normalizeSdkModels(this.actualProvider, sdkModels)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import NewApiAddModelPopup from '@renderer/pages/settings/ProviderSettings/Model
|
|||||||
import NewApiBatchAddModelPopup from '@renderer/pages/settings/ProviderSettings/ModelList/NewApiBatchAddModelPopup'
|
import NewApiBatchAddModelPopup from '@renderer/pages/settings/ProviderSettings/ModelList/NewApiBatchAddModelPopup'
|
||||||
import { fetchModels } from '@renderer/services/ApiService'
|
import { fetchModels } from '@renderer/services/ApiService'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { filterModelsByKeywords, getDefaultGroupName, getFancyProviderName } from '@renderer/utils'
|
import { filterModelsByKeywords, getFancyProviderName } from '@renderer/utils'
|
||||||
import { isFreeModel } from '@renderer/utils/model'
|
import { isFreeModel } from '@renderer/utils/model'
|
||||||
import { isNewApiProvider } from '@renderer/utils/provider'
|
import { isNewApiProvider } from '@renderer/utils/provider'
|
||||||
import { Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd'
|
import { Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd'
|
||||||
@ -183,25 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
|
|||||||
setLoadingModels(true)
|
setLoadingModels(true)
|
||||||
try {
|
try {
|
||||||
const models = await fetchModels(provider)
|
const models = await fetchModels(provider)
|
||||||
// TODO: More robust conversion
|
const filteredModels = models.filter((model) => !isEmpty(model.name))
|
||||||
const filteredModels = models
|
|
||||||
.map((model) => ({
|
|
||||||
// @ts-ignore modelId
|
|
||||||
id: model?.id || model?.name,
|
|
||||||
// @ts-ignore name
|
|
||||||
name: model?.display_name || model?.displayName || model?.name || model?.id,
|
|
||||||
provider: provider.id,
|
|
||||||
// @ts-ignore group
|
|
||||||
group: getDefaultGroupName(model?.id || model?.name, provider.id),
|
|
||||||
// @ts-ignore description
|
|
||||||
description: model?.description || '',
|
|
||||||
// @ts-ignore owned_by
|
|
||||||
owned_by: model?.owned_by || '',
|
|
||||||
// @ts-ignore supported_endpoint_types
|
|
||||||
supported_endpoint_types: model?.supported_endpoint_types
|
|
||||||
}))
|
|
||||||
.filter((model) => !isEmpty(model.name))
|
|
||||||
|
|
||||||
setListModels(filteredModels)
|
setListModels(filteredModels)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(`Failed to load models for provider ${getFancyProviderName(provider)}`, error as Error)
|
logger.error(`Failed to load models for provider ${getFancyProviderName(provider)}`, error as Error)
|
||||||
|
|||||||
@ -13,7 +13,6 @@ import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/t
|
|||||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||||
import type { Message, ResponseError } from '@renderer/types/newMessage'
|
import type { Message, ResponseError } from '@renderer/types/newMessage'
|
||||||
import type { SdkModel } from '@renderer/types/sdk'
|
|
||||||
import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils'
|
import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils'
|
||||||
import { abortCompletion, readyToAbort } from '@renderer/utils/abortController'
|
import { abortCompletion, readyToAbort } from '@renderer/utils/abortController'
|
||||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||||
@ -424,7 +423,7 @@ export function hasApiKey(provider: Provider) {
|
|||||||
// return undefined
|
// return undefined
|
||||||
// }
|
// }
|
||||||
|
|
||||||
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
|
export async function fetchModels(provider: Provider): Promise<Model[]> {
|
||||||
const AI = new AiProviderNew(provider)
|
const AI = new AiProviderNew(provider)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|||||||
102
src/renderer/src/services/__tests__/ModelAdapter.test.ts
Normal file
102
src/renderer/src/services/__tests__/ModelAdapter.test.ts
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||||
|
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
|
||||||
|
import type { Model, Provider } from '@renderer/types'
|
||||||
|
import type { EndpointType } from '@renderer/types/index'
|
||||||
|
import type { SdkModel } from '@renderer/types/sdk'
|
||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
|
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
|
||||||
|
id: 'openai',
|
||||||
|
type: 'openai',
|
||||||
|
name: 'OpenAI',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://example.com/v1',
|
||||||
|
models: [],
|
||||||
|
...overrides
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('ModelAdapter', () => {
|
||||||
|
it('adapts generic SDK models into internal models', () => {
|
||||||
|
const provider = createProvider({ id: 'openai' })
|
||||||
|
const models = normalizeSdkModels(provider, [
|
||||||
|
{
|
||||||
|
id: 'gpt-4o-mini',
|
||||||
|
display_name: 'GPT-4o mini',
|
||||||
|
description: 'General purpose model',
|
||||||
|
owned_by: 'openai'
|
||||||
|
} as unknown as SdkModel
|
||||||
|
])
|
||||||
|
|
||||||
|
expect(models).toHaveLength(1)
|
||||||
|
expect(models[0]).toMatchObject({
|
||||||
|
id: 'gpt-4o-mini',
|
||||||
|
name: 'GPT-4o mini',
|
||||||
|
provider: 'openai',
|
||||||
|
group: 'gpt-4o',
|
||||||
|
description: 'General purpose model',
|
||||||
|
owned_by: 'openai'
|
||||||
|
} as Partial<Model>)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('preserves supported endpoint types for New API models', () => {
|
||||||
|
const provider = createProvider({ id: 'new-api' })
|
||||||
|
const endpointTypes: EndpointType[] = ['openai', 'image-generation']
|
||||||
|
const [model] = normalizeSdkModels(provider, [
|
||||||
|
{
|
||||||
|
id: 'new-api-model',
|
||||||
|
name: 'New API Model',
|
||||||
|
supported_endpoint_types: endpointTypes
|
||||||
|
} as unknown as SdkModel
|
||||||
|
])
|
||||||
|
|
||||||
|
expect(model.supported_endpoint_types).toEqual(endpointTypes)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters unsupported endpoint types while keeping valid ones', () => {
|
||||||
|
const provider = createProvider({ id: 'new-api' })
|
||||||
|
const [model] = normalizeSdkModels(provider, [
|
||||||
|
{
|
||||||
|
id: 'another-model',
|
||||||
|
name: 'Another Model',
|
||||||
|
supported_endpoint_types: ['openai', 'unknown-endpoint', 'gemini']
|
||||||
|
} as unknown as SdkModel
|
||||||
|
])
|
||||||
|
|
||||||
|
expect(model.supported_endpoint_types).toEqual(['openai', 'gemini'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('adapts ai-gateway entries through the same adapter', () => {
|
||||||
|
const provider = createProvider({ id: 'ai-gateway', type: 'gateway' })
|
||||||
|
const [model] = normalizeGatewayModels(provider, [
|
||||||
|
{
|
||||||
|
id: 'openai/gpt-4o',
|
||||||
|
name: 'OpenAI GPT-4o',
|
||||||
|
description: 'Gateway entry',
|
||||||
|
specification: {
|
||||||
|
specificationVersion: 'v2',
|
||||||
|
provider: 'openai',
|
||||||
|
modelId: 'gpt-4o'
|
||||||
|
}
|
||||||
|
} as GatewayLanguageModelEntry
|
||||||
|
])
|
||||||
|
|
||||||
|
expect(model).toMatchObject({
|
||||||
|
id: 'openai/gpt-4o',
|
||||||
|
group: 'openai',
|
||||||
|
provider: 'ai-gateway',
|
||||||
|
description: 'Gateway entry'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('drops invalid entries without ids or names', () => {
|
||||||
|
const provider = createProvider()
|
||||||
|
const models = normalizeSdkModels(provider, [
|
||||||
|
{
|
||||||
|
id: '',
|
||||||
|
name: ''
|
||||||
|
} as unknown as SdkModel
|
||||||
|
])
|
||||||
|
|
||||||
|
expect(models).toHaveLength(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
180
src/renderer/src/services/models/ModelAdapter.ts
Normal file
180
src/renderer/src/services/models/ModelAdapter.ts
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { type EndpointType, EndPointTypeSchema, type Model, type Provider } from '@renderer/types'
|
||||||
|
import type { NewApiModel, SdkModel } from '@renderer/types/sdk'
|
||||||
|
import { getDefaultGroupName } from '@renderer/utils/naming'
|
||||||
|
import * as z from 'zod'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('ModelAdapter')
|
||||||
|
|
||||||
|
const EndpointTypeArraySchema = z.array(EndPointTypeSchema).nonempty()
|
||||||
|
|
||||||
|
const NormalizedModelSchema = z.object({
|
||||||
|
id: z.string().trim().min(1),
|
||||||
|
name: z.string().trim().min(1),
|
||||||
|
provider: z.string().trim().min(1),
|
||||||
|
group: z.string().trim().min(1),
|
||||||
|
description: z.string().optional(),
|
||||||
|
owned_by: z.string().optional(),
|
||||||
|
supported_endpoint_types: EndpointTypeArraySchema.optional()
|
||||||
|
})
|
||||||
|
|
||||||
|
type NormalizedModelInput = z.input<typeof NormalizedModelSchema>
|
||||||
|
|
||||||
|
export function normalizeSdkModels(provider: Provider, models: SdkModel[]): Model[] {
|
||||||
|
return normalizeModels(models, (entry) => adaptSdkModel(provider, entry))
|
||||||
|
}
|
||||||
|
|
||||||
|
export function normalizeGatewayModels(provider: Provider, models: GatewayLanguageModelEntry[]): Model[] {
|
||||||
|
return normalizeModels(models, (entry) => adaptGatewayModel(provider, entry))
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeModels<T>(models: T[], transformer: (entry: T) => Model | null): Model[] {
|
||||||
|
const uniqueModels: Model[] = []
|
||||||
|
const seen = new Set<string>()
|
||||||
|
|
||||||
|
for (const entry of models) {
|
||||||
|
const normalized = transformer(entry)
|
||||||
|
if (!normalized) continue
|
||||||
|
if (seen.has(normalized.id)) continue
|
||||||
|
seen.add(normalized.id)
|
||||||
|
uniqueModels.push(normalized)
|
||||||
|
}
|
||||||
|
|
||||||
|
return uniqueModels
|
||||||
|
}
|
||||||
|
|
||||||
|
function adaptSdkModel(provider: Provider, model: SdkModel): Model | null {
|
||||||
|
const id = pickPreferredString([(model as any)?.id, (model as any)?.modelId])
|
||||||
|
const name = pickPreferredString([
|
||||||
|
(model as any)?.display_name,
|
||||||
|
(model as any)?.displayName,
|
||||||
|
(model as any)?.name,
|
||||||
|
id
|
||||||
|
])
|
||||||
|
|
||||||
|
if (!id || !name) {
|
||||||
|
logger.warn('Skip SDK model with missing id or name', {
|
||||||
|
providerId: provider.id,
|
||||||
|
modelSnippet: summarizeModel(model)
|
||||||
|
})
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
const candidate: NormalizedModelInput = {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
provider: provider.id,
|
||||||
|
group: getDefaultGroupName(id, provider.id),
|
||||||
|
description: pickPreferredString([(model as any)?.description, (model as any)?.summary]),
|
||||||
|
owned_by: pickPreferredString([(model as any)?.owned_by, (model as any)?.publisher])
|
||||||
|
}
|
||||||
|
|
||||||
|
const supportedEndpointTypes = pickSupportedEndpointTypes(provider.id, model)
|
||||||
|
if (supportedEndpointTypes) {
|
||||||
|
candidate.supported_endpoint_types = supportedEndpointTypes
|
||||||
|
}
|
||||||
|
|
||||||
|
return validateModel(candidate, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
function adaptGatewayModel(provider: Provider, model: GatewayLanguageModelEntry): Model | null {
|
||||||
|
const id = model?.id?.trim()
|
||||||
|
const name = model?.name?.trim() || id
|
||||||
|
|
||||||
|
if (!id || !name) {
|
||||||
|
logger.warn('Skip gateway model with missing id or name', {
|
||||||
|
providerId: provider.id,
|
||||||
|
modelSnippet: summarizeModel(model)
|
||||||
|
})
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
const candidate: NormalizedModelInput = {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
provider: provider.id,
|
||||||
|
group: getDefaultGroupName(id, provider.id),
|
||||||
|
description: model.description ?? undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
return validateModel(candidate, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
function pickPreferredString(values: Array<unknown>): string | undefined {
|
||||||
|
for (const value of values) {
|
||||||
|
if (typeof value === 'string') {
|
||||||
|
const trimmed = value.trim()
|
||||||
|
if (trimmed.length > 0) {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
function pickSupportedEndpointTypes(providerId: string, model: SdkModel): EndpointType[] | undefined {
|
||||||
|
const candidate =
|
||||||
|
(model as Partial<NewApiModel>).supported_endpoint_types ??
|
||||||
|
((model as Record<string, unknown>).supported_endpoint_types as EndpointType[] | undefined)
|
||||||
|
|
||||||
|
if (!Array.isArray(candidate) || candidate.length === 0) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
const supported: EndpointType[] = []
|
||||||
|
const unsupported: unknown[] = []
|
||||||
|
|
||||||
|
for (const value of candidate) {
|
||||||
|
const parsed = EndPointTypeSchema.safeParse(value)
|
||||||
|
if (parsed.success) {
|
||||||
|
supported.push(parsed.data)
|
||||||
|
} else {
|
||||||
|
unsupported.push(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (unsupported.length > 0) {
|
||||||
|
logger.warn('Pruned unsupported endpoint types', {
|
||||||
|
providerId,
|
||||||
|
values: unsupported,
|
||||||
|
modelSnippet: summarizeModel(model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return supported.length > 0 ? supported : undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
function validateModel(candidate: NormalizedModelInput, source: unknown): Model | null {
|
||||||
|
const parsed = NormalizedModelSchema.safeParse(candidate)
|
||||||
|
if (!parsed.success) {
|
||||||
|
logger.warn('Discard invalid model entry', {
|
||||||
|
providerId: candidate.provider,
|
||||||
|
issues: parsed.error.issues,
|
||||||
|
modelSnippet: summarizeModel(source)
|
||||||
|
})
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed.data
|
||||||
|
}
|
||||||
|
|
||||||
|
function summarizeModel(model: unknown) {
|
||||||
|
if (!model || typeof model !== 'object') {
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
const { id, name, display_name, displayName, description, owned_by, supported_endpoint_types } = model as Record<
|
||||||
|
string,
|
||||||
|
unknown
|
||||||
|
>
|
||||||
|
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
display_name,
|
||||||
|
displayName,
|
||||||
|
description,
|
||||||
|
owned_by,
|
||||||
|
supported_endpoint_types
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -7,6 +7,8 @@ import type { CSSProperties } from 'react'
|
|||||||
export * from './file'
|
export * from './file'
|
||||||
export * from './note'
|
export * from './note'
|
||||||
|
|
||||||
|
import * as z from 'zod'
|
||||||
|
|
||||||
import type { StreamTextParams } from './aiCoreTypes'
|
import type { StreamTextParams } from './aiCoreTypes'
|
||||||
import type { Chunk } from './chunk'
|
import type { Chunk } from './chunk'
|
||||||
import type { FileMetadata } from './file'
|
import type { FileMetadata } from './file'
|
||||||
@ -240,7 +242,15 @@ export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'functio
|
|||||||
export type ModelTag = Exclude<ModelType, 'text'> | 'free'
|
export type ModelTag = Exclude<ModelType, 'text'> | 'free'
|
||||||
|
|
||||||
// "image-generation" is also openai endpoint, but specifically for image generation.
|
// "image-generation" is also openai endpoint, but specifically for image generation.
|
||||||
export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
export const EndPointTypeSchema = z.enum([
|
||||||
|
'openai',
|
||||||
|
'openai-response',
|
||||||
|
'anthropic',
|
||||||
|
'gemini',
|
||||||
|
'image-generation',
|
||||||
|
'jina-rerank'
|
||||||
|
])
|
||||||
|
export type EndpointType = z.infer<typeof EndPointTypeSchema>
|
||||||
|
|
||||||
export type ModelPricing = {
|
export type ModelPricing = {
|
||||||
input_per_million_tokens: number
|
input_per_million_tokens: number
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user