diff --git a/packages/aiCore/package.json b/packages/aiCore/package.json index 3cfca454c8..2e02700563 100644 --- a/packages/aiCore/package.json +++ b/packages/aiCore/package.json @@ -1,6 +1,6 @@ { "name": "@cherrystudio/ai-core", - "version": "1.0.0-alpha.10", + "version": "1.0.0-alpha.11", "description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK", "main": "dist/index.js", "module": "dist/index.mjs", diff --git a/packages/aiCore/src/core/models/ModelResolver.ts b/packages/aiCore/src/core/models/ModelResolver.ts index e0c001aec1..9d336e819e 100644 --- a/packages/aiCore/src/core/models/ModelResolver.ts +++ b/packages/aiCore/src/core/models/ModelResolver.ts @@ -7,7 +7,6 @@ import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider' -import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' import { wrapModelWithMiddlewares } from '../middleware/wrapper' import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement' @@ -28,14 +27,9 @@ export class ModelResolver { ): Promise { let finalProviderId = fallbackProviderId let model: LanguageModelV2 - // 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移) if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') { - // 检查是否支持 chat 模式且不是只支持 chat 的模型 - if (!isOpenAIChatCompletionOnlyModel(modelId)) { - finalProviderId = 'openai-chat' - } - // 否则使用默认的 openai (responses 模式) + finalProviderId = 'openai-chat' } // 检查是否是命名空间格式 diff --git a/packages/aiCore/src/utils/model.ts b/packages/aiCore/src/utils/model.ts deleted file mode 100644 index 4a37fe1237..0000000000 --- a/packages/aiCore/src/utils/model.ts +++ /dev/null @@ -1,49 +0,0 @@ -import { type ProviderId } from '../core/providers/types' - -export function isOpenAIChatCompletionOnlyModel(modelId: string): boolean { - if (!modelId) { - return false - } - - return ( - modelId.includes('gpt-4o-search-preview') || - modelId.includes('gpt-4o-mini-search-preview') || - modelId.includes('o1-mini') || - modelId.includes('o1-preview') - ) -} - -export function isOpenAIReasoningModel(modelId: string): boolean { - return modelId.includes('o1') || modelId.includes('o3') || modelId.includes('o4') -} - -export function isOpenAILLMModel(modelId: string): boolean { - if (modelId.includes('gpt-4o-image')) { - return false - } - if (isOpenAIReasoningModel(modelId)) { - return true - } - if (modelId.includes('gpt')) { - return true - } - return false -} - -export function getModelToProviderId(modelId: string): ProviderId | 'openai-compatible' { - const id = modelId.toLowerCase() - - if (id.startsWith('claude')) { - return 'anthropic' - } - - if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { - return 'google' - } - - if (isOpenAILLMModel(modelId)) { - return 'openai' - } - - return 'openai-compatible' -} diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index b281dd1a52..849ec54c63 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -11,6 +11,7 @@ import { createExecutor, generateImage } from '@cherrystudio/ai-core' import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { isNotSupportedImageSizeModel } from '@renderer/config/models' +import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' @@ -36,7 +37,7 @@ export default class ModernAiProvider { this.legacyProvider = new LegacyAiProvider(this.actualProvider) // 只保存配置,不预先创建executor - this.config = providerToAiSdkConfig(this.actualProvider) + this.config = providerToAiSdkConfig(this.actualProvider, model) } public getActualProvider() { @@ -52,6 +53,28 @@ export default class ModernAiProvider { topicId?: string callType: string } + ) { + if (config.topicId && getEnableDeveloperMode()) { + // TypeScript类型窄化:确保topicId是string类型 + const traceConfig = { + ...config, + topicId: config.topicId + } + return await this._completionsForTrace(modelId, params, traceConfig) + } else { + return await this._completions(modelId, params, config) + } + } + + private async _completions( + modelId: string, + params: StreamTextParams, + config: AiSdkMiddlewareConfig & { + assistant: Assistant + // topicId for tracing + topicId?: string + callType: string + } ): Promise { // 初始化 provider 到全局管理器 try { @@ -79,7 +102,7 @@ export default class ModernAiProvider { * 带trace支持的completions方法 * 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中 */ - public async completionsForTrace( + private async _completionsForTrace( modelId: string, params: StreamTextParams, config: AiSdkMiddlewareConfig & { @@ -114,7 +137,7 @@ export default class ModernAiProvider { modelId, traceName }) - return await this.completions(modelId, params, config) + return await this._completions(modelId, params, config) } try { @@ -126,7 +149,7 @@ export default class ModernAiProvider { parentSpanCreated: true }) - const result = await this.completions(modelId, params, config) + const result = await this._completions(modelId, params, config) logger.info('Completions finished, ending parent span', { spanId: span.spanContext().spanId, @@ -172,7 +195,6 @@ export default class ModernAiProvider { params: StreamTextParams, config: AiSdkMiddlewareConfig & { assistant: Assistant - // topicId for tracing topicId?: string callType: string } diff --git a/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts b/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts index 4c3f9b5c98..a72f7506e5 100644 --- a/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts +++ b/src/renderer/src/aiCore/provider/ProviderConfigProcessor.ts @@ -4,6 +4,7 @@ import { type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core/provider' +import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { getProviderByModel } from '@renderer/services/AssistantService' import { loggerService } from '@renderer/services/LoggerService' @@ -68,7 +69,10 @@ export function getActualProvider(model: Model): Provider { * 将 Provider 配置转换为新 AI SDK 格式 * 简化版:利用新的别名映射系统 */ -export function providerToAiSdkConfig(actualProvider: Provider): { +export function providerToAiSdkConfig( + actualProvider: Provider, + model: Model +): { providerId: ProviderId | 'openai-compatible' options: ProviderSettingsMap[keyof ProviderSettingsMap] } { @@ -80,10 +84,9 @@ export function providerToAiSdkConfig(actualProvider: Provider): { baseURL: actualProvider.apiHost, apiKey: actualProvider.apiKey } - // 处理OpenAI模式(简化逻辑) const extraOptions: any = {} - if (actualProvider.type === 'openai-response') { + if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) { extraOptions.mode = 'responses' } else if (aiSdkProviderId === 'openai') { extraOptions.mode = 'chat' diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index df791dfbf9..d5fdace254 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -395,105 +395,3 @@ export async function buildGenerateTextParams( // 复用流式参数的构建逻辑 return await buildStreamTextParams(messages, assistant, provider, options) } - -/** - * 提取外部工具搜索关键词和问题 - * 从用户消息中提取用于网络搜索和知识库搜索的关键词 - * @deprecated - */ -// export async function extractSearchKeywords( -// lastUserMessage: Message, -// assistant: Assistant, -// options: { -// shouldWebSearch?: boolean -// shouldKnowledgeSearch?: boolean -// lastAnswer?: Message -// } = {} -// ): Promise { -// const { shouldWebSearch = false, shouldKnowledgeSearch = false, lastAnswer } = options - -// if (!lastUserMessage) return undefined - -// // 根据配置决定是否需要提取 -// const needWebExtract = shouldWebSearch -// const needKnowledgeExtract = shouldKnowledgeSearch - -// if (!needWebExtract && !needKnowledgeExtract) return undefined - -// // 选择合适的提示词 -// let prompt: string -// if (needWebExtract && !needKnowledgeExtract) { -// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY -// } else if (!needWebExtract && needKnowledgeExtract) { -// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY -// } else { -// prompt = SEARCH_SUMMARY_PROMPT -// } - -// // 构建用于提取的助手配置 -// const summaryAssistant = getDefaultAssistant() -// summaryAssistant.model = assistant.model || getDefaultModel() -// summaryAssistant.prompt = prompt - -// try { -// const result = await fetchSearchSummary({ -// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage], -// assistant: summaryAssistant -// }) - -// if (!result) return getFallbackResult() - -// const extracted = extractInfoFromXML(result.getText()) -// // 根据需求过滤结果 -// return { -// websearch: needWebExtract ? extracted?.websearch : undefined, -// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined -// } -// } catch (e: any) { -// console.error('extract error', e) -// return getFallbackResult() -// } - -// function getFallbackResult(): ExtractResults { -// const fallbackContent = getMainTextContent(lastUserMessage) -// return { -// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined, -// knowledge: shouldKnowledgeSearch -// ? { -// question: [fallbackContent || 'search'], -// rewrite: fallbackContent || 'search' -// } -// : undefined -// } -// } -// } - -/** - * 获取搜索摘要 - 内部辅助函数 - * @deprecated - */ -// async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) { -// const model = assistant.model || getDefaultModel() -// const provider = getProviderByModel(model) - -// if (!hasApiKey(provider)) { -// return null -// } - -// const AI = new AiProvider(provider) - -// const params: CompletionsParams = { -// callType: 'search', -// messages: messages, -// assistant, -// streamOutput: false -// } - -// return await AI.completions(params) -// } - -// function hasApiKey(provider: Provider) { -// if (!provider) return false -// if (provider.id === 'ollama' || provider.id === 'lmstudio' || provider.type === 'vertexai') return true -// return !isEmpty(provider.apiKey) -// } diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index d8b1e853cc..65890d5ba7 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -1,4 +1,15 @@ -import { Assistant, Model, Provider } from '@renderer/types' +import { isOpenAIModel, isSupportFlexServiceTierModel } from '@renderer/config/models' +import { isSupportServiceTierProvider } from '@renderer/config/providers' +import { + Assistant, + GroqServiceTiers, + isGroqServiceTier, + isOpenAIServiceTier, + Model, + OpenAIServiceTiers, + Provider, + SystemProviderIds +} from '@renderer/types' import { getAiSdkProviderId } from '../provider/factory' import { buildGeminiGenerateImageParams } from './image' @@ -12,6 +23,35 @@ import { } from './reasoning' import { getWebSearchParams } from './websearch' +// copy from BaseApiClient.ts +const getServiceTier = (model: Model, provider: Provider) => { + const serviceTierSetting = provider.serviceTier + + if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) { + return undefined + } + + // 处理不同供应商需要 fallback 到默认值的情况 + if (provider.id === SystemProviderIds.groq) { + if ( + !isGroqServiceTier(serviceTierSetting) || + (serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model)) + ) { + return undefined + } + } else { + // 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同 + if ( + !isOpenAIServiceTier(serviceTierSetting) || + (serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model)) + ) { + return undefined + } + } + + return serviceTierSetting +} + /** * 构建 AI SDK 的 providerOptions * 按 provider 类型分离,保持类型安全 @@ -28,6 +68,7 @@ export function buildProviderOptions( } ): Record { const providerId = getAiSdkProviderId(actualProvider) + const serviceTierSetting = getServiceTier(model, actualProvider) // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} @@ -62,6 +103,7 @@ export function buildProviderOptions( // 合并自定义参数到 provider 特定的选项中 providerSpecificOptions = { ...providerSpecificOptions, + serviceTier: serviceTierSetting, ...getCustomParameters(assistant) } diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 6eb0db358b..07f1958fe9 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -13,7 +13,9 @@ import { isSupportedThinkingTokenModel, isSupportedThinkingTokenQwenModel } from '@renderer/config/models' +import { getStoreSetting } from '@renderer/hooks/useSettings' import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' +import { SettingsState } from '@renderer/store/settings' import { Assistant, EFFORT_RATIO, Model } from '@renderer/types' import { ReasoningEffortOptionalParams } from '@renderer/types/sdk' @@ -205,6 +207,16 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re if (!isReasoningModel(model)) { return {} } + const openAI = getStoreSetting('openAI') as SettingsState['openAI'] + const summaryText = openAI?.summaryText || 'off' + + let reasoningSummary: string | undefined = undefined + + if (summaryText === 'off' || model.id.includes('o1-pro')) { + reasoningSummary = undefined + } else { + reasoningSummary = summaryText + } const reasoningEffort = assistant?.settings?.reasoning_effort @@ -215,7 +227,8 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re // OpenAI 推理参数 if (isSupportedReasoningEffortOpenAIModel(model)) { return { - reasoningEffort + reasoningEffort, + reasoningSummary } } diff --git a/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx b/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx index 58dbd7e7d4..aab4ee3c63 100644 --- a/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx +++ b/src/renderer/src/pages/home/Messages/Blocks/ErrorBlock.tsx @@ -181,6 +181,7 @@ ${t('error.stack')}: ${error.stack || 'N/A'} return (