From 6f1101e96d4d3c7500b15e9384fa94f384e4491e Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Thu, 8 May 2025 09:52:53 +0800 Subject: [PATCH] feat: customize aihubmix provider request logic (#5728) --- src/renderer/src/config/providers.ts | 2 +- .../providers/AiProvider/AihubmixProvider.ts | 92 +++++++++++++++++++ .../providers/AiProvider/ProviderFactory.ts | 4 + 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 src/renderer/src/providers/AiProvider/AihubmixProvider.ts diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 4879e67419..81407a8ae2 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -95,7 +95,7 @@ export function getProviderLogo(providerId: string) { return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP] } -export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope'] +export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix'] export const PROVIDER_CONFIG = { openai: { diff --git a/src/renderer/src/providers/AiProvider/AihubmixProvider.ts b/src/renderer/src/providers/AiProvider/AihubmixProvider.ts new file mode 100644 index 0000000000..4c82acc14f --- /dev/null +++ b/src/renderer/src/providers/AiProvider/AihubmixProvider.ts @@ -0,0 +1,92 @@ +import { Assistant, Model, Provider, Suggestion } from '@renderer/types' +import { Message } from '@renderer/types/newMessage' +import OpenAI from 'openai' + +import { CompletionsParams } from '.' +import AnthropicProvider from './AnthropicProvider' +import BaseProvider from './BaseProvider' +import GeminiProvider from './GeminiProvider' +import OpenAIProvider from './OpenAIProvider' + +/** + * AihubmixProvider - 根据模型类型自动选择合适的提供商 + * 使用装饰器模式实现 + */ +export default class AihubmixProvider extends BaseProvider { + private providers: Map = new Map() + private defaultProvider: BaseProvider + + constructor(provider: Provider) { + super(provider) + + // 初始化各个提供商 + this.providers.set('claude', new AnthropicProvider(provider)) + this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' })) + this.providers.set('default', new OpenAIProvider(provider)) + + // 设置默认提供商 + this.defaultProvider = this.providers.get('default')! + } + + /** + * 根据模型ID获取合适的提供商 + */ + private getProvider(modelId: string = ''): BaseProvider { + const id = modelId.toLowerCase() + + if (id.includes('claude')) return this.providers.get('claude')! + if (id.includes('gemini')) return this.providers.get('gemini')! + + return this.defaultProvider + } + + // 直接使用默认提供商的方法 + public async models(): Promise { + return this.defaultProvider.models() + } + + public async generateText(params: { prompt: string; content: string }): Promise { + return this.defaultProvider.generateText(params) + } + + public async generateImage(params: any): Promise { + return this.defaultProvider.generateImage(params) + } + + public async generateImageByChat(params: any): Promise { + return this.defaultProvider.generateImageByChat(params) + } + + public async completions(params: CompletionsParams): Promise { + const modelId = params.assistant.model?.id || '' + return this.getProvider(modelId).completions(params) + } + + public async translate( + content: string, + assistant: Assistant, + onResponse?: (text: string, isComplete: boolean) => void + ): Promise { + return this.getProvider(assistant.model?.id).translate(content, assistant, onResponse) + } + + public async summaries(messages: Message[], assistant: Assistant): Promise { + return this.getProvider(assistant.model?.id).summaries(messages, assistant) + } + + public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { + return this.getProvider(assistant.model?.id).summaryForSearch(messages, assistant) + } + + public async suggestions(messages: Message[], assistant: Assistant): Promise { + return this.getProvider(assistant.model?.id).suggestions(messages, assistant) + } + + public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { + return this.getProvider(model.id).check(model, stream) + } + + public async getEmbeddingDimensions(model: Model): Promise { + return this.getProvider(model.id).getEmbeddingDimensions(model) + } +} diff --git a/src/renderer/src/providers/AiProvider/ProviderFactory.ts b/src/renderer/src/providers/AiProvider/ProviderFactory.ts index ff3515e119..f4ef15c7f4 100644 --- a/src/renderer/src/providers/AiProvider/ProviderFactory.ts +++ b/src/renderer/src/providers/AiProvider/ProviderFactory.ts @@ -1,5 +1,6 @@ import { Provider } from '@renderer/types' +import AihubmixProvider from './AihubmixProvider' import AnthropicProvider from './AnthropicProvider' import BaseProvider from './BaseProvider' import GeminiProvider from './GeminiProvider' @@ -10,6 +11,9 @@ export default class ProviderFactory { static create(provider: Provider): BaseProvider { switch (provider.type) { case 'openai': + if (provider.id === 'aihubmix') { + return new AihubmixProvider(provider) + } return new OpenAIProvider(provider) case 'openai-compatible': return new OpenAICompatibleProvider(provider)