mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 14:31:35 +08:00
chore(aiCore): update version to 1.0.0-alpha.11 and refactor model resolution logic
- Bumped the version of the ai-core package to 1.0.0-alpha.11. - Removed the `isOpenAIChatCompletionOnlyModel` utility function to simplify model resolution. - Adjusted the `providerToAiSdkConfig` function to accept a model parameter for improved configuration handling. - Updated the `ModernAiProvider` class to utilize the new model parameter in its configuration. - Cleaned up deprecated code related to search keyword extraction and reasoning parameters.
This commit is contained in:
parent
1735a9efb6
commit
49cd9d6723
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.10",
|
||||
"version": "1.0.0-alpha.11",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
|
||||
|
||||
@ -28,14 +27,9 @@ export class ModelResolver {
|
||||
): Promise<LanguageModelV2> {
|
||||
let finalProviderId = fallbackProviderId
|
||||
let model: LanguageModelV2
|
||||
|
||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') {
|
||||
// 检查是否支持 chat 模式且不是只支持 chat 的模型
|
||||
if (!isOpenAIChatCompletionOnlyModel(modelId)) {
|
||||
finalProviderId = 'openai-chat'
|
||||
}
|
||||
// 否则使用默认的 openai (responses 模式)
|
||||
finalProviderId = 'openai-chat'
|
||||
}
|
||||
|
||||
// 检查是否是命名空间格式
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
import { type ProviderId } from '../core/providers/types'
|
||||
|
||||
export function isOpenAIChatCompletionOnlyModel(modelId: string): boolean {
|
||||
if (!modelId) {
|
||||
return false
|
||||
}
|
||||
|
||||
return (
|
||||
modelId.includes('gpt-4o-search-preview') ||
|
||||
modelId.includes('gpt-4o-mini-search-preview') ||
|
||||
modelId.includes('o1-mini') ||
|
||||
modelId.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(modelId: string): boolean {
|
||||
return modelId.includes('o1') || modelId.includes('o3') || modelId.includes('o4')
|
||||
}
|
||||
|
||||
export function isOpenAILLMModel(modelId: string): boolean {
|
||||
if (modelId.includes('gpt-4o-image')) {
|
||||
return false
|
||||
}
|
||||
if (isOpenAIReasoningModel(modelId)) {
|
||||
return true
|
||||
}
|
||||
if (modelId.includes('gpt')) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export function getModelToProviderId(modelId: string): ProviderId | 'openai-compatible' {
|
||||
const id = modelId.toLowerCase()
|
||||
|
||||
if (id.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (isOpenAILLMModel(modelId)) {
|
||||
return 'openai'
|
||||
}
|
||||
|
||||
return 'openai-compatible'
|
||||
}
|
||||
@ -37,7 +37,7 @@ export default class ModernAiProvider {
|
||||
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider)
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, model)
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
|
||||
@ -4,6 +4,7 @@ import {
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap
|
||||
} from '@cherrystudio/ai-core/provider'
|
||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { loggerService } from '@renderer/services/LoggerService'
|
||||
@ -68,7 +69,10 @@ export function getActualProvider(model: Model): Provider {
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
* 简化版:利用新的别名映射系统
|
||||
*/
|
||||
export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
export function providerToAiSdkConfig(
|
||||
actualProvider: Provider,
|
||||
model: Model
|
||||
): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
@ -80,10 +84,9 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
baseURL: actualProvider.apiHost,
|
||||
apiKey: actualProvider.apiKey
|
||||
}
|
||||
|
||||
// 处理OpenAI模式(简化逻辑)
|
||||
const extraOptions: any = {}
|
||||
if (actualProvider.type === 'openai-response') {
|
||||
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
|
||||
extraOptions.mode = 'responses'
|
||||
} else if (aiSdkProviderId === 'openai') {
|
||||
extraOptions.mode = 'chat'
|
||||
|
||||
@ -395,105 +395,3 @@ export async function buildGenerateTextParams(
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, provider, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取外部工具搜索关键词和问题
|
||||
* 从用户消息中提取用于网络搜索和知识库搜索的关键词
|
||||
* @deprecated
|
||||
*/
|
||||
// export async function extractSearchKeywords(
|
||||
// lastUserMessage: Message,
|
||||
// assistant: Assistant,
|
||||
// options: {
|
||||
// shouldWebSearch?: boolean
|
||||
// shouldKnowledgeSearch?: boolean
|
||||
// lastAnswer?: Message
|
||||
// } = {}
|
||||
// ): Promise<ExtractResults | undefined> {
|
||||
// const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer } = options
|
||||
|
||||
// if (!lastUserMessage) return undefined
|
||||
|
||||
// // 根据配置决定是否需要提取
|
||||
// const needWebExtract = shouldWebSearch
|
||||
// const needKnowledgeExtract = shouldKnowledgeSearch
|
||||
|
||||
// if (!needWebExtract && !needKnowledgeExtract) return undefined
|
||||
|
||||
// // 选择合适的提示词
|
||||
// let prompt: string
|
||||
// if (needWebExtract && !needKnowledgeExtract) {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
// } else if (!needWebExtract && needKnowledgeExtract) {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
|
||||
// } else {
|
||||
// prompt = SEARCH_SUMMARY_PROMPT
|
||||
// }
|
||||
|
||||
// // 构建用于提取的助手配置
|
||||
// const summaryAssistant = getDefaultAssistant()
|
||||
// summaryAssistant.model = assistant.model || getDefaultModel()
|
||||
// summaryAssistant.prompt = prompt
|
||||
|
||||
// try {
|
||||
// const result = await fetchSearchSummary({
|
||||
// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
|
||||
// assistant: summaryAssistant
|
||||
// })
|
||||
|
||||
// if (!result) return getFallbackResult()
|
||||
|
||||
// const extracted = extractInfoFromXML(result.getText())
|
||||
// // 根据需求过滤结果
|
||||
// return {
|
||||
// websearch: needWebExtract ? extracted?.websearch : undefined,
|
||||
// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
|
||||
// }
|
||||
// } catch (e: any) {
|
||||
// console.error('extract error', e)
|
||||
// return getFallbackResult()
|
||||
// }
|
||||
|
||||
// function getFallbackResult(): ExtractResults {
|
||||
// const fallbackContent = getMainTextContent(lastUserMessage)
|
||||
// return {
|
||||
// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
|
||||
// knowledge: shouldKnowledgeSearch
|
||||
// ? {
|
||||
// question: [fallbackContent || 'search'],
|
||||
// rewrite: fallbackContent || 'search'
|
||||
// }
|
||||
// : undefined
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* 获取搜索摘要 - 内部辅助函数
|
||||
* @deprecated
|
||||
*/
|
||||
// async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||||
// const model = assistant.model || getDefaultModel()
|
||||
// const provider = getProviderByModel(model)
|
||||
|
||||
// if (!hasApiKey(provider)) {
|
||||
// return null
|
||||
// }
|
||||
|
||||
// const AI = new AiProvider(provider)
|
||||
|
||||
// const params: CompletionsParams = {
|
||||
// callType: 'search',
|
||||
// messages: messages,
|
||||
// assistant,
|
||||
// streamOutput: false
|
||||
// }
|
||||
|
||||
// return await AI.completions(params)
|
||||
// }
|
||||
|
||||
// function hasApiKey(provider: Provider) {
|
||||
// if (!provider) return false
|
||||
// if (provider.id === 'ollama' || provider.id === 'lmstudio' || provider.type === 'vertexai') return true
|
||||
// return !isEmpty(provider.apiKey)
|
||||
// }
|
||||
|
||||
@ -1,4 +1,15 @@
|
||||
import { Assistant, Model, Provider } from '@renderer/types'
|
||||
import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
|
||||
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
||||
import {
|
||||
Assistant,
|
||||
GroqServiceTiers,
|
||||
isGroqServiceTier,
|
||||
isOpenAIServiceTier,
|
||||
Model,
|
||||
OpenAIServiceTiers,
|
||||
Provider,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { buildGeminiGenerateImageParams } from './image'
|
||||
@ -12,6 +23,35 @@ import {
|
||||
} from './reasoning'
|
||||
import { getWebSearchParams } from './websearch'
|
||||
|
||||
// copy from BaseApiClient.ts
|
||||
const getServiceTier = (model: Model, provider: Provider) => {
|
||||
const serviceTierSetting = provider.serviceTier
|
||||
|
||||
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 处理不同供应商需要 fallback 到默认值的情况
|
||||
if (provider.id === SystemProviderIds.groq) {
|
||||
if (
|
||||
!isGroqServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
} else {
|
||||
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
||||
if (
|
||||
!isOpenAIServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
return serviceTierSetting
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
* 按 provider 类型分离,保持类型安全
|
||||
@ -28,6 +68,7 @@ export function buildProviderOptions(
|
||||
}
|
||||
): Record<string, any> {
|
||||
const providerId = getAiSdkProviderId(actualProvider)
|
||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
@ -62,6 +103,7 @@ export function buildProviderOptions(
|
||||
// 合并自定义参数到 provider 特定的选项中
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
serviceTier: serviceTierSetting,
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
|
||||
@ -13,7 +13,9 @@ import {
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import { Assistant, EFFORT_RATIO, Model } from '@renderer/types'
|
||||
import { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
||||
|
||||
@ -205,6 +207,16 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
const summaryText = openAI?.summaryText || 'off'
|
||||
|
||||
let reasoningSummary: string | undefined = undefined
|
||||
|
||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||
reasoningSummary = undefined
|
||||
} else {
|
||||
reasoningSummary = summaryText
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
@ -215,7 +227,8 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
||||
// OpenAI 推理参数
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoningEffort
|
||||
reasoningEffort,
|
||||
reasoningSummary
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user