mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 13:31:32 +08:00
feat: aihubmix support
This commit is contained in:
parent
3c955e69f1
commit
450d6228d4
@ -1,7 +1,7 @@
|
||||
// copy from @ai-sdk/xai/xai-chat-options.ts
|
||||
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
|
||||
|
||||
import { z } from 'zod'
|
||||
import * as z from 'zod/v4'
|
||||
|
||||
const webSourceSchema = z.object({
|
||||
type: z.literal('web'),
|
||||
@ -25,7 +25,7 @@ const newsSourceSchema = z.object({
|
||||
|
||||
const rssSourceSchema = z.object({
|
||||
type: z.literal('rss'),
|
||||
links: z.array(z.string().url()).max(1) // currently only supports one RSS link
|
||||
links: z.array(z.url()).max(1) // currently only supports one RSS link
|
||||
})
|
||||
|
||||
const searchSourceSchema = z.discriminatedUnion('type', [
|
||||
|
||||
@ -30,14 +30,11 @@ import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { createAihubmixProvider } from './provider/aihubmix'
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
*/
|
||||
function providerToAiSdkConfig(provider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
function getActualProvider(model: Model): Provider {
|
||||
const provider = getProviderByModel(model)
|
||||
// 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider
|
||||
let actualProvider = cloneDeep(provider)
|
||||
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
|
||||
@ -47,18 +44,25 @@ function providerToAiSdkConfig(provider: Provider): {
|
||||
actualProvider = createVertexProvider(provider)
|
||||
}
|
||||
|
||||
if (
|
||||
actualProvider.type === 'openai' ||
|
||||
actualProvider.type === 'anthropic' ||
|
||||
actualProvider.type === 'openai-response'
|
||||
) {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
|
||||
if (provider.id === 'aihubmix') {
|
||||
actualProvider = createAihubmixProvider(model, actualProvider)
|
||||
}
|
||||
|
||||
if (actualProvider.type === 'gemini') {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
|
||||
} else {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
|
||||
}
|
||||
return actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
*/
|
||||
function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||
@ -126,14 +130,18 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config: ReturnType<typeof providerToAiSdkConfig>
|
||||
private actualProvider: Provider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.legacyProvider = new LegacyAiProvider(provider)
|
||||
constructor(model: Model) {
|
||||
this.actualProvider = getActualProvider(model)
|
||||
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(provider)
|
||||
this.config = providerToAiSdkConfig(this.actualProvider)
|
||||
}
|
||||
|
||||
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
|
||||
public getActualProvider() {
|
||||
return this.actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
55
src/renderer/src/aiCore/provider/aihubmix.ts
Normal file
55
src/renderer/src/aiCore/provider/aihubmix.ts
Normal file
@ -0,0 +1,55 @@
|
||||
import { ProviderId } from '@cherrystudio/ai-core/types'
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
|
||||
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
if (id.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (isOpenAILLMModel(model)) {
|
||||
return 'openai'
|
||||
}
|
||||
|
||||
return 'openai-compatible'
|
||||
}
|
||||
|
||||
export function createAihubmixProvider(model: Model, provider: Provider): Provider {
|
||||
const providerId = getAiSdkProviderIdForAihubmix(model)
|
||||
provider = {
|
||||
...provider,
|
||||
extra_headers: {
|
||||
...provider.extra_headers,
|
||||
'APP-Code': 'MLTG2087'
|
||||
}
|
||||
}
|
||||
if (providerId === 'google') {
|
||||
return {
|
||||
...provider,
|
||||
type: 'gemini',
|
||||
apiHost: 'https://aihubmix.com/gemini'
|
||||
}
|
||||
}
|
||||
|
||||
if (providerId === 'openai') {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai'
|
||||
}
|
||||
}
|
||||
|
||||
if (providerId === 'anthropic') {
|
||||
return {
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
}
|
||||
}
|
||||
|
||||
return provider
|
||||
}
|
||||
@ -27,7 +27,7 @@ import {
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
|
||||
import type { Assistant, MCPTool, Message, Model, Provider } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
|
||||
import {
|
||||
@ -241,6 +241,7 @@ export async function convertMessagesToSdkMessages(
|
||||
export async function buildStreamTextParams(
|
||||
sdkMessages: StreamTextParams['messages'],
|
||||
assistant: Assistant,
|
||||
provider: Provider,
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
@ -285,14 +286,8 @@ export async function buildStreamTextParams(
|
||||
enableToolUse: enableTools
|
||||
})
|
||||
|
||||
// Add web search tools if enabled
|
||||
// if (enableWebSearch) {
|
||||
// const webSearchTools = getWebSearchTools(model)
|
||||
// tools = { ...tools, ...webSearchTools }
|
||||
// }
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, {
|
||||
const providerOptions = buildProviderOptions(assistant, model, provider, {
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
@ -321,11 +316,12 @@ export async function buildStreamTextParams(
|
||||
export async function buildGenerateTextParams(
|
||||
messages: ModelMessage[],
|
||||
assistant: Assistant,
|
||||
provider: Provider,
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
} = {}
|
||||
): Promise<any> {
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, options)
|
||||
return await buildStreamTextParams(messages, assistant, provider, options)
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, Model } from '@renderer/types'
|
||||
import { getDefaultModel, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, Model, Provider } from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import {
|
||||
@ -7,8 +7,10 @@ import {
|
||||
getCustomParameters,
|
||||
getGeminiReasoningParams,
|
||||
getOpenAIReasoningParams,
|
||||
getReasoningEffort
|
||||
getReasoningEffort,
|
||||
getXAIReasoningParams
|
||||
} from './reasoning'
|
||||
import { getWebSearchParams } from './websearch'
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
@ -18,25 +20,22 @@ import {
|
||||
export function buildProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
actualProvider: Provider,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const provider = getProviderByModel(model)
|
||||
const providerId = getAiSdkProviderId(provider)
|
||||
const providerId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
console.log('buildProviderOptions', providerId)
|
||||
console.log('buildProviderOptions', provider)
|
||||
|
||||
// 根据 provider 类型分离构建逻辑
|
||||
switch (provider.type) {
|
||||
case 'openai-response':
|
||||
case 'azure-openai':
|
||||
switch (providerId) {
|
||||
case 'openai':
|
||||
case 'azure':
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
@ -44,11 +43,15 @@ export function buildProviderOptions(
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'gemini':
|
||||
case 'vertexai':
|
||||
case 'google':
|
||||
case 'google-vertex':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'xai':
|
||||
providerSpecificOptions = buildXAIProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities)
|
||||
@ -79,7 +82,7 @@ function buildOpenAIProviderOptions(
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// OpenAI 推理参数
|
||||
@ -91,15 +94,6 @@ function buildOpenAIProviderOptions(
|
||||
}
|
||||
}
|
||||
|
||||
// Web 搜索和图像生成暂时使用通用格式
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
@ -115,7 +109,7 @@ function buildAnthropicProviderOptions(
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// Anthropic 推理参数
|
||||
@ -127,14 +121,6 @@ function buildAnthropicProviderOptions(
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
@ -150,21 +136,39 @@ function buildGeminiProviderOptions(
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// Gemini 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getGeminiReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
function buildXAIProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getXAIReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
@ -182,7 +186,7 @@ function buildGenericProviderOptions(
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableWebSearch, enableGenerateImage } = capabilities
|
||||
const { enableWebSearch } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
const reasoningParams = getReasoningEffort(assistant, model)
|
||||
@ -192,11 +196,11 @@ function buildGenericProviderOptions(
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
const webSearchParams = getWebSearchParams(model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...webSearchParams
|
||||
}
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
|
||||
@ -314,6 +314,18 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
|
||||
return {}
|
||||
}
|
||||
|
||||
export function getXAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isSupportedReasoningEffortGrokModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
||||
|
||||
return {
|
||||
reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取自定义参数
|
||||
* 从 assistant 设置中提取自定义参数
|
||||
|
||||
@ -1,37 +1,31 @@
|
||||
// import { isWebSearchModel } from '@renderer/config/models'
|
||||
// import { Model } from '@renderer/types'
|
||||
// // import {} from '@cherrystudio/ai-core'
|
||||
import { isOpenAIWebSearchChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '@renderer/config/prompts'
|
||||
import { Model } from '@renderer/types'
|
||||
|
||||
// // The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
|
||||
// const GEMINI_SEARCH_TOOL_NAME = 'google_search'
|
||||
export function getWebSearchParams(model: Model): Record<string, any> {
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: true, citation: true, search_info: true }
|
||||
}
|
||||
|
||||
// export function getWebSearchTools(model: Model): Record<string, any> {
|
||||
// if (!isWebSearchModel(model)) {
|
||||
// return {}
|
||||
// }
|
||||
if (model.provider === 'dashscope') {
|
||||
return {
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// // Use provider from model if available, otherwise fallback to parsing model id.
|
||||
// const provider = model.provider || model.id.split('/')[0]
|
||||
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
|
||||
return {
|
||||
web_search_options: {}
|
||||
}
|
||||
}
|
||||
|
||||
// switch (provider) {
|
||||
// case 'anthropic':
|
||||
// return {
|
||||
// web_search: {
|
||||
// type: 'web_search_20250305',
|
||||
// name: 'web_search',
|
||||
// max_uses: 5
|
||||
// }
|
||||
// }
|
||||
// case 'google':
|
||||
// case 'gemini':
|
||||
// return {
|
||||
// [GEMINI_SEARCH_TOOL_NAME]: {
|
||||
// googleSearch: {}
|
||||
// }
|
||||
// }
|
||||
// default:
|
||||
// // For OpenAI and others, web search is often a parameter, not a tool.
|
||||
// // The logic is handled in `buildProviderOptions`.
|
||||
// return {}
|
||||
// }
|
||||
// }
|
||||
if (model.provider === 'openrouter') {
|
||||
return {
|
||||
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
|
||||
}
|
||||
}
|
||||
return {}
|
||||
}
|
||||
|
||||
@ -2776,8 +2776,6 @@ export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boole
|
||||
return {
|
||||
tools: webSearchTools
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
export function isGemmaModel(model?: Model): boolean {
|
||||
@ -2839,7 +2837,7 @@ export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> =
|
||||
'gemini-.*-pro.*$': { min: 128, max: 32768 },
|
||||
|
||||
// Qwen models
|
||||
'qwen-plus-.*$': { min: 0, max: 38912 },
|
||||
'qwen-plus(-.*)?$': { min: 0, max: 38912 },
|
||||
'qwen-turbo-.*$': { min: 0, max: 38912 },
|
||||
'qwen3-0\\.6b$': { min: 0, max: 30720 },
|
||||
'qwen3-1\\.7b$': { min: 0, max: 30720 },
|
||||
|
||||
@ -300,8 +300,8 @@ export async function fetchChatCompletion({
|
||||
}
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
}) {
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProviderNew(provider)
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
const mcpTools = await fetchMcpTools(assistant)
|
||||
|
||||
@ -310,7 +310,7 @@ export async function fetchChatCompletion({
|
||||
params: aiSdkParams,
|
||||
modelId,
|
||||
capabilities
|
||||
} = await buildStreamTextParams(messages, assistant, {
|
||||
} = await buildStreamTextParams(messages, assistant, provider, {
|
||||
mcpTools: mcpTools,
|
||||
enableTools: isEnabledToolUse(assistant),
|
||||
requestOptions: options
|
||||
|
||||
@ -176,14 +176,7 @@ export interface VertexProvider extends BaseProvider {
|
||||
location: string
|
||||
}
|
||||
|
||||
export type ProviderType =
|
||||
| 'openai'
|
||||
| 'openai-response'
|
||||
| 'anthropic'
|
||||
| 'gemini'
|
||||
| 'qwenlm'
|
||||
| 'azure-openai'
|
||||
| 'vertexai'
|
||||
export type ProviderType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'azure-openai' | 'vertexai'
|
||||
|
||||
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search'
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user