Merge branch 'feat/aisdk-package' of github.com:CherryHQ/cherry-studio into feat/aisdk-package

This commit is contained in:
icarus 2025-08-29 19:19:59 +08:00
commit e56218f3ac
4 changed files with 28 additions and 13 deletions

View File

@ -28,8 +28,8 @@ export class ModelResolver {
let finalProviderId = fallbackProviderId let finalProviderId = fallbackProviderId
let model: LanguageModelV2 let model: LanguageModelV2
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移) // 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') { if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
finalProviderId = 'openai-chat' finalProviderId = `${fallbackProviderId}-chat`
} }
// 检查是否是命名空间格式 // 检查是否是命名空间格式
@ -84,6 +84,7 @@ export class ModelResolver {
*/ */
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 { private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}` const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
console.log('fullModelId', fullModelId)
return globalRegistryManagement.languageModel(fullModelId as any) return globalRegistryManagement.languageModel(fullModelId as any)
} }

View File

@ -176,7 +176,7 @@ export function registerProvider(providerId: string, provider: any): boolean {
// 处理特殊provider逻辑 // 处理特殊provider逻辑
if (providerId === 'openai') { if (providerId === 'openai') {
// 注册默认 openai // 注册默认 openai
globalRegistryManagement.registerProvider('openai', provider, aliases) globalRegistryManagement.registerProvider(providerId, provider, aliases)
// 创建并注册 openai-chat 变体 // 创建并注册 openai-chat 变体
const openaiChatProvider = customProvider({ const openaiChatProvider = customProvider({
@ -185,7 +185,17 @@ export function registerProvider(providerId: string, provider: any): boolean {
languageModel: (modelId: string) => provider.chat(modelId) languageModel: (modelId: string) => provider.chat(modelId)
} }
}) })
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider) globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
} else if (providerId === 'azure') {
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases)
// 跟上面相反,creator产出的默认会调用chat
const azureResponsesProvider = customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.responses(modelId)
}
})
globalRegistryManagement.registerProvider(providerId, azureResponsesProvider)
} else { } else {
// 其他provider直接注册 // 其他provider直接注册
globalRegistryManagement.registerProvider(providerId, provider, aliases) globalRegistryManagement.registerProvider(providerId, provider, aliases)

View File

@ -106,6 +106,17 @@ export function providerToAiSdkConfig(
'copilot-vision-request': 'true' 'copilot-vision-request': 'true'
} }
} }
// azure
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
extraOptions.apiVersion = actualProvider.apiVersion
baseConfig.baseURL += '/openai'
if (actualProvider.apiVersion === 'preview') {
extraOptions.mode = 'responses'
} else {
extraOptions.mode = 'chat'
extraOptions.useDeploymentBasedUrls = true
}
}
// 如果AI SDK支持该provider使用原生配置 // 如果AI SDK支持该provider使用原生配置
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {

View File

@ -81,7 +81,7 @@ export function buildProviderOptions(
case 'openai': case 'openai':
case 'azure': case 'azure':
providerSpecificOptions = { providerSpecificOptions = {
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider), ...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting serviceTier: serviceTierSetting
} }
break break
@ -152,8 +152,7 @@ function buildOpenAIProviderOptions(
enableReasoning: boolean enableReasoning: boolean
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
}, }
actualProvider: Provider
): Record<string, any> { ): Record<string, any> {
const { enableReasoning } = capabilities const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: Record<string, any> = {}
@ -165,12 +164,6 @@ function buildOpenAIProviderOptions(
...reasoningParams ...reasoningParams
} }
} }
if (actualProvider.id === 'azure') {
providerOptions.apiVersion = actualProvider.apiVersion
providerOptions.useDeploymentBasedUrls = true
}
return providerOptions return providerOptions
} }