mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
Merge 812028823a into a6ba5d34e0
This commit is contained in:
commit
2e6fba94b9
118
src/main/apiServer/utils/__tests__/provider-alias.test.ts
Normal file
118
src/main/apiServer/utils/__tests__/provider-alias.test.ts
Normal file
@ -0,0 +1,118 @@
|
||||
import { CacheService } from '@main/services/CacheService'
|
||||
import type { Model, Provider } from '@types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { selectMock } = vi.hoisted(() => ({
|
||||
selectMock: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@main/services/ReduxService', () => ({
|
||||
reduxService: {
|
||||
select: selectMock
|
||||
}
|
||||
}))
|
||||
|
||||
import { getProviderByModel, transformModelToOpenAI, validateModelId } from '..'
|
||||
|
||||
describe('api server provider alias', () => {
|
||||
beforeEach(() => {
|
||||
CacheService.clear()
|
||||
selectMock.mockReset()
|
||||
})
|
||||
|
||||
it('formats model id using provider apiIdentifier when present', () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
const model: Model = {
|
||||
id: 'glm-4.6',
|
||||
provider: provider.id,
|
||||
name: 'glm-4.6',
|
||||
group: 'glm'
|
||||
}
|
||||
|
||||
const apiModel = transformModelToOpenAI(model, provider)
|
||||
expect(apiModel.id).toBe('short:glm-4.6')
|
||||
expect(apiModel.provider).toBe('test-provider-uuid')
|
||||
})
|
||||
|
||||
it('resolves provider by apiIdentifier for model routing', async () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
selectMock.mockResolvedValue([provider])
|
||||
|
||||
const resolved = await getProviderByModel('short:glm-4.6')
|
||||
expect(resolved?.id).toBe('test-provider-uuid')
|
||||
})
|
||||
|
||||
it('validates model ids with apiIdentifier prefix', async () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [
|
||||
{
|
||||
id: 'glm-4.6',
|
||||
provider: 'test-provider-uuid',
|
||||
name: 'glm-4.6',
|
||||
group: 'glm'
|
||||
}
|
||||
],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
selectMock.mockResolvedValue([provider])
|
||||
|
||||
const result = await validateModelId('short:glm-4.6')
|
||||
expect(result.valid).toBe(true)
|
||||
expect(result.provider?.id).toBe('test-provider-uuid')
|
||||
expect(result.modelId).toBe('glm-4.6')
|
||||
})
|
||||
|
||||
it('still supports provider.id prefix', async () => {
|
||||
const provider: Provider = {
|
||||
id: 'test-provider-uuid',
|
||||
apiIdentifier: 'short',
|
||||
type: 'openai',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [
|
||||
{
|
||||
id: 'glm-4.6',
|
||||
provider: 'test-provider-uuid',
|
||||
name: 'glm-4.6',
|
||||
group: 'glm'
|
||||
}
|
||||
],
|
||||
enabled: true
|
||||
}
|
||||
|
||||
selectMock.mockResolvedValue([provider])
|
||||
|
||||
const result = await validateModelId('test-provider-uuid:glm-4.6')
|
||||
expect(result.valid).toBe(true)
|
||||
expect(result.provider?.id).toBe('test-provider-uuid')
|
||||
expect(result.modelId).toBe('glm-4.6')
|
||||
})
|
||||
})
|
||||
@ -82,16 +82,24 @@ export async function getProviderByModel(model: string): Promise<Provider | unde
|
||||
}
|
||||
|
||||
const providerId = modelInfo[0]
|
||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||
const matchingProviders = providers.filter((p: Provider) => p.id === providerId || p.apiIdentifier === providerId)
|
||||
const provider = matchingProviders.find((p) => p.id === providerId) ?? matchingProviders[0]
|
||||
|
||||
if (!provider) {
|
||||
logger.warn('Provider not found for model', {
|
||||
providerId,
|
||||
available: providers.map((p) => p.id)
|
||||
available: providers.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier }))
|
||||
})
|
||||
return undefined
|
||||
}
|
||||
|
||||
if (matchingProviders.length > 1) {
|
||||
logger.warn('Multiple providers matched for model prefix', {
|
||||
providerId,
|
||||
matches: matchingProviders.map((p) => ({ id: p.id, apiIdentifier: p.apiIdentifier }))
|
||||
})
|
||||
}
|
||||
|
||||
logger.debug('Provider resolved for model', { providerId, model })
|
||||
return provider
|
||||
} catch (error: any) {
|
||||
@ -200,8 +208,9 @@ export async function validateModelId(model: string): Promise<{
|
||||
|
||||
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
||||
const providerDisplayName = provider?.name
|
||||
const providerModelPrefix = provider?.apiIdentifier?.trim() || provider?.id || model.provider
|
||||
return {
|
||||
id: `${model.provider}:${model.id}`,
|
||||
id: `${providerModelPrefix}:${model.id}`,
|
||||
object: 'model',
|
||||
name: model.name,
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
|
||||
@ -4469,6 +4469,16 @@
|
||||
"api_host_no_valid": "API address is invalid",
|
||||
"api_host_preview": "Preview: {{url}}",
|
||||
"api_host_tooltip": "Override only when your provider requires a custom OpenAI-compatible endpoint.",
|
||||
"api_identifier": {
|
||||
"error": {
|
||||
"duplicate": "Identifier is already used by another provider",
|
||||
"invalid": "Only letters, numbers, '-' and '_' are allowed, and ':' is not allowed"
|
||||
},
|
||||
"label": "API Identifier",
|
||||
"placeholder": "Example: my-provider",
|
||||
"preview": "Example: {{model}}",
|
||||
"tip": "Used as the model prefix in API relay. Leave empty to use the internal UUID."
|
||||
},
|
||||
"api_key": {
|
||||
"label": "API Key",
|
||||
"tip": "Use commas to separate multiple keys"
|
||||
|
||||
@ -4469,6 +4469,16 @@
|
||||
"api_host_no_valid": "API 地址不合法",
|
||||
"api_host_preview": "预览:{{url}}",
|
||||
"api_host_tooltip": "仅在服务商需要自定义的 OpenAI 兼容地址时覆盖。",
|
||||
"api_identifier": {
|
||||
"error": {
|
||||
"duplicate": "该标识符已被其他提供商使用",
|
||||
"invalid": "仅支持字母、数字、“-”、“_”,且不能包含“:”"
|
||||
},
|
||||
"label": "API 标识符",
|
||||
"placeholder": "例如:my-provider",
|
||||
"preview": "示例:{{model}}",
|
||||
"tip": "用于 API Relay 的模型前缀。留空则使用内部 UUID。"
|
||||
},
|
||||
"api_key": {
|
||||
"label": "API 密钥",
|
||||
"tip": "多个密钥使用逗号分隔"
|
||||
|
||||
@ -4469,6 +4469,16 @@
|
||||
"api_host_no_valid": "API 位址不合法",
|
||||
"api_host_preview": "預覽:{{url}}",
|
||||
"api_host_tooltip": "僅在供應商需要自訂的 OpenAI 相容端點時才覆蓋。",
|
||||
"api_identifier": {
|
||||
"error": {
|
||||
"duplicate": "此識別碼已被其他供應商使用",
|
||||
"invalid": "僅支援字母、數字、“-”、“_”,且不能包含“:”"
|
||||
},
|
||||
"label": "API 識別碼",
|
||||
"placeholder": "例如:my-provider",
|
||||
"preview": "範例:{{model}}",
|
||||
"tip": "用於 API Relay 的模型前綴。留空則使用內部 UUID。"
|
||||
},
|
||||
"api_key": {
|
||||
"label": "API 金鑰",
|
||||
"tip": "多個金鑰使用逗號分隔"
|
||||
|
||||
@ -92,6 +92,8 @@ const isAnthropicCompatibleProviderId = (id: string): id is AnthropicCompatibleP
|
||||
|
||||
type HostField = 'apiHost' | 'anthropicApiHost'
|
||||
|
||||
const API_IDENTIFIER_PATTERN = /^[a-zA-Z0-9][a-zA-Z0-9_-]{0,31}$/
|
||||
|
||||
const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
const { provider, updateProvider, models } = useProvider(providerId)
|
||||
const allProviders = useAllProviders()
|
||||
@ -122,6 +124,7 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
const fancyProviderName = getFancyProviderName(provider)
|
||||
|
||||
const [localApiKey, setLocalApiKey] = useState(provider.apiKey)
|
||||
const [apiIdentifier, setApiIdentifier] = useState(provider.apiIdentifier ?? '')
|
||||
const [apiKeyConnectivity, setApiKeyConnectivity] = useState<ApiKeyConnectivity>({
|
||||
status: HealthStatus.NOT_CHECKED,
|
||||
checking: false
|
||||
@ -147,6 +150,10 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
setApiKeyConnectivity({ status: HealthStatus.NOT_CHECKED })
|
||||
}, [provider.apiKey])
|
||||
|
||||
useEffect(() => {
|
||||
setApiIdentifier(provider.apiIdentifier ?? '')
|
||||
}, [provider.apiIdentifier])
|
||||
|
||||
// 同步 localApiKey 到 provider.apiKey(防抖)
|
||||
useEffect(() => {
|
||||
if (localApiKey !== provider.apiKey) {
|
||||
@ -385,6 +392,41 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
|
||||
const isAnthropicOAuth = () => provider.id === 'anthropic' && provider.authType === 'oauth'
|
||||
|
||||
const onUpdateApiIdentifier = useCallback(() => {
|
||||
const normalizedIdentifier = apiIdentifier.trim()
|
||||
|
||||
if (!normalizedIdentifier) {
|
||||
updateProvider({ apiIdentifier: undefined })
|
||||
setApiIdentifier('')
|
||||
return
|
||||
}
|
||||
|
||||
if (!API_IDENTIFIER_PATTERN.test(normalizedIdentifier) || normalizedIdentifier.includes(':')) {
|
||||
window.toast.error(t('settings.provider.api_identifier.error.invalid'))
|
||||
setApiIdentifier(provider.apiIdentifier ?? '')
|
||||
return
|
||||
}
|
||||
|
||||
const conflictProvider = allProviders.find((p) => {
|
||||
if (p.id === provider.id) {
|
||||
return false
|
||||
}
|
||||
if (p.id === normalizedIdentifier) {
|
||||
return true
|
||||
}
|
||||
return p.apiIdentifier?.trim() === normalizedIdentifier
|
||||
})
|
||||
|
||||
if (conflictProvider) {
|
||||
window.toast.error(t('settings.provider.api_identifier.error.duplicate'))
|
||||
setApiIdentifier(provider.apiIdentifier ?? '')
|
||||
return
|
||||
}
|
||||
|
||||
updateProvider({ apiIdentifier: normalizedIdentifier })
|
||||
setApiIdentifier(normalizedIdentifier)
|
||||
}, [allProviders, apiIdentifier, provider.apiIdentifier, provider.id, t, updateProvider])
|
||||
|
||||
return (
|
||||
<SettingContainer theme={theme} style={{ background: 'var(--color-background)' }}>
|
||||
<SettingTitle>
|
||||
@ -418,6 +460,34 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
/>
|
||||
</SettingTitle>
|
||||
<Divider style={{ width: '100%', margin: '10px 0' }} />
|
||||
{!isSystemProvider(provider) && (
|
||||
<>
|
||||
<SettingSubtitle style={{ marginTop: 5, display: 'flex', alignItems: 'center', gap: 6 }}>
|
||||
{t('settings.provider.api_identifier.label')}
|
||||
<HelpTooltip title={t('settings.provider.api_identifier.tip')}></HelpTooltip>
|
||||
</SettingSubtitle>
|
||||
<Input
|
||||
value={apiIdentifier}
|
||||
placeholder={t('settings.provider.api_identifier.placeholder')}
|
||||
onChange={(e) => setApiIdentifier(e.target.value)}
|
||||
onBlur={onUpdateApiIdentifier}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' && !e.nativeEvent.isComposing) {
|
||||
onUpdateApiIdentifier()
|
||||
}
|
||||
}}
|
||||
spellCheck={false}
|
||||
maxLength={32}
|
||||
/>
|
||||
<SettingHelpTextRow style={{ justifyContent: 'space-between' }}>
|
||||
<SettingHelpText>
|
||||
{t('settings.provider.api_identifier.preview', {
|
||||
model: `${apiIdentifier.trim() || provider.id}:glm-4.6`
|
||||
})}
|
||||
</SettingHelpText>
|
||||
</SettingHelpTextRow>
|
||||
</>
|
||||
)}
|
||||
{isProviderSupportAuth(provider) && <ProviderOAuth providerId={provider.id} />}
|
||||
{provider.id === 'openai' && <OpenAIAlert />}
|
||||
{provider.id === 'ovms' && <OVMSSettings />}
|
||||
|
||||
@ -98,6 +98,11 @@ export type Provider = {
|
||||
id: string
|
||||
type: ProviderType
|
||||
name: string
|
||||
/**
|
||||
* Optional short identifier used by the built-in API relay.
|
||||
* When set, models can be referenced as `${apiIdentifier}:${modelId}` instead of `${id}:${modelId}`.
|
||||
*/
|
||||
apiIdentifier?: string
|
||||
apiKey: string
|
||||
apiHost: string
|
||||
anthropicApiHost?: string
|
||||
|
||||
Loading…
Reference in New Issue
Block a user