diff --git a/packages/aiCore/src/core/models/ProviderCreator.ts b/packages/aiCore/src/core/models/ProviderCreator.ts index 3cc9c07f09..6fc37b432f 100644 --- a/packages/aiCore/src/core/models/ProviderCreator.ts +++ b/packages/aiCore/src/core/models/ProviderCreator.ts @@ -90,8 +90,12 @@ export async function createBaseModel({ let provider = creatorFunction(providerSettings) // 加一个特判 - if (providerConfig.id === 'openai' && !isOpenAIChatCompletionOnlyModel(modelId)) { - provider = provider.responses + if (providerConfig.id === 'openai') { + if (!isOpenAIChatCompletionOnlyModel(modelId)) { + provider = provider.responses + } else { + provider = provider.chat + } } // 返回模型实例 if (typeof provider === 'function') { diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index 159f6114a6..6d91f6a011 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -23,6 +23,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR transformParams: async (params: any, context: AiRequestContext) => { const { providerId } = context // console.log('providerId', providerId) + // const modelToProviderId = getModelToProviderId(modelId) + // console.log('modelToProviderId', modelToProviderId) switch (providerId) { case 'openai': { if (config.openai) { @@ -69,6 +71,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR // break // } } + // console.log('params', params) return params } diff --git a/packages/aiCore/src/utils/model.ts b/packages/aiCore/src/utils/model.ts index 5ceaa27351..8bed8ab443 100644 --- a/packages/aiCore/src/utils/model.ts +++ b/packages/aiCore/src/utils/model.ts @@ -10,3 +10,38 @@ export function isOpenAIChatCompletionOnlyModel(modelId: string): boolean { modelId.includes('o1-preview') ) } + +export function isOpenAIReasoningModel(modelId: string): boolean { + return modelId.includes('o1') || modelId.includes('o3') || modelId.includes('o4') +} + +export function isOpenAILLMModel(modelId: string): boolean { + if (modelId.includes('gpt-4o-image')) { + return false + } + if (isOpenAIReasoningModel(modelId)) { + return true + } + if (modelId.includes('gpt')) { + return true + } + return false +} + +export function getModelToProviderId(modelId: string): string | 'openai-compatible' { + const id = modelId.toLowerCase() + + if (id.startsWith('claude')) { + return 'anthropic' + } + + if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { + return 'google' + } + + if (isOpenAILLMModel(modelId)) { + return 'openai' + } + + return 'openai-compatible' +} diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index ced3095f7b..3a586e2fae 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -46,8 +46,8 @@ function getActualProvider(model: Model): Provider { if (provider.id === 'aihubmix') { actualProvider = createAihubmixProvider(model, actualProvider) + console.log('actualProvider', actualProvider) } - if (actualProvider.type === 'gemini') { actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta') } else { @@ -63,15 +63,21 @@ function providerToAiSdkConfig(actualProvider: Provider): { providerId: ProviderId | 'openai-compatible' options: ProviderSettingsMap[keyof ProviderSettingsMap] } { + // console.log('actualProvider', actualProvider) const aiSdkProviderId = getAiSdkProviderId(actualProvider) - + // console.log('aiSdkProviderId', aiSdkProviderId) // 如果provider是openai,则使用strict模式并且默认responses api + const actualProviderId = actualProvider.id const openaiResponseOptions = - aiSdkProviderId === 'openai' + actualProviderId === 'openai' ? { compatibility: 'strict' } - : undefined + : aiSdkProviderId === 'openai' + ? { + compatibility: 'compatible' + } + : undefined if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { const options = ProviderConfigFactory.fromProvider( diff --git a/src/renderer/src/aiCore/provider/aihubmix.ts b/src/renderer/src/aiCore/provider/aihubmix.ts index 349c15cadc..66ca52f1b6 100644 --- a/src/renderer/src/aiCore/provider/aihubmix.ts +++ b/src/renderer/src/aiCore/provider/aihubmix.ts @@ -3,13 +3,16 @@ import { isOpenAILLMModel } from '@renderer/config/models' import { Model, Provider } from '@renderer/types' export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' { + console.log('getAiSdkProviderIdForAihubmix', model) const id = model.id.toLowerCase() if (id.startsWith('claude')) { return 'anthropic' } + // TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑 + // if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { - if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { + if (id.startsWith('gemini') || id.startsWith('imagen')) { return 'google' } diff --git a/src/renderer/src/aiCore/provider/factory.ts b/src/renderer/src/aiCore/provider/factory.ts index d7edb92b73..2711df003f 100644 --- a/src/renderer/src/aiCore/provider/factory.ts +++ b/src/renderer/src/aiCore/provider/factory.ts @@ -26,6 +26,9 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com if (AiCore.isSupported(provider.id)) { return provider.id as ProviderId } + if (AiCore.isSupported(provider.type)) { + return provider.type as ProviderId + } return provider.id as ProviderId } diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index b70e3d7a85..af8618f69e 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -299,13 +299,15 @@ export async function buildStreamTextParams( maxOutputTokens: maxTokens || DEFAULT_MAX_TOKENS, temperature: getTemperature(assistant, model), topP: getTopP(assistant, model), - system: assistant.prompt || '', abortSignal: options.requestOptions?.signal, headers: options.requestOptions?.headers, providerOptions, tools, stopWhen: stepCountIs(10) } + if (assistant.prompt) { + params.system = assistant.prompt + } return { params, modelId: model.id, capabilities: { enableReasoning, enableWebSearch, enableGenerateImage } } }