mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-02 02:09:03 +08:00
refactor(aiCore): 重构provider选项构建逻辑以支持更多provider类型
重构buildProviderOptions函数,使用schema验证providerId并分为基础provider和自定义provider处理 新增对deepseek、openai-compatible等provider的支持
This commit is contained in:
parent
005cd730b0
commit
bdbb2c2c75
@ -1,3 +1,4 @@
|
|||||||
|
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
|
||||||
import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
|
import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
|
||||||
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
||||||
import {
|
import {
|
||||||
@ -67,41 +68,66 @@ export function buildProviderOptions(
|
|||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): Record<string, any> {
|
||||||
const providerId = getAiSdkProviderId(actualProvider)
|
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||||
// 构建 provider 特定的选项
|
// 构建 provider 特定的选项
|
||||||
let providerSpecificOptions: Record<string, any> = {}
|
let providerSpecificOptions: Record<string, any> = {}
|
||||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||||
providerSpecificOptions.serviceTier = serviceTierSetting
|
providerSpecificOptions.serviceTier = serviceTierSetting
|
||||||
// 根据 provider 类型分离构建逻辑
|
// 根据 provider 类型分离构建逻辑
|
||||||
switch (providerId) {
|
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
|
||||||
case 'openai':
|
if (success) {
|
||||||
case 'azure':
|
// 应该覆盖所有类型
|
||||||
providerSpecificOptions = {
|
switch (baseProviderId) {
|
||||||
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider),
|
case 'openai':
|
||||||
serviceTier: serviceTierSetting
|
case 'azure':
|
||||||
|
providerSpecificOptions = {
|
||||||
|
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider),
|
||||||
|
serviceTier: serviceTierSetting
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'anthropic':
|
||||||
|
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'google':
|
||||||
|
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'xai':
|
||||||
|
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
|
||||||
|
break
|
||||||
|
case 'deepseek':
|
||||||
|
case 'openai-compatible':
|
||||||
|
case 'openai-responses':
|
||||||
|
// 对于其他 provider,使用通用的构建逻辑
|
||||||
|
providerSpecificOptions = {
|
||||||
|
...buildGenericProviderOptions(assistant, model, capabilities),
|
||||||
|
serviceTier: serviceTierSetting
|
||||||
|
}
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
throw new Error(`Unsupported base provider ${baseProviderId}`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 处理自定义 provider
|
||||||
|
const { data: providerId, success, error } = customProviderIdSchema.safeParse(rawProviderId)
|
||||||
|
if (success) {
|
||||||
|
switch (providerId) {
|
||||||
|
// 非 base provider 的单独处理逻辑
|
||||||
|
case 'google-vertex':
|
||||||
|
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
// 对于其他 provider,使用通用的构建逻辑
|
||||||
|
providerSpecificOptions = {
|
||||||
|
...buildGenericProviderOptions(assistant, model, capabilities),
|
||||||
|
serviceTier: serviceTierSetting
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break
|
} else {
|
||||||
|
throw error
|
||||||
case 'anthropic':
|
}
|
||||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
|
||||||
break
|
|
||||||
|
|
||||||
case 'google':
|
|
||||||
case 'google-vertex':
|
|
||||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
|
||||||
break
|
|
||||||
|
|
||||||
case 'xai':
|
|
||||||
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
|
|
||||||
break
|
|
||||||
|
|
||||||
default:
|
|
||||||
// 对于其他 provider,使用通用的构建逻辑
|
|
||||||
providerSpecificOptions = {
|
|
||||||
...buildGenericProviderOptions(assistant, model, capabilities),
|
|
||||||
serviceTier: serviceTierSetting
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 合并自定义参数到 provider 特定的选项中
|
// 合并自定义参数到 provider 特定的选项中
|
||||||
@ -112,7 +138,7 @@ export function buildProviderOptions(
|
|||||||
|
|
||||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||||
return {
|
return {
|
||||||
[providerId]: providerSpecificOptions
|
[rawProviderId]: providerSpecificOptions
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user