mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor(aiCore): improve provider registration and model resolution logic
- Enhanced the model resolution logic to support dynamic provider IDs for chat modes. - Updated provider registration to handle both OpenAI and Azure with a unified approach for chat variants. - Refactored the `buildOpenAIProviderOptions` function to streamline parameter handling and removed unnecessary parameters. - Added configuration options for Azure to manage API versions and deployment URLs effectively.
This commit is contained in:
parent
a7b8b40301
commit
6fa82533d5
@ -28,8 +28,8 @@ export class ModelResolver {
|
||||
let finalProviderId = fallbackProviderId
|
||||
let model: LanguageModelV2
|
||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') {
|
||||
finalProviderId = 'openai-chat'
|
||||
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
|
||||
finalProviderId = `${fallbackProviderId}-chat`
|
||||
}
|
||||
|
||||
// 检查是否是命名空间格式
|
||||
@ -84,6 +84,7 @@ export class ModelResolver {
|
||||
*/
|
||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
console.log('fullModelId', fullModelId)
|
||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||
}
|
||||
|
||||
|
||||
@ -176,7 +176,7 @@ export function registerProvider(providerId: string, provider: any): boolean {
|
||||
// 处理特殊provider逻辑
|
||||
if (providerId === 'openai') {
|
||||
// 注册默认 openai
|
||||
globalRegistryManagement.registerProvider('openai', provider, aliases)
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
|
||||
// 创建并注册 openai-chat 变体
|
||||
const openaiChatProvider = customProvider({
|
||||
@ -185,7 +185,17 @@ export function registerProvider(providerId: string, provider: any): boolean {
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
|
||||
} else if (providerId === 'azure') {
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases)
|
||||
// 跟上面相反,creator产出的默认会调用chat
|
||||
const azureResponsesProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider(providerId, azureResponsesProvider)
|
||||
} else {
|
||||
// 其他provider直接注册
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
|
||||
@ -106,6 +106,17 @@ export function providerToAiSdkConfig(
|
||||
'copilot-vision-request': 'true'
|
||||
}
|
||||
}
|
||||
// azure
|
||||
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
|
||||
extraOptions.apiVersion = actualProvider.apiVersion
|
||||
baseConfig.baseURL += '/openai'
|
||||
if (actualProvider.apiVersion === 'preview') {
|
||||
extraOptions.mode = 'responses'
|
||||
} else {
|
||||
extraOptions.mode = 'chat'
|
||||
extraOptions.useDeploymentBasedUrls = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果AI SDK支持该provider,使用原生配置
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
|
||||
@ -81,7 +81,7 @@ export function buildProviderOptions(
|
||||
case 'openai':
|
||||
case 'azure':
|
||||
providerSpecificOptions = {
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider),
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
||||
serviceTier: serviceTierSetting
|
||||
}
|
||||
break
|
||||
@ -152,8 +152,7 @@ function buildOpenAIProviderOptions(
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
},
|
||||
actualProvider: Provider
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
@ -165,12 +164,6 @@ function buildOpenAIProviderOptions(
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
if (actualProvider.id === 'azure') {
|
||||
providerOptions.apiVersion = actualProvider.apiVersion
|
||||
providerOptions.useDeploymentBasedUrls = true
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user