From 6fa82533d5c1fc5b1e695f0e9536a29c6ae02852 Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Fri, 29 Aug 2025 19:15:06 +0800 Subject: [PATCH] refactor(aiCore): improve provider registration and model resolution logic - Enhanced the model resolution logic to support dynamic provider IDs for chat modes. - Updated provider registration to handle both OpenAI and Azure with a unified approach for chat variants. - Refactored the `buildOpenAIProviderOptions` function to streamline parameter handling and removed unnecessary parameters. - Added configuration options for Azure to manage API versions and deployment URLs effectively. --- packages/aiCore/src/core/models/ModelResolver.ts | 5 +++-- packages/aiCore/src/core/providers/registry.ts | 14 ++++++++++++-- src/renderer/src/aiCore/provider/providerConfig.ts | 11 +++++++++++ src/renderer/src/aiCore/utils/options.ts | 11 ++--------- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/packages/aiCore/src/core/models/ModelResolver.ts b/packages/aiCore/src/core/models/ModelResolver.ts index 9d336e819e..0f1bde95c6 100644 --- a/packages/aiCore/src/core/models/ModelResolver.ts +++ b/packages/aiCore/src/core/models/ModelResolver.ts @@ -28,8 +28,8 @@ export class ModelResolver { let finalProviderId = fallbackProviderId let model: LanguageModelV2 // 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移) - if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') { - finalProviderId = 'openai-chat' + if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') { + finalProviderId = `${fallbackProviderId}-chat` } // 检查是否是命名空间格式 @@ -84,6 +84,7 @@ export class ModelResolver { */ private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 { const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}` + console.log('fullModelId', fullModelId) return globalRegistryManagement.languageModel(fullModelId as any) } diff --git a/packages/aiCore/src/core/providers/registry.ts b/packages/aiCore/src/core/providers/registry.ts index 050333d650..e8cd46770a 100644 --- a/packages/aiCore/src/core/providers/registry.ts +++ b/packages/aiCore/src/core/providers/registry.ts @@ -176,7 +176,7 @@ export function registerProvider(providerId: string, provider: any): boolean { // 处理特殊provider逻辑 if (providerId === 'openai') { // 注册默认 openai - globalRegistryManagement.registerProvider('openai', provider, aliases) + globalRegistryManagement.registerProvider(providerId, provider, aliases) // 创建并注册 openai-chat 变体 const openaiChatProvider = customProvider({ @@ -185,7 +185,17 @@ export function registerProvider(providerId: string, provider: any): boolean { languageModel: (modelId: string) => provider.chat(modelId) } }) - globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider) + globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider) + } else if (providerId === 'azure') { + globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases) + // 跟上面相反,creator产出的默认会调用chat + const azureResponsesProvider = customProvider({ + fallbackProvider: { + ...provider, + languageModel: (modelId: string) => provider.responses(modelId) + } + }) + globalRegistryManagement.registerProvider(providerId, azureResponsesProvider) } else { // 其他provider直接注册 globalRegistryManagement.registerProvider(providerId, provider, aliases) diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index e1b7ba0e00..51f9c52db7 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -106,6 +106,17 @@ export function providerToAiSdkConfig( 'copilot-vision-request': 'true' } } + // azure + if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') { + extraOptions.apiVersion = actualProvider.apiVersion + baseConfig.baseURL += '/openai' + if (actualProvider.apiVersion === 'preview') { + extraOptions.mode = 'responses' + } else { + extraOptions.mode = 'chat' + extraOptions.useDeploymentBasedUrls = true + } + } // 如果AI SDK支持该provider,使用原生配置 if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 65ba6503f6..a9c42f0f3e 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -81,7 +81,7 @@ export function buildProviderOptions( case 'openai': case 'azure': providerSpecificOptions = { - ...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider), + ...buildOpenAIProviderOptions(assistant, model, capabilities), serviceTier: serviceTierSetting } break @@ -152,8 +152,7 @@ function buildOpenAIProviderOptions( enableReasoning: boolean enableWebSearch: boolean enableGenerateImage: boolean - }, - actualProvider: Provider + } ): Record { const { enableReasoning } = capabilities let providerOptions: Record = {} @@ -165,12 +164,6 @@ function buildOpenAIProviderOptions( ...reasoningParams } } - - if (actualProvider.id === 'azure') { - providerOptions.apiVersion = actualProvider.apiVersion - providerOptions.useDeploymentBasedUrls = true - } - return providerOptions }