mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
Merge 812028823a into a6ba5d34e0
This commit is contained in:
commit
2e6fba94b9
118
src/main/apiServer/utils/__tests__/provider-alias.test.ts
Normal file
118
src/main/apiServer/utils/__tests__/provider-alias.test.ts
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import { CacheService } from '@main/services/CacheService'
|
||||||
|
import type { Model, Provider } from '@types'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
const { selectMock } = vi.hoisted(() => ({
|
||||||
|
selectMock: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@main/services/ReduxService', () => ({
|
||||||
|
reduxService: {
|
||||||
|
select: selectMock
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
import { getProviderByModel, transformModelToOpenAI, validateModelId } from '..'
|
||||||
|
|
||||||
|
describe('api server provider alias', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
CacheService.clear()
|
||||||
|
selectMock.mockReset()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('formats model id using provider apiIdentifier when present', () => {
|
||||||
|
const provider: Provider = {
|
||||||
|
id: 'test-provider-uuid',
|
||||||
|
apiIdentifier: 'short',
|
||||||
|
type: 'openai',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://example.com/v1',
|
||||||
|
models: [],
|
||||||
|
enabled: true
|
||||||
|
}
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'glm-4.6',
|
||||||
|
provider: provider.id,
|
||||||
|
name: 'glm-4.6',
|
||||||
|
group: 'glm'
|
||||||
|
}
|
||||||
|
|
||||||
|
const apiModel = transformModelToOpenAI(model, provider)
|
||||||
|
expect(apiModel.id).toBe('short:glm-4.6')
|
||||||
|
expect(apiModel.provider).toBe('test-provider-uuid')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('resolves provider by apiIdentifier for model routing', async () => {
|
||||||
|
const provider: Provider = {
|
||||||
|
id: 'test-provider-uuid',
|
||||||
|
apiIdentifier: 'short',
|
||||||
|
type: 'openai',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://example.com/v1',
|
||||||
|
models: [],
|
||||||
|
enabled: true
|
||||||
|
}
|
||||||
|
|
||||||
|
selectMock.mockResolvedValue([provider])
|
||||||
|
|
||||||
|
const resolved = await getProviderByModel('short:glm-4.6')
|
||||||
|
expect(resolved?.id).toBe('test-provider-uuid')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('validates model ids with apiIdentifier prefix', async () => {
|
||||||
|
const provider: Provider = {
|
||||||
|
id: 'test-provider-uuid',
|
||||||
|
apiIdentifier: 'short',
|
||||||
|
type: 'openai',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://example.com/v1',
|
||||||
|
models: [
|
||||||
|
{
|
||||||
|
id: 'glm-4.6',
|
||||||
|
provider: 'test-provider-uuid',
|
||||||
|
name: 'glm-4.6',
|
||||||
|
group: 'glm'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
enabled: true
|
||||||
|
}
|
||||||
|
|
||||||
|
selectMock.mockResolvedValue([provider])
|
||||||
|
|
||||||
|
const result = await validateModelId('short:glm-4.6')
|
||||||
|
expect(result.valid).toBe(true)
|
||||||
|
expect(result.provider?.id).toBe('test-provider-uuid')
|
||||||
|
expect(result.modelId).toBe('glm-4.6')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('still supports provider.id prefix', async () => {
|
||||||
|
const provider: Provider = {
|
||||||
|
id: 'test-provider-uuid',
|
||||||
|
apiIdentifier: 'short',
|
||||||
|
type: 'openai',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://example.com/v1',
|
||||||
|
models: [
|
||||||
|
{
|
||||||
|
id: 'glm-4.6',
|
||||||
|
provider: 'test-provider-uuid',
|
||||||
|
name: 'glm-4.6',
|
||||||
|
group: 'glm'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
enabled: true
|
||||||
|
}
|
||||||
|
|
||||||
|
selectMock.mockResolvedValue([provider])
|
||||||
|
|
||||||
|
const result = await validateModelId('test-provider-uuid:glm-4.6')
|
||||||
|
expect(result.valid).toBe(true)
|
||||||
|
expect(result.provider?.id).toBe('test-provider-uuid')
|
||||||
|
expect(result.modelId).toBe('glm-4.6')
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -82,16 +82,24 @@ export async function getProviderByModel(model: string): Promise<Provider | unde
|
|||||||
}
|
}
|
||||||
|
|
||||||
const providerId = modelInfo[0]
|
const providerId = modelInfo[0]
|
||||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
const matchingProviders = providers.filter((p: Provider) => p.id === providerId || p.apiIdentifier === providerId)
|
||||||
|
const provider = matchingProviders.find((p) => p.id === providerId) ?? matchingProviders[0]
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
logger.warn('Provider not found for model', {
|
logger.warn('Provider not found for model', {
|
||||||
providerId,
|
providerId,
|
||||||
available: providers.map((p) => p.id)
|
available: providers.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier }))
|
||||||
})
|
})
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (matchingProviders.length > 1) {
|
||||||
|
logger.warn('Multiple providers matched for model prefix', {
|
||||||
|
providerId,
|
||||||
|
matches: matchingProviders.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier }))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
logger.debug('Provider resolved for model', { providerId, model })
|
logger.debug('Provider resolved for model', { providerId, model })
|
||||||
return provider
|
return provider
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
@ -200,8 +208,9 @@ export async function validateModelId(model: string): Promise<{
|
|||||||
|
|
||||||
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
||||||
const providerDisplayName = provider?.name
|
const providerDisplayName = provider?.name
|
||||||
|
const providerModelPrefix = provider?.apiIdentifier?.trim() || provider?.id || model.provider
|
||||||
return {
|
return {
|
||||||
id: `${model.provider}:${model.id}`,
|
id: `${providerModelPrefix}:${model.id}`,
|
||||||
object: 'model',
|
object: 'model',
|
||||||
name: model.name,
|
name: model.name,
|
||||||
created: Math.floor(Date.now() / 1000),
|
created: Math.floor(Date.now() / 1000),
|
||||||
|
|||||||
@ -4469,6 +4469,16 @@
|
|||||||
"api_host_no_valid": "API address is invalid",
|
"api_host_no_valid": "API address is invalid",
|
||||||
"api_host_preview": "Preview: {{url}}",
|
"api_host_preview": "Preview: {{url}}",
|
||||||
"api_host_tooltip": "Override only when your provider requires a custom OpenAI-compatible endpoint.",
|
"api_host_tooltip": "Override only when your provider requires a custom OpenAI-compatible endpoint.",
|
||||||
|
"api_identifier": {
|
||||||
|
"error": {
|
||||||
|
"duplicate": "Identifier is already used by another provider",
|
||||||
|
"invalid": "Only letters, numbers, '-' and '_' are allowed, and ':' is not allowed"
|
||||||
|
},
|
||||||
|
"label": "API Identifier",
|
||||||
|
"placeholder": "Example: my-provider",
|
||||||
|
"preview": "Example: {{model}}",
|
||||||
|
"tip": "Used as the model prefix in API relay. Leave empty to use the internal UUID."
|
||||||
|
},
|
||||||
"api_key": {
|
"api_key": {
|
||||||
"label": "API Key",
|
"label": "API Key",
|
||||||
"tip": "Use commas to separate multiple keys"
|
"tip": "Use commas to separate multiple keys"
|
||||||
|
|||||||
@ -4469,6 +4469,16 @@
|
|||||||
"api_host_no_valid": "API 地址不合法",
|
"api_host_no_valid": "API 地址不合法",
|
||||||
"api_host_preview": "预览:{{url}}",
|
"api_host_preview": "预览:{{url}}",
|
||||||
"api_host_tooltip": "仅在服务商需要自定义的 OpenAI 兼容地址时覆盖。",
|
"api_host_tooltip": "仅在服务商需要自定义的 OpenAI 兼容地址时覆盖。",
|
||||||
|
"api_identifier": {
|
||||||
|
"error": {
|
||||||
|
"duplicate": "该标识符已被其他提供商使用",
|
||||||
|
"invalid": "仅支持字母、数字、“-”、“_”,且不能包含“:”"
|
||||||
|
},
|
||||||
|
"label": "API 标识符",
|
||||||
|
"placeholder": "例如:my-provider",
|
||||||
|
"preview": "示例:{{model}}",
|
||||||
|
"tip": "用于 API Relay 的模型前缀。留空则使用内部 UUID。"
|
||||||
|
},
|
||||||
"api_key": {
|
"api_key": {
|
||||||
"label": "API 密钥",
|
"label": "API 密钥",
|
||||||
"tip": "多个密钥使用逗号分隔"
|
"tip": "多个密钥使用逗号分隔"
|
||||||
|
|||||||
@ -4469,6 +4469,16 @@
|
|||||||
"api_host_no_valid": "API 位址不合法",
|
"api_host_no_valid": "API 位址不合法",
|
||||||
"api_host_preview": "預覽:{{url}}",
|
"api_host_preview": "預覽:{{url}}",
|
||||||
"api_host_tooltip": "僅在供應商需要自訂的 OpenAI 相容端點時才覆蓋。",
|
"api_host_tooltip": "僅在供應商需要自訂的 OpenAI 相容端點時才覆蓋。",
|
||||||
|
"api_identifier": {
|
||||||
|
"error": {
|
||||||
|
"duplicate": "此識別碼已被其他供應商使用",
|
||||||
|
"invalid": "僅支援字母、數字、“-”、“_”,且不能包含“:”"
|
||||||
|
},
|
||||||
|
"label": "API 識別碼",
|
||||||
|
"placeholder": "例如:my-provider",
|
||||||
|
"preview": "範例:{{model}}",
|
||||||
|
"tip": "用於 API Relay 的模型前綴。留空則使用內部 UUID。"
|
||||||
|
},
|
||||||
"api_key": {
|
"api_key": {
|
||||||
"label": "API 金鑰",
|
"label": "API 金鑰",
|
||||||
"tip": "多個金鑰使用逗號分隔"
|
"tip": "多個金鑰使用逗號分隔"
|
||||||
|
|||||||
@ -92,6 +92,8 @@ const isAnthropicCompatibleProviderId = (id: string): id is AnthropicCompatibleP
|
|||||||
|
|
||||||
type HostField = 'apiHost' | 'anthropicApiHost'
|
type HostField = 'apiHost' | 'anthropicApiHost'
|
||||||
|
|
||||||
|
const API_IDENTIFIER_PATTERN = /^[a-zA-Z0-9][a-zA-Z0-9_-]{0,31}$/
|
||||||
|
|
||||||
const ProviderSetting: FC<Props> = ({ providerId }) => {
|
const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||||
const { provider, updateProvider, models } = useProvider(providerId)
|
const { provider, updateProvider, models } = useProvider(providerId)
|
||||||
const allProviders = useAllProviders()
|
const allProviders = useAllProviders()
|
||||||
@ -122,6 +124,7 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
|||||||
const fancyProviderName = getFancyProviderName(provider)
|
const fancyProviderName = getFancyProviderName(provider)
|
||||||
|
|
||||||
const [localApiKey, setLocalApiKey] = useState(provider.apiKey)
|
const [localApiKey, setLocalApiKey] = useState(provider.apiKey)
|
||||||
|
const [apiIdentifier, setApiIdentifier] = useState(provider.apiIdentifier ?? '')
|
||||||
const [apiKeyConnectivity, setApiKeyConnectivity] = useState<ApiKeyConnectivity>({
|
const [apiKeyConnectivity, setApiKeyConnectivity] = useState<ApiKeyConnectivity>({
|
||||||
status: HealthStatus.NOT_CHECKED,
|
status: HealthStatus.NOT_CHECKED,
|
||||||
checking: false
|
checking: false
|
||||||
@ -147,6 +150,10 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
|||||||
setApiKeyConnectivity({ status: HealthStatus.NOT_CHECKED })
|
setApiKeyConnectivity({ status: HealthStatus.NOT_CHECKED })
|
||||||
}, [provider.apiKey])
|
}, [provider.apiKey])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setApiIdentifier(provider.apiIdentifier ?? '')
|
||||||
|
}, [provider.apiIdentifier])
|
||||||
|
|
||||||
// 同步 localApiKey 到 provider.apiKey(防抖)
|
// 同步 localApiKey 到 provider.apiKey(防抖)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (localApiKey !== provider.apiKey) {
|
if (localApiKey !== provider.apiKey) {
|
||||||
@ -385,6 +392,41 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
|||||||
|
|
||||||
const isAnthropicOAuth = () => provider.id === 'anthropic' && provider.authType === 'oauth'
|
const isAnthropicOAuth = () => provider.id === 'anthropic' && provider.authType === 'oauth'
|
||||||
|
|
||||||
|
const onUpdateApiIdentifier = useCallback(() => {
|
||||||
|
const normalizedIdentifier = apiIdentifier.trim()
|
||||||
|
|
||||||
|
if (!normalizedIdentifier) {
|
||||||
|
updateProvider({ apiIdentifier: undefined })
|
||||||
|
setApiIdentifier('')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!API_IDENTIFIER_PATTERN.test(normalizedIdentifier) || normalizedIdentifier.includes(':')) {
|
||||||
|
window.toast.error(t('settings.provider.api_identifier.error.invalid'))
|
||||||
|
setApiIdentifier(provider.apiIdentifier ?? '')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const conflictProvider = allProviders.find((p) => {
|
||||||
|
if (p.id === provider.id) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if (p.id === normalizedIdentifier) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return p.apiIdentifier?.trim() === normalizedIdentifier
|
||||||
|
})
|
||||||
|
|
||||||
|
if (conflictProvider) {
|
||||||
|
window.toast.error(t('settings.provider.api_identifier.error.duplicate'))
|
||||||
|
setApiIdentifier(provider.apiIdentifier ?? '')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updateProvider({ apiIdentifier: normalizedIdentifier })
|
||||||
|
setApiIdentifier(normalizedIdentifier)
|
||||||
|
}, [allProviders, apiIdentifier, provider.apiIdentifier, provider.id, t, updateProvider])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<SettingContainer theme={theme} style={{ background: 'var(--color-background)' }}>
|
<SettingContainer theme={theme} style={{ background: 'var(--color-background)' }}>
|
||||||
<SettingTitle>
|
<SettingTitle>
|
||||||
@ -418,6 +460,34 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
|||||||
/>
|
/>
|
||||||
</SettingTitle>
|
</SettingTitle>
|
||||||
<Divider style={{ width: '100%', margin: '10px 0' }} />
|
<Divider style={{ width: '100%', margin: '10px 0' }} />
|
||||||
|
{!isSystemProvider(provider) && (
|
||||||
|
<>
|
||||||
|
<SettingSubtitle style={{ marginTop: 5, display: 'flex', alignItems: 'center', gap: 6 }}>
|
||||||
|
{t('settings.provider.api_identifier.label')}
|
||||||
|
<HelpTooltip title={t('settings.provider.api_identifier.tip')}></HelpTooltip>
|
||||||
|
</SettingSubtitle>
|
||||||
|
<Input
|
||||||
|
value={apiIdentifier}
|
||||||
|
placeholder={t('settings.provider.api_identifier.placeholder')}
|
||||||
|
onChange={(e) => setApiIdentifier(e.target.value)}
|
||||||
|
onBlur={onUpdateApiIdentifier}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === 'Enter' && !e.nativeEvent.isComposing) {
|
||||||
|
onUpdateApiIdentifier()
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
spellCheck={false}
|
||||||
|
maxLength={32}
|
||||||
|
/>
|
||||||
|
<SettingHelpTextRow style={{ justifyContent: 'space-between' }}>
|
||||||
|
<SettingHelpText>
|
||||||
|
{t('settings.provider.api_identifier.preview', {
|
||||||
|
model: `${apiIdentifier.trim() || provider.id}:glm-4.6`
|
||||||
|
})}
|
||||||
|
</SettingHelpText>
|
||||||
|
</SettingHelpTextRow>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
{isProviderSupportAuth(provider) && <ProviderOAuth providerId={provider.id} />}
|
{isProviderSupportAuth(provider) && <ProviderOAuth providerId={provider.id} />}
|
||||||
{provider.id === 'openai' && <OpenAIAlert />}
|
{provider.id === 'openai' && <OpenAIAlert />}
|
||||||
{provider.id === 'ovms' && <OVMSSettings />}
|
{provider.id === 'ovms' && <OVMSSettings />}
|
||||||
|
|||||||
@ -98,6 +98,11 @@ export type Provider = {
|
|||||||
id: string
|
id: string
|
||||||
type: ProviderType
|
type: ProviderType
|
||||||
name: string
|
name: string
|
||||||
|
/**
|
||||||
|
* Optional short identifier used by the built-in API relay.
|
||||||
|
* When set, models can be referenced as `${apiIdentifier}:${modelId}` instead of `${id}:${modelId}`.
|
||||||
|
*/
|
||||||
|
apiIdentifier?: string
|
||||||
apiKey: string
|
apiKey: string
|
||||||
apiHost: string
|
apiHost: string
|
||||||
anthropicApiHost?: string
|
anthropicApiHost?: string
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user