feat: refactor AihubmixProvider and OpenAICompatibleProvider for improved model handling (#5732)

This commit is contained in:
kangfenmao 2025-05-08 13:20:12 +08:00
parent 5c2998cc48
commit aeb8091c89
4 changed files with 31 additions and 37 deletions

View File

@ -1,3 +1,5 @@
import { isOpenAILLMModel } from '@renderer/config/models'
import { getDefaultModel } from '@renderer/services/AssistantService'
import { Assistant, Model, Provider, Suggestion } from '@renderer/types' import { Assistant, Model, Provider, Suggestion } from '@renderer/types'
import { Message } from '@renderer/types/newMessage' import { Message } from '@renderer/types/newMessage'
import OpenAI from 'openai' import OpenAI from 'openai'
@ -6,6 +8,7 @@ import { CompletionsParams } from '.'
import AnthropicProvider from './AnthropicProvider' import AnthropicProvider from './AnthropicProvider'
import BaseProvider from './BaseProvider' import BaseProvider from './BaseProvider'
import GeminiProvider from './GeminiProvider' import GeminiProvider from './GeminiProvider'
import OpenAICompatibleProvider from './OpenAICompatibleProvider'
import OpenAIProvider from './OpenAIProvider' import OpenAIProvider from './OpenAIProvider'
/** /**
@ -22,20 +25,28 @@ export default class AihubmixProvider extends BaseProvider {
// 初始化各个提供商 // 初始化各个提供商
this.providers.set('claude', new AnthropicProvider(provider)) this.providers.set('claude', new AnthropicProvider(provider))
this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' })) this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' }))
this.providers.set('default', new OpenAIProvider(provider)) this.providers.set('openai', new OpenAIProvider(provider))
this.providers.set('default', new OpenAICompatibleProvider(provider))
// 设置默认提供商 // 设置默认提供商
this.defaultProvider = this.providers.get('default')! this.defaultProvider = this.providers.get('default')!
} }
/** /**
* ID获取合适的提供 *
*/ */
private getProvider(modelId: string = ''): BaseProvider { private getProvider(model: Model): BaseProvider {
const id = modelId.toLowerCase() const id = model.id.toLowerCase()
if (id.includes('claude')) return this.providers.get('claude')! if (id.includes('claude')) {
if (id.includes('gemini')) return this.providers.get('gemini')! return this.providers.get('claude')!
}
if (id.includes('gemini')) {
return this.providers.get('gemini')!
}
if (isOpenAILLMModel(model)) {
return this.providers.get('openai')!
}
return this.defaultProvider return this.defaultProvider
} }
@ -58,8 +69,8 @@ export default class AihubmixProvider extends BaseProvider {
} }
public async completions(params: CompletionsParams): Promise<void> { public async completions(params: CompletionsParams): Promise<void> {
const modelId = params.assistant.model?.id || '' const model = params.assistant.model
return this.getProvider(modelId).completions(params) return this.getProvider(model!).completions(params)
} }
public async translate( public async translate(
@ -67,26 +78,26 @@ export default class AihubmixProvider extends BaseProvider {
assistant: Assistant, assistant: Assistant,
onResponse?: (text: string, isComplete: boolean) => void onResponse?: (text: string, isComplete: boolean) => void
): Promise<string> { ): Promise<string> {
return this.getProvider(assistant.model?.id).translate(content, assistant, onResponse) return this.getProvider(assistant.model || getDefaultModel()).translate(content, assistant, onResponse)
} }
public async summaries(messages: Message[], assistant: Assistant): Promise<string> { public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
return this.getProvider(assistant.model?.id).summaries(messages, assistant) return this.getProvider(assistant.model || getDefaultModel()).summaries(messages, assistant)
} }
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> { public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
return this.getProvider(assistant.model?.id).summaryForSearch(messages, assistant) return this.getProvider(assistant.model || getDefaultModel()).summaryForSearch(messages, assistant)
} }
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> { public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
return this.getProvider(assistant.model?.id).suggestions(messages, assistant) return this.getProvider(assistant.model || getDefaultModel()).suggestions(messages, assistant)
} }
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> { public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
return this.getProvider(model.id).check(model, stream) return this.getProvider(model).check(model, stream)
} }
public async getEmbeddingDimensions(model: Model): Promise<number> { public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.getProvider(model.id).getEmbeddingDimensions(model) return this.getProvider(model).getEmbeddingDimensions(model)
} }
} }

