From bdbb2c2c75509ce4606d71ac21d9feed3516cd01 Mon Sep 17 00:00:00 2001 From: icarus Date: Fri, 29 Aug 2025 18:43:03 +0800 Subject: [PATCH] =?UTF-8?q?refactor(aiCore):=20=E9=87=8D=E6=9E=84provider?= =?UTF-8?q?=E9=80=89=E9=A1=B9=E6=9E=84=E5=BB=BA=E9=80=BB=E8=BE=91=E4=BB=A5?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=9B=B4=E5=A4=9Aprovider=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构buildProviderOptions函数,使用schema验证providerId并分为基础provider和自定义provider处理 新增对deepseek、openai-compatible等provider的支持 --- src/renderer/src/aiCore/utils/options.ts | 86 +++++++++++++++--------- 1 file changed, 56 insertions(+), 30 deletions(-) diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index acca70df74..65ba6503f6 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -1,3 +1,4 @@ +import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider' import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models' import { isSupportServiceTierProvider } from '@renderer/config/providers' import { @@ -67,41 +68,66 @@ export function buildProviderOptions( enableGenerateImage: boolean } ): Record { - const providerId = getAiSdkProviderId(actualProvider) + const rawProviderId = getAiSdkProviderId(actualProvider) // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} const serviceTierSetting = getServiceTier(model, actualProvider) providerSpecificOptions.serviceTier = serviceTierSetting // 根据 provider 类型分离构建逻辑 - switch (providerId) { - case 'openai': - case 'azure': - providerSpecificOptions = { - ...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider), - serviceTier: serviceTierSetting + const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId) + if (success) { + // 应该覆盖所有类型 + switch (baseProviderId) { + case 'openai': + case 'azure': + providerSpecificOptions = { + ...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider), + serviceTier: serviceTierSetting + } + break + + case 'anthropic': + providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) + break + + case 'google': + providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) + break + + case 'xai': + providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities) + break + case 'deepseek': + case 'openai-compatible': + case 'openai-responses': + // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = { + ...buildGenericProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } + break + default: + throw new Error(`Unsupported base provider ${baseProviderId}`) + } + } else { + // 处理自定义 provider + const { data: providerId, success, error } = customProviderIdSchema.safeParse(rawProviderId) + if (success) { + switch (providerId) { + // 非 base provider 的单独处理逻辑 + case 'google-vertex': + providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) + break + default: + // 对于其他 provider,使用通用的构建逻辑 + providerSpecificOptions = { + ...buildGenericProviderOptions(assistant, model, capabilities), + serviceTier: serviceTierSetting + } } - break - - case 'anthropic': - providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) - break - - case 'google': - case 'google-vertex': - providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) - break - - case 'xai': - providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities) - break - - default: - // 对于其他 provider,使用通用的构建逻辑 - providerSpecificOptions = { - ...buildGenericProviderOptions(assistant, model, capabilities), - serviceTier: serviceTierSetting - } - break + } else { + throw error + } } // 合并自定义参数到 provider 特定的选项中 @@ -112,7 +138,7 @@ export function buildProviderOptions( // 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } return { - [providerId]: providerSpecificOptions + [rawProviderId]: providerSpecificOptions } }