diff --git a/src/renderer/src/providers/AiProvider/AihubmixProvider.ts b/src/renderer/src/providers/AiProvider/AihubmixProvider.ts index 4c82acc14f..8961a050d4 100644 --- a/src/renderer/src/providers/AiProvider/AihubmixProvider.ts +++ b/src/renderer/src/providers/AiProvider/AihubmixProvider.ts @@ -1,3 +1,5 @@ +import { isOpenAILLMModel } from '@renderer/config/models' +import { getDefaultModel } from '@renderer/services/AssistantService' import { Assistant, Model, Provider, Suggestion } from '@renderer/types' import { Message } from '@renderer/types/newMessage' import OpenAI from 'openai' @@ -6,6 +8,7 @@ import { CompletionsParams } from '.' import AnthropicProvider from './AnthropicProvider' import BaseProvider from './BaseProvider' import GeminiProvider from './GeminiProvider' +import OpenAICompatibleProvider from './OpenAICompatibleProvider' import OpenAIProvider from './OpenAIProvider' /** @@ -22,20 +25,28 @@ export default class AihubmixProvider extends BaseProvider { // 初始化各个提供商 this.providers.set('claude', new AnthropicProvider(provider)) this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' })) - this.providers.set('default', new OpenAIProvider(provider)) + this.providers.set('openai', new OpenAIProvider(provider)) + this.providers.set('default', new OpenAICompatibleProvider(provider)) // 设置默认提供商 this.defaultProvider = this.providers.get('default')! } /** - * 根据模型ID获取合适的提供商 + * 根据模型获取合适的提供商 */ - private getProvider(modelId: string = ''): BaseProvider { - const id = modelId.toLowerCase() + private getProvider(model: Model): BaseProvider { + const id = model.id.toLowerCase() - if (id.includes('claude')) return this.providers.get('claude')! - if (id.includes('gemini')) return this.providers.get('gemini')! + if (id.includes('claude')) { + return this.providers.get('claude')! + } + if (id.includes('gemini')) { + return this.providers.get('gemini')! + } + if (isOpenAILLMModel(model)) { + return this.providers.get('openai')! + } return this.defaultProvider } @@ -58,8 +69,8 @@ export default class AihubmixProvider extends BaseProvider { } public async completions(params: CompletionsParams): Promise { - const modelId = params.assistant.model?.id || '' - return this.getProvider(modelId).completions(params) + const model = params.assistant.model + return this.getProvider(model!).completions(params) } public async translate( @@ -67,26 +78,26 @@ export default class AihubmixProvider extends BaseProvider { assistant: Assistant, onResponse?: (text: string, isComplete: boolean) => void ): Promise { - return this.getProvider(assistant.model?.id).translate(content, assistant, onResponse) + return this.getProvider(assistant.model || getDefaultModel()).translate(content, assistant, onResponse) } public async summaries(messages: Message[], assistant: Assistant): Promise { - return this.getProvider(assistant.model?.id).summaries(messages, assistant) + return this.getProvider(assistant.model || getDefaultModel()).summaries(messages, assistant) } public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { - return this.getProvider(assistant.model?.id).summaryForSearch(messages, assistant) + return this.getProvider(assistant.model || getDefaultModel()).summaryForSearch(messages, assistant) } public async suggestions(messages: Message[], assistant: Assistant): Promise { - return this.getProvider(assistant.model?.id).suggestions(messages, assistant) + return this.getProvider(assistant.model || getDefaultModel()).suggestions(messages, assistant) } public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { - return this.getProvider(model.id).check(model, stream) + return this.getProvider(model).check(model, stream) } public async getEmbeddingDimensions(model: Model): Promise { - return this.getProvider(model.id).getEmbeddingDimensions(model) + return this.getProvider(model).getEmbeddingDimensions(model) } } diff --git a/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts b/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts index 950b2c138f..dce83cb6dc 100644 --- a/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts @@ -2,7 +2,6 @@ import { findTokenLimit, getOpenAIWebSearchParams, isHunyuanSearchModel, - isOpenAILLMModel, isOpenAIReasoningModel, isOpenAIWebSearch, isReasoningModel, @@ -331,10 +330,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel - if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) { - await super.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }) - return - } const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId messages = addImageFileToContents(messages) @@ -693,9 +688,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { async translate(content: string, assistant: Assistant, onResponse?: (text: string, isComplete: boolean) => void) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel - if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) { - return await super.translate(content, assistant, onResponse) - } + const messagesForApi = content ? [ { role: 'system', content: assistant.prompt }, @@ -770,10 +763,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() - if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) { - return await super.summaries(messages, assistant) - } - const userMessages = takeRight(messages, 5) .filter((message) => !message.isPreset) .map((message) => ({ @@ -823,10 +812,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { const model = assistant.model || getDefaultModel() - if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) { - return await super.summaryForSearch(messages, assistant) - } - const systemMessage = { role: 'system', content: assistant.prompt @@ -938,9 +923,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { if (!model) { return { valid: false, error: new Error('No model found') } } - if (model.provider === 'aihubmix' && isOpenAILLMModel(model)) { - return await super.check(model, stream) - } + const body = { model: model.id, messages: [{ role: 'user', content: 'hi' }], diff --git a/src/renderer/src/providers/AiProvider/ProviderFactory.ts b/src/renderer/src/providers/AiProvider/ProviderFactory.ts index f4ef15c7f4..6d3c10468e 100644 --- a/src/renderer/src/providers/AiProvider/ProviderFactory.ts +++ b/src/renderer/src/providers/AiProvider/ProviderFactory.ts @@ -11,11 +11,11 @@ export default class ProviderFactory { static create(provider: Provider): BaseProvider { switch (provider.type) { case 'openai': + return new OpenAIProvider(provider) + case 'openai-compatible': if (provider.id === 'aihubmix') { return new AihubmixProvider(provider) } - return new OpenAIProvider(provider) - case 'openai-compatible': return new OpenAICompatibleProvider(provider) case 'anthropic': return new AnthropicProvider(provider) diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index 805acaf941..27a68b342b 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -38,7 +38,7 @@ export const INITIAL_PROVIDERS: Provider[] = [ { id: 'aihubmix', name: 'AiHubMix', - type: 'openai', + type: 'openai-compatible', apiKey: '', apiHost: 'https://aihubmix.com', models: SYSTEM_MODELS.aihubmix, @@ -68,7 +68,7 @@ export const INITIAL_PROVIDERS: Provider[] = [ { id: 'openrouter', name: 'OpenRouter', - type: 'openai', + type: 'openai-compatible', apiKey: '', apiHost: 'https://openrouter.ai/api/v1/', models: SYSTEM_MODELS.openrouter,