feat: aihubmix support

This commit is contained in:
suyao 2025-07-08 03:47:25 +08:00
parent 3c955e69f1
commit 450d6228d4
No known key found for this signature in database
10 changed files with 181 additions and 121 deletions

View File

@ -1,7 +1,7 @@
// copy from @ai-sdk/xai/xai-chat-options.ts
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
import { z } from 'zod'
import * as z from 'zod/v4'
const webSourceSchema = z.object({
type: z.literal('web'),
@ -25,7 +25,7 @@ const newsSourceSchema = z.object({
const rssSourceSchema = z.object({
type: z.literal('rss'),
links: z.array(z.string().url()).max(1) // currently only supports one RSS link
links: z.array(z.url()).max(1) // currently only supports one RSS link
})
const searchSourceSchema = z.discriminatedUnion('type', [

View File

@ -30,14 +30,11 @@ import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk
import { CompletionsResult } from './middleware/schemas'
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
import { getAiSdkProviderId } from './provider/factory'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { createAihubmixProvider } from './provider/aihubmix'
/**
* Provider AI SDK
*/
function providerToAiSdkConfig(provider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
function getActualProvider(model: Model): Provider {
const provider = getProviderByModel(model)
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider
let actualProvider = cloneDeep(provider)
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
@ -47,18 +44,25 @@ function providerToAiSdkConfig(provider: Provider): {
actualProvider = createVertexProvider(provider)
}
if (
actualProvider.type === 'openai' ||
actualProvider.type === 'anthropic' ||
actualProvider.type === 'openai-response'
) {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
if (provider.id === 'aihubmix') {
actualProvider = createAihubmixProvider(model, actualProvider)
}
if (actualProvider.type === 'gemini') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
} else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
}
return actualProvider
}
/**
* Provider AI SDK
*/
function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 如果provider是openai则使用strict模式并且默认responses api
@ -126,14 +130,18 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config: ReturnType<typeof providerToAiSdkConfig>
private actualProvider: Provider
constructor(provider: Provider) {
this.legacyProvider = new LegacyAiProvider(provider)
constructor(model: Model) {
this.actualProvider = getActualProvider(model)
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
// 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(provider)
this.config = providerToAiSdkConfig(this.actualProvider)
}
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
public getActualProvider() {
return this.actualProvider
}
/**

View File

@ -0,0 +1,55 @@
import { ProviderId } from '@cherrystudio/ai-core/types'
import { isOpenAILLMModel } from '@renderer/config/models'
import { Model, Provider } from '@renderer/types'
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
const id = model.id.toLowerCase()
if (id.startsWith('claude')) {
return 'anthropic'
}
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
return 'google'
}
if (isOpenAILLMModel(model)) {
return 'openai'
}
return 'openai-compatible'
}
export function createAihubmixProvider(model: Model, provider: Provider): Provider {
const providerId = getAiSdkProviderIdForAihubmix(model)
provider = {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
if (providerId === 'google') {
return {
...provider,
type: 'gemini',
apiHost: 'https://aihubmix.com/gemini'
}
}
if (providerId === 'openai') {
return {
...provider,
type: 'openai'
}
}
if (providerId === 'anthropic') {
return {
...provider,
type: 'anthropic'
}
}
return provider
}

View File

@ -27,7 +27,7 @@ import {
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
import {
@ -241,6 +241,7 @@ export async function convertMessagesToSdkMessages(
export async function buildStreamTextParams(
sdkMessages: StreamTextParams['messages'],
assistant: Assistant,
provider: Provider,
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
@ -285,14 +286,8 @@ export async function buildStreamTextParams(
enableToolUse: enableTools
})
// Add web search tools if enabled
// if (enableWebSearch) {
// const webSearchTools = getWebSearchTools(model)
// tools = { ...tools, ...webSearchTools }
// }
// 构建真正的 providerOptions
const providerOptions = buildProviderOptions(assistant, model, {
const providerOptions = buildProviderOptions(assistant, model, provider, {
enableReasoning,
enableWebSearch,
enableGenerateImage
@ -321,11 +316,12 @@ export async function buildStreamTextParams(
export async function buildGenerateTextParams(
messages: ModelMessage[],
assistant: Assistant,
provider: Provider,
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
): Promise<any> {
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, options)
return await buildStreamTextParams(messages, assistant, provider, options)
}

View File

@ -1,5 +1,5 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import { Assistant, Model } from '@renderer/types'
import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
import { Assistant, Model, Provider } from '@renderer/types'
import { getAiSdkProviderId } from '../provider/factory'
import {
@ -7,8 +7,10 @@ import {
getCustomParameters,
getGeminiReasoningParams,
getOpenAIReasoningParams,
getReasoningEffort
getReasoningEffort,
getXAIReasoningParams
} from './reasoning'
import { getWebSearchParams } from './websearch'
/**
* AI SDK providerOptions
@ -18,25 +20,22 @@ import {
export function buildProviderOptions(
assistant: Assistant,
model: Model,
actualProvider: Provider,
capabilities: {
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
const provider = getProviderByModel(model)
const providerId = getAiSdkProviderId(provider)
const providerId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
console.log('buildProviderOptions', providerId)
console.log('buildProviderOptions', provider)
// 根据 provider 类型分离构建逻辑
switch (provider.type) {
case 'openai-response':
case 'azure-openai':
switch (providerId) {
case 'openai':
case 'azure':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
break
@ -44,11 +43,15 @@ export function buildProviderOptions(
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
case 'gemini':
case 'vertexai':
case 'google':
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'xai':
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities)
@ -79,7 +82,7 @@ function buildOpenAIProviderOptions(
enableGenerateImage: boolean
}
): Record<string, any> {
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
// OpenAI 推理参数
@ -91,15 +94,6 @@ function buildOpenAIProviderOptions(
}
}
// Web 搜索和图像生成暂时使用通用格式
if (enableWebSearch) {
providerOptions.webSearch = { enabled: true }
}
if (enableGenerateImage) {
providerOptions.generateImage = { enabled: true }
}
return providerOptions
}
@ -115,7 +109,7 @@ function buildAnthropicProviderOptions(
enableGenerateImage: boolean
}
): Record<string, any> {
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
// Anthropic 推理参数
@ -127,14 +121,6 @@ function buildAnthropicProviderOptions(
}
}
if (enableWebSearch) {
providerOptions.webSearch = { enabled: true }
}
if (enableGenerateImage) {
providerOptions.generateImage = { enabled: true }
}
return providerOptions
}
@ -150,21 +136,39 @@ function buildGeminiProviderOptions(
enableGenerateImage: boolean
}
): Record<string, any> {
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
const providerOptions: Record<string, any> = {}
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
// Gemini 推理参数
if (enableReasoning) {
const reasoningParams = getGeminiReasoningParams(assistant, model)
Object.assign(providerOptions, reasoningParams)
providerOptions = {
...providerOptions,
...reasoningParams
}
}
if (enableWebSearch) {
providerOptions.webSearch = { enabled: true }
}
return providerOptions
}
if (enableGenerateImage) {
providerOptions.generateImage = { enabled: true }
function buildXAIProviderOptions(
assistant: Assistant,
model: Model,
capabilities: {
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
if (enableReasoning) {
const reasoningParams = getXAIReasoningParams(assistant, model)
providerOptions = {
...providerOptions,
...reasoningParams
}
}
return providerOptions
@ -182,7 +186,7 @@ function buildGenericProviderOptions(
enableGenerateImage: boolean
}
): Record<string, any> {
const { enableWebSearch, enableGenerateImage } = capabilities
const { enableWebSearch } = capabilities
let providerOptions: Record<string, any> = {}
const reasoningParams = getReasoningEffort(assistant, model)
@ -192,11 +196,11 @@ function buildGenericProviderOptions(
}
if (enableWebSearch) {
providerOptions.webSearch = { enabled: true }
}
if (enableGenerateImage) {
providerOptions.generateImage = { enabled: true }
const webSearchParams = getWebSearchParams(model)
providerOptions = {
...providerOptions,
...webSearchParams
}
}
return providerOptions

View File

@ -314,6 +314,18 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
export function getXAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
if (!isSupportedReasoningEffortGrokModel(model)) {
return {}
}
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
return {
reasoningEffort
}
}
/**
*
* assistant

View File

@ -1,37 +1,31 @@
// import { isWebSearchModel } from '@renderer/config/models'
// import { Model } from '@renderer/types'
// // import {} from '@cherrystudio/ai-core'
import { isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models'
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '@renderer/config/prompts'
import { Model } from '@renderer/types'
// // The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
// const GEMINI_SEARCH_TOOL_NAME = 'google_search'
export function getWebSearchParams(model: Model): Record<string, any> {
if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
// export function getWebSearchTools(model: Model): Record<string, any> {
// if (!isWebSearchModel(model)) {
// return {}
// }
if (model.provider === 'dashscope') {
return {
enable_search: true,
search_options: {
forced_search: true
}
}
}
// // Use provider from model if available, otherwise fallback to parsing model id.
// const provider = model.provider || model.id.split('/')[0]
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
return {
web_search_options: {}
}
}
// switch (provider) {
// case 'anthropic':
// return {
// web_search: {
// type: 'web_search_20250305',
// name: 'web_search',
// max_uses: 5
// }
// }
// case 'google':
// case 'gemini':
// return {
// [GEMINI_SEARCH_TOOL_NAME]: {
// googleSearch: {}
// }
// }
// default:
// // For OpenAI and others, web search is often a parameter, not a tool.
// // The logic is handled in `buildProviderOptions`.
// return {}
// }
// }
if (model.provider === 'openrouter') {
return {
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
}
}
return {}
}

View File

@ -2776,8 +2776,6 @@ export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boole
return {
tools: webSearchTools
}
return {}
}
export function isGemmaModel(model?: Model): boolean {
@ -2839,7 +2837,7 @@ export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> =
'gemini-.*-pro.*$': { min: 128, max: 32768 },
// Qwen models
'qwen-plus-.*$': { min: 0, max: 38912 },
'qwen-plus(-.*)?$': { min: 0, max: 38912 },
'qwen-turbo-.*$': { min: 0, max: 38912 },
'qwen3-0\\.6b$': { min: 0, max: 30720 },
'qwen3-1\\.7b$': { min: 0, max: 30720 },

View File

@ -300,8 +300,8 @@ export async function fetchChatCompletion({
}
onChunkReceived: (chunk: Chunk) => void
}) {
const provider = getAssistantProvider(assistant)
const AI = new AiProviderNew(provider)
const AI = new AiProviderNew(assistant.model || getDefaultModel())
const provider = AI.getActualProvider()
const mcpTools = await fetchMcpTools(assistant)
@ -310,7 +310,7 @@ export async function fetchChatCompletion({
params: aiSdkParams,
modelId,
capabilities
} = await buildStreamTextParams(messages, assistant, {
} = await buildStreamTextParams(messages, assistant, provider, {
mcpTools: mcpTools,
enableTools: isEnabledToolUse(assistant),
requestOptions: options

View File

@ -176,14 +176,7 @@ export interface VertexProvider extends BaseProvider {
location: string
}
export type ProviderType =
| 'openai'
| 'openai-response'
| 'anthropic'
| 'gemini'
| 'qwenlm'
| 'azure-openai'
| 'vertexai'
export type ProviderType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'azure-openai' | 'vertexai'
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search'