feat: aihubmix support

This commit is contained in:
suyao 2025-07-08 03:47:25 +08:00
parent 3c955e69f1
commit 450d6228d4
No known key found for this signature in database
10 changed files with 181 additions and 121 deletions

View File

@ -1,7 +1,7 @@
// copy from @ai-sdk/xai/xai-chat-options.ts // copy from @ai-sdk/xai/xai-chat-options.ts
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件 // 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
import { z } from 'zod' import * as z from 'zod/v4'
const webSourceSchema = z.object({ const webSourceSchema = z.object({
type: z.literal('web'), type: z.literal('web'),
@ -25,7 +25,7 @@ const newsSourceSchema = z.object({
const rssSourceSchema = z.object({ const rssSourceSchema = z.object({
type: z.literal('rss'), type: z.literal('rss'),
links: z.array(z.string().url()).max(1) // currently only supports one RSS link links: z.array(z.url()).max(1) // currently only supports one RSS link
}) })
const searchSourceSchema = z.discriminatedUnion('type', [ const searchSourceSchema = z.discriminatedUnion('type', [

View File

@ -30,14 +30,11 @@ import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk
import { CompletionsResult } from './middleware/schemas' import { CompletionsResult } from './middleware/schemas'
import reasoningTimePlugin from './plugins/reasoningTimePlugin' import reasoningTimePlugin from './plugins/reasoningTimePlugin'
import { getAiSdkProviderId } from './provider/factory' import { getAiSdkProviderId } from './provider/factory'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { createAihubmixProvider } from './provider/aihubmix'
/** function getActualProvider(model: Model): Provider {
* Provider AI SDK const provider = getProviderByModel(model)
*/
function providerToAiSdkConfig(provider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider // 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider
let actualProvider = cloneDeep(provider) let actualProvider = cloneDeep(provider)
if (provider.type === 'vertexai' && !isVertexProvider(provider)) { if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
@ -47,18 +44,25 @@ function providerToAiSdkConfig(provider: Provider): {
actualProvider = createVertexProvider(provider) actualProvider = createVertexProvider(provider)
} }
if ( if (provider.id === 'aihubmix') {
actualProvider.type === 'openai' || actualProvider = createAihubmixProvider(model, actualProvider)
actualProvider.type === 'anthropic' ||
actualProvider.type === 'openai-response'
) {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
} }
if (actualProvider.type === 'gemini') { if (actualProvider.type === 'gemini') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta') actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
} else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
} }
return actualProvider
}
/**
* Provider AI SDK
*/
function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
const aiSdkProviderId = getAiSdkProviderId(actualProvider) const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 如果provider是openai则使用strict模式并且默认responses api // 如果provider是openai则使用strict模式并且默认responses api
@ -126,14 +130,18 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
export default class ModernAiProvider { export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider private legacyProvider: LegacyAiProvider
private config: ReturnType<typeof providerToAiSdkConfig> private config: ReturnType<typeof providerToAiSdkConfig>
private actualProvider: Provider
constructor(provider: Provider) { constructor(model: Model) {
this.legacyProvider = new LegacyAiProvider(provider) this.actualProvider = getActualProvider(model)
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
// 只保存配置不预先创建executor // 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(provider) this.config = providerToAiSdkConfig(this.actualProvider)
}
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled') public getActualProvider() {
return this.actualProvider
} }
/** /**

View File

@ -0,0 +1,55 @@
import { ProviderId } from '@cherrystudio/ai-core/types'
import { isOpenAILLMModel } from '@renderer/config/models'
import { Model, Provider } from '@renderer/types'
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
const id = model.id.toLowerCase()
if (id.startsWith('claude')) {
return 'anthropic'
}
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
return 'google'
}
if (isOpenAILLMModel(model)) {
return 'openai'
}
return 'openai-compatible'
}
export function createAihubmixProvider(model: Model, provider: Provider): Provider {
const providerId = getAiSdkProviderIdForAihubmix(model)
provider = {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
if (providerId === 'google') {
return {
...provider,
type: 'gemini',
apiHost: 'https://aihubmix.com/gemini'
}
}
if (providerId === 'openai') {
return {
...provider,
type: 'openai'
}
}
if (providerId === 'anthropic') {
return {
...provider,
type: 'anthropic'
}
}
return provider
}

View File

@ -27,7 +27,7 @@ import {
isWebSearchModel isWebSearchModel
} from '@renderer/config/models' } from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types' import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types'
import { FileTypes } from '@renderer/types' import { FileTypes } from '@renderer/types'
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
import { import {
@ -241,6 +241,7 @@ export async function convertMessagesToSdkMessages(
export async function buildStreamTextParams( export async function buildStreamTextParams(
sdkMessages: StreamTextParams['messages'], sdkMessages: StreamTextParams['messages'],
assistant: Assistant, assistant: Assistant,
provider: Provider,
options: { options: {
mcpTools?: MCPTool[] mcpTools?: MCPTool[]
enableTools?: boolean enableTools?: boolean
@ -285,14 +286,8 @@ export async function buildStreamTextParams(
enableToolUse: enableTools enableToolUse: enableTools
}) })
// Add web search tools if enabled
// if (enableWebSearch) {
// const webSearchTools = getWebSearchTools(model)
// tools = { ...tools, ...webSearchTools }
// }
// 构建真正的 providerOptions // 构建真正的 providerOptions
const providerOptions = buildProviderOptions(assistant, model, { const providerOptions = buildProviderOptions(assistant, model, provider, {
enableReasoning, enableReasoning,
enableWebSearch, enableWebSearch,
enableGenerateImage enableGenerateImage
@ -321,11 +316,12 @@ export async function buildStreamTextParams(
export async function buildGenerateTextParams( export async function buildGenerateTextParams(
messages: ModelMessage[], messages: ModelMessage[],
assistant: Assistant, assistant: Assistant,
provider: Provider,
options: { options: {
mcpTools?: MCPTool[] mcpTools?: MCPTool[]
enableTools?: boolean enableTools?: boolean
} = {} } = {}
): Promise<any> { ): Promise<any> {
// 复用流式参数的构建逻辑 // 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, options) return await buildStreamTextParams(messages, assistant, provider, options)
} }

View File

@ -1,5 +1,5 @@
import { getProviderByModel } from '@renderer/services/AssistantService' import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
import { Assistant, Model } from '@renderer/types' import { Assistant, Model, Provider } from '@renderer/types'
import { getAiSdkProviderId } from '../provider/factory' import { getAiSdkProviderId } from '../provider/factory'
import { import {
@ -7,8 +7,10 @@ import {
getCustomParameters, getCustomParameters,
getGeminiReasoningParams, getGeminiReasoningParams,
getOpenAIReasoningParams, getOpenAIReasoningParams,
getReasoningEffort getReasoningEffort,
getXAIReasoningParams
} from './reasoning' } from './reasoning'
import { getWebSearchParams } from './websearch'
/** /**
* AI SDK providerOptions * AI SDK providerOptions
@ -18,25 +20,22 @@ import {
export function buildProviderOptions( export function buildProviderOptions(
assistant: Assistant, assistant: Assistant,
model: Model, model: Model,
actualProvider: Provider,
capabilities: { capabilities: {
enableReasoning: boolean enableReasoning: boolean
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): Record<string, any> {
const provider = getProviderByModel(model) const providerId = getAiSdkProviderId(actualProvider)
const providerId = getAiSdkProviderId(provider)
// 构建 provider 特定的选项 // 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {} let providerSpecificOptions: Record<string, any> = {}
console.log('buildProviderOptions', providerId)
console.log('buildProviderOptions', provider)
// 根据 provider 类型分离构建逻辑 // 根据 provider 类型分离构建逻辑
switch (provider.type) { switch (providerId) {
case 'openai-response': case 'openai':
case 'azure-openai': case 'azure':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
break break
@ -44,11 +43,15 @@ export function buildProviderOptions(
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break break
case 'gemini': case 'google':
case 'vertexai': case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break break
case 'xai':
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
break
default: default:
// 对于其他 provider使用通用的构建逻辑 // 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities)
@ -79,7 +82,7 @@ function buildOpenAIProviderOptions(
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): Record<string, any> {
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: Record<string, any> = {}
// OpenAI 推理参数 // OpenAI 推理参数
@ -91,15 +94,6 @@ function buildOpenAIProviderOptions(
} }
} }
// Web 搜索和图像生成暂时使用通用格式
if (enableWebSearch) {
providerOptions.webSearch = { enabled: true }
}
if (enableGenerateImage) {
providerOptions.generateImage = { enabled: true }
}
return providerOptions return providerOptions
} }
@ -115,7 +109,7 @@ function buildAnthropicProviderOptions(
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): Record<string, any> {
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: Record<string, any> = {}
// Anthropic 推理参数 // Anthropic 推理参数
@ -127,14 +121,6 @@ function buildAnthropicProviderOptions(
} }
} }
if (enableWebSearch) {
providerOptions.webSearch = { enabled: true }
}
if (enableGenerateImage) {
providerOptions.generateImage = { enabled: true }
}
return providerOptions return providerOptions
} }
@ -150,21 +136,39 @@ function buildGeminiProviderOptions(
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): Record<string, any> {
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities const { enableReasoning } = capabilities
const providerOptions: Record<string, any> = {} let providerOptions: Record<string, any> = {}
// Gemini 推理参数 // Gemini 推理参数
if (enableReasoning) { if (enableReasoning) {
const reasoningParams = getGeminiReasoningParams(assistant, model) const reasoningParams = getGeminiReasoningParams(assistant, model)
Object.assign(providerOptions, reasoningParams) providerOptions = {
...providerOptions,
...reasoningParams
}
} }
if (enableWebSearch) { return providerOptions
providerOptions.webSearch = { enabled: true } }
}
if (enableGenerateImage) { function buildXAIProviderOptions(
providerOptions.generateImage = { enabled: true } assistant: Assistant,
model: Model,
capabilities: {
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
if (enableReasoning) {
const reasoningParams = getXAIReasoningParams(assistant, model)
providerOptions = {
...providerOptions,
...reasoningParams
}
} }
return providerOptions return providerOptions
@ -182,7 +186,7 @@ function buildGenericProviderOptions(
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): Record<string, any> {
const { enableWebSearch, enableGenerateImage } = capabilities const { enableWebSearch } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: Record<string, any> = {}
const reasoningParams = getReasoningEffort(assistant, model) const reasoningParams = getReasoningEffort(assistant, model)
@ -192,11 +196,11 @@ function buildGenericProviderOptions(
} }
if (enableWebSearch) { if (enableWebSearch) {
providerOptions.webSearch = { enabled: true } const webSearchParams = getWebSearchParams(model)
} providerOptions = {
...providerOptions,
if (enableGenerateImage) { ...webSearchParams
providerOptions.generateImage = { enabled: true } }
} }
return providerOptions return providerOptions

View File

@ -314,6 +314,18 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
return {} return {}
} }
export function getXAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
if (!isSupportedReasoningEffortGrokModel(model)) {
return {}
}
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
return {
reasoningEffort
}
}
/** /**
* *
* assistant * assistant

View File

@ -1,37 +1,31 @@
// import { isWebSearchModel } from '@renderer/config/models' import { isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models'
// import { Model } from '@renderer/types' import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '@renderer/config/prompts'
// // import {} from '@cherrystudio/ai-core' import { Model } from '@renderer/types'
// // The tool name for Gemini search can be arbitrary, but let's use a descriptive one. export function getWebSearchParams(model: Model): Record<string, any> {
// const GEMINI_SEARCH_TOOL_NAME = 'google_search' if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
// export function getWebSearchTools(model: Model): Record<string, any> { if (model.provider === 'dashscope') {
// if (!isWebSearchModel(model)) { return {
// return {} enable_search: true,
// } search_options: {
forced_search: true
}
}
}
// // Use provider from model if available, otherwise fallback to parsing model id. if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
// const provider = model.provider || model.id.split('/')[0] return {
web_search_options: {}
}
}
// switch (provider) { if (model.provider === 'openrouter') {
// case 'anthropic': return {
// return { plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
// web_search: { }
// type: 'web_search_20250305', }
// name: 'web_search', return {}
// max_uses: 5 }
// }
// }
// case 'google':
// case 'gemini':
// return {
// [GEMINI_SEARCH_TOOL_NAME]: {
// googleSearch: {}
// }
// }
// default:
// // For OpenAI and others, web search is often a parameter, not a tool.
// // The logic is handled in `buildProviderOptions`.
// return {}
// }
// }

View File

@ -2776,8 +2776,6 @@ export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boole
return { return {
tools: webSearchTools tools: webSearchTools
} }
return {}
} }
export function isGemmaModel(model?: Model): boolean { export function isGemmaModel(model?: Model): boolean {
@ -2839,7 +2837,7 @@ export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> =
'gemini-.*-pro.*$': { min: 128, max: 32768 }, 'gemini-.*-pro.*$': { min: 128, max: 32768 },
// Qwen models // Qwen models
'qwen-plus-.*$': { min: 0, max: 38912 }, 'qwen-plus(-.*)?$': { min: 0, max: 38912 },
'qwen-turbo-.*$': { min: 0, max: 38912 }, 'qwen-turbo-.*$': { min: 0, max: 38912 },
'qwen3-0\\.6b$': { min: 0, max: 30720 }, 'qwen3-0\\.6b$': { min: 0, max: 30720 },
'qwen3-1\\.7b$': { min: 0, max: 30720 }, 'qwen3-1\\.7b$': { min: 0, max: 30720 },

View File

@ -300,8 +300,8 @@ export async function fetchChatCompletion({
} }
onChunkReceived: (chunk: Chunk) => void onChunkReceived: (chunk: Chunk) => void
}) { }) {
const provider = getAssistantProvider(assistant) const AI = new AiProviderNew(assistant.model || getDefaultModel())
const AI = new AiProviderNew(provider) const provider = AI.getActualProvider()
const mcpTools = await fetchMcpTools(assistant) const mcpTools = await fetchMcpTools(assistant)
@ -310,7 +310,7 @@ export async function fetchChatCompletion({
params: aiSdkParams, params: aiSdkParams,
modelId, modelId,
capabilities capabilities
} = await buildStreamTextParams(messages, assistant, { } = await buildStreamTextParams(messages, assistant, provider, {
mcpTools: mcpTools, mcpTools: mcpTools,
enableTools: isEnabledToolUse(assistant), enableTools: isEnabledToolUse(assistant),
requestOptions: options requestOptions: options

View File

@ -176,14 +176,7 @@ export interface VertexProvider extends BaseProvider {
location: string location: string
} }
export type ProviderType = export type ProviderType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'azure-openai' | 'vertexai'
| 'openai'
| 'openai-response'
| 'anthropic'
| 'gemini'
| 'qwenlm'
| 'azure-openai'
| 'vertexai'
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search'