View File

@ -2,7 +2,6 @@ import {
findTokenLimit, findTokenLimit,
getOpenAIWebSearchParams, getOpenAIWebSearchParams,
isHunyuanSearchModel, isHunyuanSearchModel,
isOpenAILLMModel,
isOpenAIReasoningModel, isOpenAIReasoningModel,
isOpenAIWebSearch, isOpenAIWebSearch,
isReasoningModel, isReasoningModel,
@ -331,10 +330,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
await super.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
return
}
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId
messages = addImageFileToContents(messages) messages = addImageFileToContents(messages)
@ -693,9 +688,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
async translate(content: string, assistant: Assistant, onResponse?: (text: string, isComplete: boolean) => void) { async translate(content: string, assistant: Assistant, onResponse?: (text: string, isComplete: boolean) => void) {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
return await super.translate(content, assistant, onResponse)
}
const messagesForApi = content const messagesForApi = content
? [ ? [
{ role: 'system', content: assistant.prompt }, { role: 'system', content: assistant.prompt },
@ -770,10 +763,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
public async summaries(messages: Message[], assistant: Assistant): Promise<string> { public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel() const model = getTopNamingModel() || assistant.model || getDefaultModel()
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
return await super.summaries(messages, assistant)
}
const userMessages = takeRight(messages, 5) const userMessages = takeRight(messages, 5)
.filter((message) => !message.isPreset) .filter((message) => !message.isPreset)
.map((message) => ({ .map((message) => ({
@ -823,10 +812,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> { public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = assistant.model || getDefaultModel() const model = assistant.model || getDefaultModel()
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
return await super.summaryForSearch(messages, assistant)
}
const systemMessage = { const systemMessage = {
role: 'system', role: 'system',
content: assistant.prompt content: assistant.prompt
@ -938,9 +923,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
if (!model) { if (!model) {
return { valid: false, error: new Error('No model found') } return { valid: false, error: new Error('No model found') }
} }
if (model.provider === 'aihubmix' && isOpenAILLMModel(model)) {
return await super.check(model, stream)
}
const body = { const body = {
model: model.id, model: model.id,
messages: [{ role: 'user', content: 'hi' }], messages: [{ role: 'user', content: 'hi' }],

View File

@ -11,11 +11,11 @@ export default class ProviderFactory {
static create(provider: Provider): BaseProvider { static create(provider: Provider): BaseProvider {
switch (provider.type) { switch (provider.type) {
case 'openai': case 'openai':
return new OpenAIProvider(provider)
case 'openai-compatible':
if (provider.id === 'aihubmix') { if (provider.id === 'aihubmix') {
return new AihubmixProvider(provider) return new AihubmixProvider(provider)
} }
return new OpenAIProvider(provider)
case 'openai-compatible':
return new OpenAICompatibleProvider(provider) return new OpenAICompatibleProvider(provider)
case 'anthropic': case 'anthropic':
return new AnthropicProvider(provider) return new AnthropicProvider(provider)

View File

@ -38,7 +38,7 @@ export const INITIAL_PROVIDERS: Provider[] = [
{ {
id: 'aihubmix', id: 'aihubmix',
name: 'AiHubMix', name: 'AiHubMix',
type: 'openai', type: 'openai-compatible',
apiKey: '', apiKey: '',
apiHost: 'https://aihubmix.com', apiHost: 'https://aihubmix.com',
models: SYSTEM_MODELS.aihubmix, models: SYSTEM_MODELS.aihubmix,
@ -68,7 +68,7 @@ export const INITIAL_PROVIDERS: Provider[] = [
{ {
id: 'openrouter', id: 'openrouter',
name: 'OpenRouter', name: 'OpenRouter',
type: 'openai', type: 'openai-compatible',
apiKey: '', apiKey: '',
apiHost: 'https://openrouter.ai/api/v1/', apiHost: 'https://openrouter.ai/api/v1/',
models: SYSTEM_MODELS.openrouter, models: SYSTEM_MODELS.openrouter,