mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: enhance model handling and provider integration
- Updated `createBaseModel` to differentiate between OpenAI chat and response models. - Introduced new utility functions for model identification: `isOpenAIReasoningModel`, `isOpenAILLMModel`, and `getModelToProviderId`. - Improved `transformParameters` to conditionally set the system prompt based on the assistant's prompt. - Refactored `getAiSdkProviderIdForAihubmix` to simplify provider identification logic. - Enhanced `getAiSdkProviderId` to support provider type checks.
This commit is contained in:
parent
a0623f2187
commit
42c7ebd193
@ -90,8 +90,12 @@ export async function createBaseModel({
|
||||
let provider = creatorFunction(providerSettings)
|
||||
|
||||
// 加一个特判
|
||||
if (providerConfig.id === 'openai' && !isOpenAIChatCompletionOnlyModel(modelId)) {
|
||||
provider = provider.responses
|
||||
if (providerConfig.id === 'openai') {
|
||||
if (!isOpenAIChatCompletionOnlyModel(modelId)) {
|
||||
provider = provider.responses
|
||||
} else {
|
||||
provider = provider.chat
|
||||
}
|
||||
}
|
||||
// 返回模型实例
|
||||
if (typeof provider === 'function') {
|
||||
|
||||
@ -23,6 +23,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
// console.log('providerId', providerId)
|
||||
// const modelToProviderId = getModelToProviderId(modelId)
|
||||
// console.log('modelToProviderId', modelToProviderId)
|
||||
switch (providerId) {
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
@ -69,6 +71,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
||||
// break
|
||||
// }
|
||||
}
|
||||
// console.log('params', params)
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
@ -10,3 +10,38 @@ export function isOpenAIChatCompletionOnlyModel(modelId: string): boolean {
|
||||
modelId.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(modelId: string): boolean {
|
||||
return modelId.includes('o1') || modelId.includes('o3') || modelId.includes('o4')
|
||||
}
|
||||
|
||||
export function isOpenAILLMModel(modelId: string): boolean {
|
||||
if (modelId.includes('gpt-4o-image')) {
|
||||
return false
|
||||
}
|
||||
if (isOpenAIReasoningModel(modelId)) {
|
||||
return true
|
||||
}
|
||||
if (modelId.includes('gpt')) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export function getModelToProviderId(modelId: string): string | 'openai-compatible' {
|
||||
const id = modelId.toLowerCase()
|
||||
|
||||
if (id.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (isOpenAILLMModel(modelId)) {
|
||||
return 'openai'
|
||||
}
|
||||
|
||||
return 'openai-compatible'
|
||||
}
|
||||
|
||||
@ -46,8 +46,8 @@ function getActualProvider(model: Model): Provider {
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
actualProvider = createAihubmixProvider(model, actualProvider)
|
||||
console.log('actualProvider', actualProvider)
|
||||
}
|
||||
|
||||
if (actualProvider.type === 'gemini') {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
|
||||
} else {
|
||||
@ -63,15 +63,21 @@ function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
// console.log('actualProvider', actualProvider)
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
// console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||
const actualProviderId = actualProvider.id
|
||||
const openaiResponseOptions =
|
||||
aiSdkProviderId === 'openai'
|
||||
actualProviderId === 'openai'
|
||||
? {
|
||||
compatibility: 'strict'
|
||||
}
|
||||
: undefined
|
||||
: aiSdkProviderId === 'openai'
|
||||
? {
|
||||
compatibility: 'compatible'
|
||||
}
|
||||
: undefined
|
||||
|
||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(
|
||||
|
||||
@ -3,13 +3,16 @@ import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
|
||||
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
|
||||
console.log('getAiSdkProviderIdForAihubmix', model)
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
if (id.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
// TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑
|
||||
// if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
if (id.startsWith('gemini') || id.startsWith('imagen')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
|
||||
@ -26,6 +26,9 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
||||
if (AiCore.isSupported(provider.id)) {
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
if (AiCore.isSupported(provider.type)) {
|
||||
return provider.type as ProviderId
|
||||
}
|
||||
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
|
||||
@ -299,13 +299,15 @@ export async function buildStreamTextParams(
|
||||
maxOutputTokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
system: assistant.prompt || '',
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
providerOptions,
|
||||
tools,
|
||||
stopWhen: stepCountIs(10)
|
||||
}
|
||||
if (assistant.prompt) {
|
||||
params.system = assistant.prompt
|
||||
}
|
||||
|
||||
return { params, modelId: model.id, capabilities: { enableReasoning, enableWebSearch, enableGenerateImage } }
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user