mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 23:10:20 +08:00
Merge branch 'feat/aisdk-package' of github.com:CherryHQ/cherry-studio into feat/aisdk-package
This commit is contained in:
commit
e56218f3ac
@ -28,8 +28,8 @@ export class ModelResolver {
|
|||||||
let finalProviderId = fallbackProviderId
|
let finalProviderId = fallbackProviderId
|
||||||
let model: LanguageModelV2
|
let model: LanguageModelV2
|
||||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||||
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') {
|
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
|
||||||
finalProviderId = 'openai-chat'
|
finalProviderId = `${fallbackProviderId}-chat`
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否是命名空间格式
|
// 检查是否是命名空间格式
|
||||||
@ -84,6 +84,7 @@ export class ModelResolver {
|
|||||||
*/
|
*/
|
||||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
|
console.log('fullModelId', fullModelId)
|
||||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -176,7 +176,7 @@ export function registerProvider(providerId: string, provider: any): boolean {
|
|||||||
// 处理特殊provider逻辑
|
// 处理特殊provider逻辑
|
||||||
if (providerId === 'openai') {
|
if (providerId === 'openai') {
|
||||||
// 注册默认 openai
|
// 注册默认 openai
|
||||||
globalRegistryManagement.registerProvider('openai', provider, aliases)
|
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||||
|
|
||||||
// 创建并注册 openai-chat 变体
|
// 创建并注册 openai-chat 变体
|
||||||
const openaiChatProvider = customProvider({
|
const openaiChatProvider = customProvider({
|
||||||
@ -185,7 +185,17 @@ export function registerProvider(providerId: string, provider: any): boolean {
|
|||||||
languageModel: (modelId: string) => provider.chat(modelId)
|
languageModel: (modelId: string) => provider.chat(modelId)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
|
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
|
||||||
|
} else if (providerId === 'azure') {
|
||||||
|
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases)
|
||||||
|
// 跟上面相反,creator产出的默认会调用chat
|
||||||
|
const azureResponsesProvider = customProvider({
|
||||||
|
fallbackProvider: {
|
||||||
|
...provider,
|
||||||
|
languageModel: (modelId: string) => provider.responses(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
globalRegistryManagement.registerProvider(providerId, azureResponsesProvider)
|
||||||
} else {
|
} else {
|
||||||
// 其他provider直接注册
|
// 其他provider直接注册
|
||||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||||
|
|||||||
@ -106,6 +106,17 @@ export function providerToAiSdkConfig(
|
|||||||
'copilot-vision-request': 'true'
|
'copilot-vision-request': 'true'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// azure
|
||||||
|
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
|
||||||
|
extraOptions.apiVersion = actualProvider.apiVersion
|
||||||
|
baseConfig.baseURL += '/openai'
|
||||||
|
if (actualProvider.apiVersion === 'preview') {
|
||||||
|
extraOptions.mode = 'responses'
|
||||||
|
} else {
|
||||||
|
extraOptions.mode = 'chat'
|
||||||
|
extraOptions.useDeploymentBasedUrls = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 如果AI SDK支持该provider,使用原生配置
|
// 如果AI SDK支持该provider,使用原生配置
|
||||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||||
|
|||||||
@ -81,7 +81,7 @@ export function buildProviderOptions(
|
|||||||
case 'openai':
|
case 'openai':
|
||||||
case 'azure':
|
case 'azure':
|
||||||
providerSpecificOptions = {
|
providerSpecificOptions = {
|
||||||
...buildOpenAIProviderOptions(assistant, model, capabilities, actualProvider),
|
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
||||||
serviceTier: serviceTierSetting
|
serviceTier: serviceTierSetting
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@ -152,8 +152,7 @@ function buildOpenAIProviderOptions(
|
|||||||
enableReasoning: boolean
|
enableReasoning: boolean
|
||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
},
|
}
|
||||||
actualProvider: Provider
|
|
||||||
): Record<string, any> {
|
): Record<string, any> {
|
||||||
const { enableReasoning } = capabilities
|
const { enableReasoning } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: Record<string, any> = {}
|
||||||
@ -165,12 +164,6 @@ function buildOpenAIProviderOptions(
|
|||||||
...reasoningParams
|
...reasoningParams
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (actualProvider.id === 'azure') {
|
|
||||||
providerOptions.apiVersion = actualProvider.apiVersion
|
|
||||||
providerOptions.useDeploymentBasedUrls = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return providerOptions
|
return providerOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user