diff --git a/packages/aiCore/src/core/options/xai.ts b/packages/aiCore/src/core/options/xai.ts index 13e3e7dd6b..8d82f587e8 100644 --- a/packages/aiCore/src/core/options/xai.ts +++ b/packages/aiCore/src/core/options/xai.ts @@ -1,7 +1,7 @@ // copy from @ai-sdk/xai/xai-chat-options.ts // 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件 -import { z } from 'zod' +import * as z from 'zod/v4' const webSourceSchema = z.object({ type: z.literal('web'), @@ -25,7 +25,7 @@ const newsSourceSchema = z.object({ const rssSourceSchema = z.object({ type: z.literal('rss'), - links: z.array(z.string().url()).max(1) // currently only supports one RSS link + links: z.array(z.url()).max(1) // currently only supports one RSS link }) const searchSourceSchema = z.discriminatedUnion('type', [ diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 2b08a59175..091f725435 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -30,14 +30,11 @@ import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk import { CompletionsResult } from './middleware/schemas' import reasoningTimePlugin from './plugins/reasoningTimePlugin' import { getAiSdkProviderId } from './provider/factory' +import { getProviderByModel } from '@renderer/services/AssistantService' +import { createAihubmixProvider } from './provider/aihubmix' -/** - * 将 Provider 配置转换为新 AI SDK 格式 - */ -function providerToAiSdkConfig(provider: Provider): { - providerId: ProviderId | 'openai-compatible' - options: ProviderSettingsMap[keyof ProviderSettingsMap] -} { +function getActualProvider(model: Model): Provider { + const provider = getProviderByModel(model) // 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider let actualProvider = cloneDeep(provider) if (provider.type === 'vertexai' && !isVertexProvider(provider)) { @@ -47,18 +44,25 @@ function providerToAiSdkConfig(provider: Provider): { actualProvider = createVertexProvider(provider) } - if ( - actualProvider.type === 'openai' || - actualProvider.type === 'anthropic' || - actualProvider.type === 'openai-response' - ) { - actualProvider.apiHost = formatApiHost(actualProvider.apiHost) + if (provider.id === 'aihubmix') { + actualProvider = createAihubmixProvider(model, actualProvider) } if (actualProvider.type === 'gemini') { actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta') + } else { + actualProvider.apiHost = formatApiHost(actualProvider.apiHost) } + return actualProvider +} +/** + * 将 Provider 配置转换为新 AI SDK 格式 + */ +function providerToAiSdkConfig(actualProvider: Provider): { + providerId: ProviderId | 'openai-compatible' + options: ProviderSettingsMap[keyof ProviderSettingsMap] +} { const aiSdkProviderId = getAiSdkProviderId(actualProvider) // 如果provider是openai,则使用strict模式并且默认responses api @@ -126,14 +130,18 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean { export default class ModernAiProvider { private legacyProvider: LegacyAiProvider private config: ReturnType + private actualProvider: Provider - constructor(provider: Provider) { - this.legacyProvider = new LegacyAiProvider(provider) + constructor(model: Model) { + this.actualProvider = getActualProvider(model) + this.legacyProvider = new LegacyAiProvider(this.actualProvider) // 只保存配置,不预先创建executor - this.config = providerToAiSdkConfig(provider) + this.config = providerToAiSdkConfig(this.actualProvider) + } - console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled') + public getActualProvider() { + return this.actualProvider } /** diff --git a/src/renderer/src/aiCore/provider/aihubmix.ts b/src/renderer/src/aiCore/provider/aihubmix.ts new file mode 100644 index 0000000000..349c15cadc --- /dev/null +++ b/src/renderer/src/aiCore/provider/aihubmix.ts @@ -0,0 +1,55 @@ +import { ProviderId } from '@cherrystudio/ai-core/types' +import { isOpenAILLMModel } from '@renderer/config/models' +import { Model, Provider } from '@renderer/types' + +export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' { + const id = model.id.toLowerCase() + + if (id.startsWith('claude')) { + return 'anthropic' + } + + if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) { + return 'google' + } + + if (isOpenAILLMModel(model)) { + return 'openai' + } + + return 'openai-compatible' +} + +export function createAihubmixProvider(model: Model, provider: Provider): Provider { + const providerId = getAiSdkProviderIdForAihubmix(model) + provider = { + ...provider, + extra_headers: { + ...provider.extra_headers, + 'APP-Code': 'MLTG2087' + } + } + if (providerId === 'google') { + return { + ...provider, + type: 'gemini', + apiHost: 'https://aihubmix.com/gemini' + } + } + + if (providerId === 'openai') { + return { + ...provider, + type: 'openai' + } + } + + if (providerId === 'anthropic') { + return { + ...provider, + type: 'anthropic' + } + } + + return provider +} diff --git a/src/renderer/src/aiCore/transformParameters.ts b/src/renderer/src/aiCore/transformParameters.ts index 986f628bdf..b70e3d7a85 100644 --- a/src/renderer/src/aiCore/transformParameters.ts +++ b/src/renderer/src/aiCore/transformParameters.ts @@ -27,7 +27,7 @@ import { isWebSearchModel } from '@renderer/config/models' import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService' -import type { Assistant, MCPTool, Message, Model } from '@renderer/types' +import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types' import { FileTypes } from '@renderer/types' import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' import { @@ -241,6 +241,7 @@ export async function convertMessagesToSdkMessages( export async function buildStreamTextParams( sdkMessages: StreamTextParams['messages'], assistant: Assistant, + provider: Provider, options: { mcpTools?: MCPTool[] enableTools?: boolean @@ -285,14 +286,8 @@ export async function buildStreamTextParams( enableToolUse: enableTools }) - // Add web search tools if enabled - // if (enableWebSearch) { - // const webSearchTools = getWebSearchTools(model) - // tools = { ...tools, ...webSearchTools } - // } - // 构建真正的 providerOptions - const providerOptions = buildProviderOptions(assistant, model, { + const providerOptions = buildProviderOptions(assistant, model, provider, { enableReasoning, enableWebSearch, enableGenerateImage @@ -321,11 +316,12 @@ export async function buildStreamTextParams( export async function buildGenerateTextParams( messages: ModelMessage[], assistant: Assistant, + provider: Provider, options: { mcpTools?: MCPTool[] enableTools?: boolean } = {} ): Promise { // 复用流式参数的构建逻辑 - return await buildStreamTextParams(messages, assistant, options) + return await buildStreamTextParams(messages, assistant, provider, options) } diff --git a/src/renderer/src/aiCore/utils/options.ts b/src/renderer/src/aiCore/utils/options.ts index 2a23a9de9d..fd0198c4e3 100644 --- a/src/renderer/src/aiCore/utils/options.ts +++ b/src/renderer/src/aiCore/utils/options.ts @@ -1,5 +1,5 @@ -import { getProviderByModel } from '@renderer/services/AssistantService' -import { Assistant, Model } from '@renderer/types' +import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService' +import { Assistant, Model, Provider } from '@renderer/types' import { getAiSdkProviderId } from '../provider/factory' import { @@ -7,8 +7,10 @@ import { getCustomParameters, getGeminiReasoningParams, getOpenAIReasoningParams, - getReasoningEffort + getReasoningEffort, + getXAIReasoningParams } from './reasoning' +import { getWebSearchParams } from './websearch' /** * 构建 AI SDK 的 providerOptions @@ -18,25 +20,22 @@ import { export function buildProviderOptions( assistant: Assistant, model: Model, + actualProvider: Provider, capabilities: { enableReasoning: boolean enableWebSearch: boolean enableGenerateImage: boolean } ): Record { - const provider = getProviderByModel(model) - const providerId = getAiSdkProviderId(provider) + const providerId = getAiSdkProviderId(actualProvider) // 构建 provider 特定的选项 let providerSpecificOptions: Record = {} - console.log('buildProviderOptions', providerId) - console.log('buildProviderOptions', provider) - // 根据 provider 类型分离构建逻辑 - switch (provider.type) { - case 'openai-response': - case 'azure-openai': + switch (providerId) { + case 'openai': + case 'azure': providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities) break @@ -44,11 +43,15 @@ export function buildProviderOptions( providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) break - case 'gemini': - case 'vertexai': + case 'google': + case 'google-vertex': providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) break + case 'xai': + providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities) + break + default: // 对于其他 provider,使用通用的构建逻辑 providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities) @@ -79,7 +82,7 @@ function buildOpenAIProviderOptions( enableGenerateImage: boolean } ): Record { - const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities + const { enableReasoning } = capabilities let providerOptions: Record = {} // OpenAI 推理参数 @@ -91,15 +94,6 @@ function buildOpenAIProviderOptions( } } - // Web 搜索和图像生成暂时使用通用格式 - if (enableWebSearch) { - providerOptions.webSearch = { enabled: true } - } - - if (enableGenerateImage) { - providerOptions.generateImage = { enabled: true } - } - return providerOptions } @@ -115,7 +109,7 @@ function buildAnthropicProviderOptions( enableGenerateImage: boolean } ): Record { - const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities + const { enableReasoning } = capabilities let providerOptions: Record = {} // Anthropic 推理参数 @@ -127,14 +121,6 @@ function buildAnthropicProviderOptions( } } - if (enableWebSearch) { - providerOptions.webSearch = { enabled: true } - } - - if (enableGenerateImage) { - providerOptions.generateImage = { enabled: true } - } - return providerOptions } @@ -150,21 +136,39 @@ function buildGeminiProviderOptions( enableGenerateImage: boolean } ): Record { - const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities - const providerOptions: Record = {} + const { enableReasoning } = capabilities + let providerOptions: Record = {} // Gemini 推理参数 if (enableReasoning) { const reasoningParams = getGeminiReasoningParams(assistant, model) - Object.assign(providerOptions, reasoningParams) + providerOptions = { + ...providerOptions, + ...reasoningParams + } } - if (enableWebSearch) { - providerOptions.webSearch = { enabled: true } - } + return providerOptions +} - if (enableGenerateImage) { - providerOptions.generateImage = { enabled: true } +function buildXAIProviderOptions( + assistant: Assistant, + model: Model, + capabilities: { + enableReasoning: boolean + enableWebSearch: boolean + enableGenerateImage: boolean + } +): Record { + const { enableReasoning } = capabilities + let providerOptions: Record = {} + + if (enableReasoning) { + const reasoningParams = getXAIReasoningParams(assistant, model) + providerOptions = { + ...providerOptions, + ...reasoningParams + } } return providerOptions @@ -182,7 +186,7 @@ function buildGenericProviderOptions( enableGenerateImage: boolean } ): Record { - const { enableWebSearch, enableGenerateImage } = capabilities + const { enableWebSearch } = capabilities let providerOptions: Record = {} const reasoningParams = getReasoningEffort(assistant, model) @@ -192,11 +196,11 @@ function buildGenericProviderOptions( } if (enableWebSearch) { - providerOptions.webSearch = { enabled: true } - } - - if (enableGenerateImage) { - providerOptions.generateImage = { enabled: true } + const webSearchParams = getWebSearchParams(model) + providerOptions = { + ...providerOptions, + ...webSearchParams + } } return providerOptions diff --git a/src/renderer/src/aiCore/utils/reasoning.ts b/src/renderer/src/aiCore/utils/reasoning.ts index 6a241aadfd..6eb0db358b 100644 --- a/src/renderer/src/aiCore/utils/reasoning.ts +++ b/src/renderer/src/aiCore/utils/reasoning.ts @@ -314,6 +314,18 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re return {} } +export function getXAIReasoningParams(assistant: Assistant, model: Model): Record { + if (!isSupportedReasoningEffortGrokModel(model)) { + return {} + } + + const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant) + + return { + reasoningEffort + } +} + /** * 获取自定义参数 * 从 assistant 设置中提取自定义参数 diff --git a/src/renderer/src/aiCore/utils/websearch.ts b/src/renderer/src/aiCore/utils/websearch.ts index a59cbf5fbd..d2d0345826 100644 --- a/src/renderer/src/aiCore/utils/websearch.ts +++ b/src/renderer/src/aiCore/utils/websearch.ts @@ -1,37 +1,31 @@ -// import { isWebSearchModel } from '@renderer/config/models' -// import { Model } from '@renderer/types' -// // import {} from '@cherrystudio/ai-core' +import { isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models' +import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '@renderer/config/prompts' +import { Model } from '@renderer/types' -// // The tool name for Gemini search can be arbitrary, but let's use a descriptive one. -// const GEMINI_SEARCH_TOOL_NAME = 'google_search' +export function getWebSearchParams(model: Model): Record { + if (model.provider === 'hunyuan') { + return { enable_enhancement: true, citation: true, search_info: true } + } -// export function getWebSearchTools(model: Model): Record { -// if (!isWebSearchModel(model)) { -// return {} -// } + if (model.provider === 'dashscope') { + return { + enable_search: true, + search_options: { + forced_search: true + } + } + } -// // Use provider from model if available, otherwise fallback to parsing model id. -// const provider = model.provider || model.id.split('/')[0] + if (isOpenAIWebSearchChatCompletionOnlyModel(model)) { + return { + web_search_options: {} + } + } -// switch (provider) { -// case 'anthropic': -// return { -// web_search: { -// type: 'web_search_20250305', -// name: 'web_search', -// max_uses: 5 -// } -// } -// case 'google': -// case 'gemini': -// return { -// [GEMINI_SEARCH_TOOL_NAME]: { -// googleSearch: {} -// } -// } -// default: -// // For OpenAI and others, web search is often a parameter, not a tool. -// // The logic is handled in `buildProviderOptions`. -// return {} -// } -// } + if (model.provider === 'openrouter') { + return { + plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }] + } + } + return {} +} diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 18d24d9ba6..4a7db3ae70 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -2776,8 +2776,6 @@ export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boole return { tools: webSearchTools } - - return {} } export function isGemmaModel(model?: Model): boolean { @@ -2839,7 +2837,7 @@ export const THINKING_TOKEN_MAP: Record = 'gemini-.*-pro.*$': { min: 128, max: 32768 }, // Qwen models - 'qwen-plus-.*$': { min: 0, max: 38912 }, + 'qwen-plus(-.*)?$': { min: 0, max: 38912 }, 'qwen-turbo-.*$': { min: 0, max: 38912 }, 'qwen3-0\\.6b$': { min: 0, max: 30720 }, 'qwen3-1\\.7b$': { min: 0, max: 30720 }, diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index eda7bd19e9..58827216a0 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -300,8 +300,8 @@ export async function fetchChatCompletion({ } onChunkReceived: (chunk: Chunk) => void }) { - const provider = getAssistantProvider(assistant) - const AI = new AiProviderNew(provider) + const AI = new AiProviderNew(assistant.model || getDefaultModel()) + const provider = AI.getActualProvider() const mcpTools = await fetchMcpTools(assistant) @@ -310,7 +310,7 @@ export async function fetchChatCompletion({ params: aiSdkParams, modelId, capabilities - } = await buildStreamTextParams(messages, assistant, { + } = await buildStreamTextParams(messages, assistant, provider, { mcpTools: mcpTools, enableTools: isEnabledToolUse(assistant), requestOptions: options diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 8a9bc7cbc8..f851d85766 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -176,14 +176,7 @@ export interface VertexProvider extends BaseProvider { location: string } -export type ProviderType = - | 'openai' - | 'openai-response' - | 'anthropic' - | 'gemini' - | 'qwenlm' - | 'azure-openai' - | 'vertexai' +export type ProviderType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'azure-openai' | 'vertexai' export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search'