refactor(aiCore): 重构provider选项构建逻辑以支持更多provider类型

重构buildProviderOptions函数,使用schema验证providerId并分为基础provider和自定义provider处理
新增对deepseek、openai-compatible等provider的支持
This commit is contained in:
icarus 2025-08-29 18:43:03 +08:00
parent 005cd730b0
commit bdbb2c2c75

View File

@ -1,3 +1,4 @@
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models' import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers' import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { import {
@ -67,41 +68,66 @@ export function buildProviderOptions(
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): Record<string, any> {
const providerId = getAiSdkProviderId(actualProvider) const rawProviderId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项 // 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {} let providerSpecificOptions: Record<string, any> = {}
const serviceTierSetting = getServiceTier(model, actualProvider) const serviceTierSetting = getServiceTier(model, actualProvider)
providerSpecificOptions.serviceTier = serviceTierSetting providerSpecificOptions.serviceTier = serviceTierSetting
// 根据 provider 类型分离构建逻辑 // 根据 provider 类型分离构建逻辑
switch (providerId) { const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
case 'openai': if (success) {
case 'azure': // 应该覆盖所有类型
providerSpecificOptions = { switch (baseProviderId) {
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider), case 'openai':
serviceTier: serviceTierSetting case 'azure':
providerSpecificOptions = {
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider),
serviceTier: serviceTierSetting
}
break
case 'anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
case 'google':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'xai':
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
break
case 'deepseek':
case 'openai-compatible':
case 'openai-responses':
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
}
break
default:
throw new Error(`Unsupported base provider ${baseProviderId}`)
}
} else {
// 处理自定义 provider
const { data: providerId, success, error } = customProviderIdSchema.safeParse(rawProviderId)
if (success) {
switch (providerId) {
// 非 base provider 的单独处理逻辑
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
}
} }
break } else {
throw error
case 'anthropic': }
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
case 'google':
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'xai':
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
}
break
} }
// 合并自定义参数到 provider 特定的选项中 // 合并自定义参数到 provider 特定的选项中
@ -112,7 +138,7 @@ export function buildProviderOptions(
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
return { return {
[providerId]: providerSpecificOptions [rawProviderId]: providerSpecificOptions
} }
} }