mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-10 23:59:45 +08:00
feat/hunyuan-a13b (#8405)
* refactor(AiProvider): enhance client compatibility checks and middleware handling - Updated AiProvider to use a compatibility type check for API clients, improving type safety and middleware management. - Implemented getClientCompatibilityType in AihubmixAPIClient, NewAPIClient, and OpenAIResponseAPIClient to return actual client types. - Added support for Hunyuan models in various model checks and updated the ThinkingButton component to reflect these changes. - Improved logging for middleware construction in AiProvider. * test(ApiService): add client compatibility type checks for mock API clients * fix: minimax-m1 reasoning export btw --------- Co-authored-by: Pleasurecruise <3196812536@qq.com>
This commit is contained in:
parent
65b1d8819d
commit
71b527b67c
@ -136,6 +136,18 @@ export class AihubmixAPIClient extends BaseApiClient {
|
|||||||
return this.currentClient
|
return this.currentClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重写基类方法,返回内部实际使用的客户端类型
|
||||||
|
*/
|
||||||
|
public override getClientCompatibilityType(model?: Model): string[] {
|
||||||
|
if (!model) {
|
||||||
|
return [this.constructor.name]
|
||||||
|
}
|
||||||
|
|
||||||
|
const actualClient = this.getClient(model)
|
||||||
|
return actualClient.getClientCompatibilityType(model)
|
||||||
|
}
|
||||||
|
|
||||||
// ============ BaseApiClient 抽象方法实现 ============
|
// ============ BaseApiClient 抽象方法实现 ============
|
||||||
|
|
||||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||||
|
|||||||
@ -75,6 +75,17 @@ export abstract class BaseApiClient<
|
|||||||
this.apiKey = this.getApiKey()
|
this.apiKey = this.getApiKey()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取客户端的兼容性类型
|
||||||
|
* 用于判断客户端是否支持特定功能,避免instanceof检查的类型收窄问题
|
||||||
|
* 对于装饰器模式的客户端(如AihubmixAPIClient),应该返回其内部实际使用的客户端类型
|
||||||
|
*/
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
public getClientCompatibilityType(_model?: Model): string[] {
|
||||||
|
// 默认返回类的名称
|
||||||
|
return [this.constructor.name]
|
||||||
|
}
|
||||||
|
|
||||||
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
||||||
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
||||||
|
|
||||||
|
|||||||
@ -128,6 +128,18 @@ export class NewAPIClient extends BaseApiClient {
|
|||||||
return this.currentClient
|
return this.currentClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重写基类方法,返回内部实际使用的客户端类型
|
||||||
|
*/
|
||||||
|
public override getClientCompatibilityType(model?: Model): string[] {
|
||||||
|
if (!model) {
|
||||||
|
return [this.constructor.name]
|
||||||
|
}
|
||||||
|
|
||||||
|
const actualClient = this.getClient(model)
|
||||||
|
return actualClient.getClientCompatibilityType(model)
|
||||||
|
}
|
||||||
|
|
||||||
// ============ BaseApiClient 抽象方法实现 ============
|
// ============ BaseApiClient 抽象方法实现 ============
|
||||||
|
|
||||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import {
|
|||||||
isSupportedThinkingTokenClaudeModel,
|
isSupportedThinkingTokenClaudeModel,
|
||||||
isSupportedThinkingTokenDoubaoModel,
|
isSupportedThinkingTokenDoubaoModel,
|
||||||
isSupportedThinkingTokenGeminiModel,
|
isSupportedThinkingTokenGeminiModel,
|
||||||
|
isSupportedThinkingTokenHunyuanModel,
|
||||||
isSupportedThinkingTokenModel,
|
isSupportedThinkingTokenModel,
|
||||||
isSupportedThinkingTokenQwenModel,
|
isSupportedThinkingTokenQwenModel,
|
||||||
isVisionModel
|
isVisionModel
|
||||||
@ -128,7 +129,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
|||||||
}
|
}
|
||||||
return { reasoning: { enabled: false, exclude: true } }
|
return { reasoning: { enabled: false, exclude: true } }
|
||||||
}
|
}
|
||||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
if (isSupportedThinkingTokenQwenModel(model) || isSupportedThinkingTokenHunyuanModel(model)) {
|
||||||
return { enable_thinking: false }
|
return { enable_thinking: false }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,6 +189,13 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
|||||||
return thinkConfig
|
return thinkConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hunyuan models
|
||||||
|
if (isSupportedThinkingTokenHunyuanModel(model)) {
|
||||||
|
return {
|
||||||
|
enable_thinking: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Grok models
|
// Grok models
|
||||||
if (isSupportedReasoningEffortGrokModel(model)) {
|
if (isSupportedReasoningEffortGrokModel(model)) {
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -96,6 +96,18 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重写基类方法,返回内部实际使用的客户端类型
|
||||||
|
*/
|
||||||
|
public override getClientCompatibilityType(model?: Model): string[] {
|
||||||
|
if (!model) {
|
||||||
|
return [this.constructor.name]
|
||||||
|
}
|
||||||
|
|
||||||
|
const actualClient = this.getClient(model)
|
||||||
|
return actualClient.getClientCompatibilityType(model)
|
||||||
|
}
|
||||||
|
|
||||||
override async getSdkInstance() {
|
override async getSdkInstance() {
|
||||||
if (this.sdkInstance) {
|
if (this.sdkInstance) {
|
||||||
return this.sdkInstance
|
return this.sdkInstance
|
||||||
|
|||||||
@ -9,9 +9,7 @@ import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
|||||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||||
|
|
||||||
import { OpenAIAPIClient } from './clients'
|
|
||||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||||
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
|
|
||||||
import { NewAPIClient } from './clients/NewAPIClient'
|
import { NewAPIClient } from './clients/NewAPIClient'
|
||||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||||
@ -87,12 +85,18 @@ export default class AiProvider {
|
|||||||
builder.remove(ThinkChunkMiddlewareName)
|
builder.remove(ThinkChunkMiddlewareName)
|
||||||
logger.silly('ThinkChunkMiddleware is removed')
|
logger.silly('ThinkChunkMiddleware is removed')
|
||||||
}
|
}
|
||||||
// 注意:用client判断会导致typescript类型收窄
|
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||||
if (!(this.apiClient instanceof OpenAIAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
|
const clientTypes = client.getClientCompatibilityType(model)
|
||||||
|
const isOpenAICompatible =
|
||||||
|
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||||
|
if (!isOpenAICompatible) {
|
||||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||||
}
|
}
|
||||||
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
|
|
||||||
|
const isAnthropicOrOpenAIResponseCompatible =
|
||||||
|
clientTypes.includes('AnthropicAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||||
|
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||||
logger.silly('RawStreamListenerMiddleware is removed')
|
logger.silly('RawStreamListenerMiddleware is removed')
|
||||||
builder.remove(RawStreamListenerMiddlewareName)
|
builder.remove(RawStreamListenerMiddlewareName)
|
||||||
}
|
}
|
||||||
@ -123,6 +127,7 @@ export default class AiProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const middlewares = builder.build()
|
const middlewares = builder.build()
|
||||||
|
logger.silly('middlewares', middlewares)
|
||||||
|
|
||||||
// 3. Create the wrapped SDK method with middlewares
|
// 3. Create the wrapped SDK method with middlewares
|
||||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||||
|
|||||||
@ -2513,7 +2513,8 @@ export function isSupportedThinkingTokenModel(model?: Model): boolean {
|
|||||||
isSupportedThinkingTokenGeminiModel(model) ||
|
isSupportedThinkingTokenGeminiModel(model) ||
|
||||||
isSupportedThinkingTokenQwenModel(model) ||
|
isSupportedThinkingTokenQwenModel(model) ||
|
||||||
isSupportedThinkingTokenClaudeModel(model) ||
|
isSupportedThinkingTokenClaudeModel(model) ||
|
||||||
isSupportedThinkingTokenDoubaoModel(model)
|
isSupportedThinkingTokenDoubaoModel(model) ||
|
||||||
|
isSupportedThinkingTokenHunyuanModel(model)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2598,6 +2599,10 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean {
|
|||||||
|
|
||||||
const baseName = getLowerBaseModelName(model.id, '/')
|
const baseName = getLowerBaseModelName(model.id, '/')
|
||||||
|
|
||||||
|
if (baseName.includes('coder')) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
baseName.startsWith('qwen3') ||
|
baseName.startsWith('qwen3') ||
|
||||||
[
|
[
|
||||||
@ -2639,12 +2644,27 @@ export function isClaudeReasoningModel(model?: Model): boolean {
|
|||||||
|
|
||||||
export const isSupportedThinkingTokenClaudeModel = isClaudeReasoningModel
|
export const isSupportedThinkingTokenClaudeModel = isClaudeReasoningModel
|
||||||
|
|
||||||
|
export const isSupportedThinkingTokenHunyuanModel = (model?: Model): boolean => {
|
||||||
|
if (!model) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
const baseName = getLowerBaseModelName(model.id, '/')
|
||||||
|
return baseName.includes('hunyuan-a13b')
|
||||||
|
}
|
||||||
|
|
||||||
|
export const isHunyuanReasoningModel = (model?: Model): boolean => {
|
||||||
|
if (!model) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isSupportedThinkingTokenHunyuanModel(model) || model.id.toLowerCase().includes('hunyuan-t1')
|
||||||
|
}
|
||||||
|
|
||||||
export function isReasoningModel(model?: Model): boolean {
|
export function isReasoningModel(model?: Model): boolean {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isEmbeddingModel(model)) {
|
if (isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2664,8 +2684,10 @@ export function isReasoningModel(model?: Model): boolean {
|
|||||||
isGeminiReasoningModel(model) ||
|
isGeminiReasoningModel(model) ||
|
||||||
isQwenReasoningModel(model) ||
|
isQwenReasoningModel(model) ||
|
||||||
isGrokReasoningModel(model) ||
|
isGrokReasoningModel(model) ||
|
||||||
model.id.includes('glm-z1') ||
|
isHunyuanReasoningModel(model) ||
|
||||||
model.id.includes('magistral')
|
model.id.toLowerCase().includes('glm-z1') ||
|
||||||
|
model.id.toLowerCase().includes('magistral') ||
|
||||||
|
model.id.toLowerCase().includes('minimax-m1')
|
||||||
) {
|
) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import {
|
|||||||
isSupportedReasoningEffortGrokModel,
|
isSupportedReasoningEffortGrokModel,
|
||||||
isSupportedThinkingTokenDoubaoModel,
|
isSupportedThinkingTokenDoubaoModel,
|
||||||
isSupportedThinkingTokenGeminiModel,
|
isSupportedThinkingTokenGeminiModel,
|
||||||
|
isSupportedThinkingTokenHunyuanModel,
|
||||||
isSupportedThinkingTokenQwenModel
|
isSupportedThinkingTokenQwenModel
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import { useAssistant } from '@renderer/hooks/useAssistant'
|
import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||||
@ -40,7 +41,8 @@ const MODEL_SUPPORTED_OPTIONS: Record<string, ThinkingOption[]> = {
|
|||||||
gemini: ['off', 'low', 'medium', 'high', 'auto'],
|
gemini: ['off', 'low', 'medium', 'high', 'auto'],
|
||||||
gemini_pro: ['low', 'medium', 'high', 'auto'],
|
gemini_pro: ['low', 'medium', 'high', 'auto'],
|
||||||
qwen: ['off', 'low', 'medium', 'high'],
|
qwen: ['off', 'low', 'medium', 'high'],
|
||||||
doubao: ['off', 'auto', 'high']
|
doubao: ['off', 'auto', 'high'],
|
||||||
|
hunyuan: ['off', 'auto']
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选项转换映射表:当选项不支持时使用的替代选项
|
// 选项转换映射表:当选项不支持时使用的替代选项
|
||||||
@ -62,6 +64,7 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
|
|||||||
const isGeminiFlashModel = GEMINI_FLASH_MODEL_REGEX.test(model.id)
|
const isGeminiFlashModel = GEMINI_FLASH_MODEL_REGEX.test(model.id)
|
||||||
const isQwenModel = isSupportedThinkingTokenQwenModel(model)
|
const isQwenModel = isSupportedThinkingTokenQwenModel(model)
|
||||||
const isDoubaoModel = isSupportedThinkingTokenDoubaoModel(model)
|
const isDoubaoModel = isSupportedThinkingTokenDoubaoModel(model)
|
||||||
|
const isHunyuanModel = isSupportedThinkingTokenHunyuanModel(model)
|
||||||
|
|
||||||
const currentReasoningEffort = useMemo(() => {
|
const currentReasoningEffort = useMemo(() => {
|
||||||
return assistant.settings?.reasoning_effort || 'off'
|
return assistant.settings?.reasoning_effort || 'off'
|
||||||
@ -79,8 +82,9 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
|
|||||||
if (isGrokModel) return 'grok'
|
if (isGrokModel) return 'grok'
|
||||||
if (isQwenModel) return 'qwen'
|
if (isQwenModel) return 'qwen'
|
||||||
if (isDoubaoModel) return 'doubao'
|
if (isDoubaoModel) return 'doubao'
|
||||||
|
if (isHunyuanModel) return 'hunyuan'
|
||||||
return 'default'
|
return 'default'
|
||||||
}, [isGeminiModel, isGrokModel, isQwenModel, isDoubaoModel, isGeminiFlashModel])
|
}, [isGeminiModel, isGrokModel, isQwenModel, isDoubaoModel, isGeminiFlashModel, isHunyuanModel])
|
||||||
|
|
||||||
// 获取当前模型支持的选项
|
// 获取当前模型支持的选项
|
||||||
const supportedOptions = useMemo(() => {
|
const supportedOptions = useMemo(() => {
|
||||||
@ -145,7 +149,7 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
|
|||||||
[updateAssistantSettings]
|
[updateAssistantSettings]
|
||||||
)
|
)
|
||||||
|
|
||||||
const baseOptions = useMemo(() => {
|
const panelItems = useMemo(() => {
|
||||||
// 使用表中定义的选项创建UI选项
|
// 使用表中定义的选项创建UI选项
|
||||||
return supportedOptions.map((option) => ({
|
return supportedOptions.map((option) => ({
|
||||||
level: option,
|
level: option,
|
||||||
@ -157,8 +161,6 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
|
|||||||
}))
|
}))
|
||||||
}, [t, createThinkingIcon, currentReasoningEffort, supportedOptions, onThinkingChange])
|
}, [t, createThinkingIcon, currentReasoningEffort, supportedOptions, onThinkingChange])
|
||||||
|
|
||||||
const panelItems = baseOptions
|
|
||||||
|
|
||||||
const openQuickPanel = useCallback(() => {
|
const openQuickPanel = useCallback(() => {
|
||||||
quickPanel.open({
|
quickPanel.open({
|
||||||
title: t('assistants.settings.reasoning_effort'),
|
title: t('assistants.settings.reasoning_effort'),
|
||||||
|
|||||||
@ -1047,7 +1047,8 @@ const mockOpenaiApiClient = {
|
|||||||
provider: {} as Provider,
|
provider: {} as Provider,
|
||||||
useSystemPromptForTools: true,
|
useSystemPromptForTools: true,
|
||||||
getBaseURL: vi.fn(() => 'https://api.openai.com'),
|
getBaseURL: vi.fn(() => 'https://api.openai.com'),
|
||||||
getApiKey: vi.fn(() => 'mock-api-key')
|
getApiKey: vi.fn(() => 'mock-api-key'),
|
||||||
|
getClientCompatibilityType: vi.fn(() => ['OpenAIAPIClient'])
|
||||||
} as unknown as OpenAIAPIClient
|
} as unknown as OpenAIAPIClient
|
||||||
|
|
||||||
// 创建 mock 的 GeminiAPIClient
|
// 创建 mock 的 GeminiAPIClient
|
||||||
@ -1165,7 +1166,8 @@ const mockGeminiApiClient = {
|
|||||||
provider: {} as Provider,
|
provider: {} as Provider,
|
||||||
useSystemPromptForTools: true,
|
useSystemPromptForTools: true,
|
||||||
getBaseURL: vi.fn(() => 'https://api.gemini.com'),
|
getBaseURL: vi.fn(() => 'https://api.gemini.com'),
|
||||||
getApiKey: vi.fn(() => 'mock-api-key')
|
getApiKey: vi.fn(() => 'mock-api-key'),
|
||||||
|
getClientCompatibilityType: vi.fn(() => ['GeminiAPIClient'])
|
||||||
} as unknown as GeminiAPIClient
|
} as unknown as GeminiAPIClient
|
||||||
|
|
||||||
const mockGeminiThinkingApiClient = cloneDeep(mockGeminiApiClient)
|
const mockGeminiThinkingApiClient = cloneDeep(mockGeminiApiClient)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user