mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-25 03:10:08 +08:00
Merge 4571d01a92 into a6ba5d34e0
This commit is contained in:
commit
cce25a5931
118
src/main/apiServer/utils/__tests__/provider-alias.test.ts
Normal file
118
src/main/apiServer/utils/__tests__/provider-alias.test.ts
Normal file
@ -0,0 +1,118 @@
|
||||
import { CacheService } from '@main/services/CacheService'
|
||||
import type { Model, Provider } from '@types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { selectMock } = vi.hoisted(() => ({
|
||||
selectMock: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@main/services/ReduxService', () => ({
|
||||
reduxService: {
|
||||
select: selectMock
|
||||
}
|
||||
}))
|
||||
|
||||
import { getProviderByModel, transformModelToOpenAI, validateModelId } from '..'
|
||||
|
||||
describe('api server provider alias', () => {
|
||||
beforeEach(() => {
|
||||
CacheService.clear()
|
||||
selectMock.mockReset()
|
||||
})
|
||||
|
||||
it('formats model id using provider apiIdentifier when present', () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
const model: Model = {
|
||||
id: 'glm-4.6',
|
||||
provider: provider.id,
|
||||
name: 'glm-4.6',
|
||||
group: 'glm'
|
||||
}
|
||||
|
||||
const apiModel = transformModelToOpenAI(model, provider)
|
||||
expect(apiModel.id).toBe('short:glm-4.6')
|
||||
expect(apiModel.provider).toBe('test-provider-uuid')
|
||||
})
|
||||
|
||||
it('resolves provider by apiIdentifier for model routing', async () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
selectMock.mockResolvedValue([provider])
|
||||
|
||||
const resolved = await getProviderByModel('short:glm-4.6')
|
||||
expect(resolved?.id).toBe('test-provider-uuid')
|
||||
})
|
||||
|
||||
it('validates model ids with apiIdentifier prefix', async () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [
|
||||
{
|
||||
id: 'glm-4.6',
|
||||
provider: 'test-provider-uuid',
|
||||
name: 'glm-4.6',
|
||||
group: 'glm'
|
||||
}
|
||||
],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
selectMock.mockResolvedValue([provider])
|
||||
|
||||
const result = await validateModelId('short:glm-4.6')
|
||||
expect(result.valid).toBe(true)
|
||||
expect(result.provider?.id).toBe('test-provider-uuid')
|
||||
expect(result.modelId).toBe('glm-4.6')
|
||||
})
|
||||
|
||||
it('still supports provider.id prefix', async () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [
|
||||
{
|
||||
id: 'glm-4.6',
|
||||
provider: 'test-provider-uuid',
|
||||
name: 'glm-4.6',
|
||||
group: 'glm'
|
||||
}
|
||||
],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
selectMock.mockResolvedValue([provider])
|
||||
|
||||
const result = await validateModelId('test-provider-uuid:glm-4.6')
|
||||
expect(result.valid).toBe(true)
|
||||
expect(result.provider?.id).toBe('test-provider-uuid')
|
||||
expect(result.modelId).toBe('glm-4.6')
|
||||
})
|
||||
})
|
||||
@ -82,16 +82,24 @@ export async function getProviderByModel(model: string): Promise<Provider | unde
|
||||
}
|
||||
|
||||
const providerId = modelInfo[0]
|
||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||
const matchingProviders = providers.filter((p: Provider) => p.id === providerId || p.apiIdentifier === providerId)
|
||||
const provider = matchingProviders.find((p) => p.id === providerId) ?? matchingProviders[0]
|
||||
|
||||
if (!provider) {
|
||||
logger.warn('Provider not found for model', {
|
||||
providerId,
|
||||
available: providers.map((p) => p.id)
|
||||
available: providers.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier }))
|
||||
})
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (matchingProviders.length > 1) {
|
||||
logger.warn('Multiple providers matched for model prefix', {
|
||||
providerId,
|
||||
matches: matchingProviders.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier }))
|
||||
})
|
||||
}
|
||||
|
||||
logger.debug('Provider resolved for model', { providerId, model })
|
||||
return provider
|
||||
} catch (error: any) {
|
||||
@ -200,8 +208,9 @@ export async function validateModelId(model: string): Promise<{
|
||||
|
||||
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
||||
const providerDisplayName = provider?.name
|
||||
const providerModelPrefix = provider?.apiIdentifier?.trim() || provider?.id || model.provider
|
||||
return {
|
||||
id: `${model.provider}:${model.id}`,
|
||||
id: `${providerModelPrefix}:${model.id}`,
|
||||
object: 'model',
|
||||
name: model.name,
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
|
||||
@ -98,6 +98,11 @@ export type Provider = {
|
||||
id: string
|
||||
type: ProviderType
|
||||
name: string
|
||||
/**
|
||||
* Optional short identifier used by the built-in API relay.
|
||||
* When set, models can be referenced as `${apiIdentifier}:${modelId}` instead of `${id}:${modelId}`.
|
||||
*/
|
||||
apiIdentifier?: string
|
||||
apiKey: string
|
||||
apiHost: string
|
||||
anthropicApiHost?: string
|
||||
|
||||
Loading…
Reference in New Issue
Block a user