diff --git a/src/main/apiServer/utils/__tests__/provider-alias.test.ts b/src/main/apiServer/utils/__tests__/provider-alias.test.ts new file mode 100644 index 0000000000..d7aad83b23 --- /dev/null +++ b/src/main/apiServer/utils/__tests__/provider-alias.test.ts @@ -0,0 +1,118 @@ +import { CacheService } from '@main/services/CacheService' +import type { Model, Provider } from '@types' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { selectMock } = vi.hoisted(() => ({ + selectMock: vi.fn() +})) + +vi.mock('@main/services/ReduxService', () => ({ + reduxService: { + select: selectMock + } +})) + +import { getProviderByModel, transformModelToOpenAI, validateModelId } from '..' + +describe('api server provider alias', () => { + beforeEach(() => { + CacheService.clear() + selectMock.mockReset() + }) + + it('formats model id using provider apiIdentifier when present', () => { + const provider: Provider = { + id: 'test-provider-uuid', + apiIdentifier: 'short', + type: 'openai', + name: 'Custom Provider', + apiKey: 'test-key', + apiHost: 'https://example.com/v1', + models: [], + enabled: true + } + + const model: Model = { + id: 'glm-4.6', + provider: provider.id, + name: 'glm-4.6', + group: 'glm' + } + + const apiModel = transformModelToOpenAI(model, provider) + expect(apiModel.id).toBe('short:glm-4.6') + expect(apiModel.provider).toBe('test-provider-uuid') + }) + + it('resolves provider by apiIdentifier for model routing', async () => { + const provider: Provider = { + id: 'test-provider-uuid', + apiIdentifier: 'short', + type: 'openai', + name: 'Custom Provider', + apiKey: 'test-key', + apiHost: 'https://example.com/v1', + models: [], + enabled: true + } + + selectMock.mockResolvedValue([provider]) + + const resolved = await getProviderByModel('short:glm-4.6') + expect(resolved?.id).toBe('test-provider-uuid') + }) + + it('validates model ids with apiIdentifier prefix', async () => { + const provider: Provider = { + id: 'test-provider-uuid', + apiIdentifier: 'short', + type: 'openai', + name: 'Custom Provider', + apiKey: 'test-key', + apiHost: 'https://example.com/v1', + models: [ + { + id: 'glm-4.6', + provider: 'test-provider-uuid', + name: 'glm-4.6', + group: 'glm' + } + ], + enabled: true + } + + selectMock.mockResolvedValue([provider]) + + const result = await validateModelId('short:glm-4.6') + expect(result.valid).toBe(true) + expect(result.provider?.id).toBe('test-provider-uuid') + expect(result.modelId).toBe('glm-4.6') + }) + + it('still supports provider.id prefix', async () => { + const provider: Provider = { + id: 'test-provider-uuid', + apiIdentifier: 'short', + type: 'openai', + name: 'Custom Provider', + apiKey: 'test-key', + apiHost: 'https://example.com/v1', + models: [ + { + id: 'glm-4.6', + provider: 'test-provider-uuid', + name: 'glm-4.6', + group: 'glm' + } + ], + enabled: true + } + + selectMock.mockResolvedValue([provider]) + + const result = await validateModelId('test-provider-uuid:glm-4.6') + expect(result.valid).toBe(true) + expect(result.provider?.id).toBe('test-provider-uuid') + expect(result.modelId).toBe('glm-4.6') + }) +}) diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index e25b49e750..072d9ecf54 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -82,16 +82,24 @@ export async function getProviderByModel(model: string): Promise p.id === providerId) + const matchingProviders = providers.filter((p: Provider) => p.id === providerId || p.apiIdentifier === providerId) + const provider = matchingProviders.find((p) => p.id === providerId) ?? matchingProviders[0] if (!provider) { logger.warn('Provider not found for model', { providerId, - available: providers.map((p) => p.id) + available: providers.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier })) }) return undefined } + if (matchingProviders.length > 1) { + logger.warn('Multiple providers matched for model prefix', { + providerId, + matches: matchingProviders.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier })) + }) + } + logger.debug('Provider resolved for model', { providerId, model }) return provider } catch (error: any) { @@ -200,8 +208,9 @@ export async function validateModelId(model: string): Promise<{ export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel { const providerDisplayName = provider?.name + const providerModelPrefix = provider?.apiIdentifier?.trim() || provider?.id || model.provider return { - id: `${model.provider}:${model.id}`, + id: `${providerModelPrefix}:${model.id}`, object: 'model', name: model.name, created: Math.floor(Date.now() / 1000), diff --git a/src/renderer/src/types/provider.ts b/src/renderer/src/types/provider.ts index edab3a7305..313e0eb7a2 100644 --- a/src/renderer/src/types/provider.ts +++ b/src/renderer/src/types/provider.ts @@ -98,6 +98,11 @@ export type Provider = { id: string type: ProviderType name: string + /** + * Optional short identifier used by the built-in API relay. + * When set, models can be referenced as `${apiIdentifier}:${modelId}` instead of `${id}:${modelId}`. + */ + apiIdentifier?: string apiKey: string apiHost: string anthropicApiHost?: string