mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-05 12:29:44 +08:00
feat: refactor AihubmixProvider and OpenAICompatibleProvider for improved model handling (#5732)
This commit is contained in:
parent
5c2998cc48
commit
aeb8091c89
@ -1,3 +1,5 @@
|
|||||||
|
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||||
|
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||||
import { Assistant, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { Message } from '@renderer/types/newMessage'
|
import { Message } from '@renderer/types/newMessage'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
@ -6,6 +8,7 @@ import { CompletionsParams } from '.'
|
|||||||
import AnthropicProvider from './AnthropicProvider'
|
import AnthropicProvider from './AnthropicProvider'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
import GeminiProvider from './GeminiProvider'
|
import GeminiProvider from './GeminiProvider'
|
||||||
|
import OpenAICompatibleProvider from './OpenAICompatibleProvider'
|
||||||
import OpenAIProvider from './OpenAIProvider'
|
import OpenAIProvider from './OpenAIProvider'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -22,20 +25,28 @@ export default class AihubmixProvider extends BaseProvider {
|
|||||||
// 初始化各个提供商
|
// 初始化各个提供商
|
||||||
this.providers.set('claude', new AnthropicProvider(provider))
|
this.providers.set('claude', new AnthropicProvider(provider))
|
||||||
this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' }))
|
this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' }))
|
||||||
this.providers.set('default', new OpenAIProvider(provider))
|
this.providers.set('openai', new OpenAIProvider(provider))
|
||||||
|
this.providers.set('default', new OpenAICompatibleProvider(provider))
|
||||||
|
|
||||||
// 设置默认提供商
|
// 设置默认提供商
|
||||||
this.defaultProvider = this.providers.get('default')!
|
this.defaultProvider = this.providers.get('default')!
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据模型ID获取合适的提供商
|
* 根据模型获取合适的提供商
|
||||||
*/
|
*/
|
||||||
private getProvider(modelId: string = ''): BaseProvider {
|
private getProvider(model: Model): BaseProvider {
|
||||||
const id = modelId.toLowerCase()
|
const id = model.id.toLowerCase()
|
||||||
|
|
||||||
if (id.includes('claude')) return this.providers.get('claude')!
|
if (id.includes('claude')) {
|
||||||
if (id.includes('gemini')) return this.providers.get('gemini')!
|
return this.providers.get('claude')!
|
||||||
|
}
|
||||||
|
if (id.includes('gemini')) {
|
||||||
|
return this.providers.get('gemini')!
|
||||||
|
}
|
||||||
|
if (isOpenAILLMModel(model)) {
|
||||||
|
return this.providers.get('openai')!
|
||||||
|
}
|
||||||
|
|
||||||
return this.defaultProvider
|
return this.defaultProvider
|
||||||
}
|
}
|
||||||
@ -58,8 +69,8 @@ export default class AihubmixProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public async completions(params: CompletionsParams): Promise<void> {
|
public async completions(params: CompletionsParams): Promise<void> {
|
||||||
const modelId = params.assistant.model?.id || ''
|
const model = params.assistant.model
|
||||||
return this.getProvider(modelId).completions(params)
|
return this.getProvider(model!).completions(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async translate(
|
public async translate(
|
||||||
@ -67,26 +78,26 @@ export default class AihubmixProvider extends BaseProvider {
|
|||||||
assistant: Assistant,
|
assistant: Assistant,
|
||||||
onResponse?: (text: string, isComplete: boolean) => void
|
onResponse?: (text: string, isComplete: boolean) => void
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
return this.getProvider(assistant.model?.id).translate(content, assistant, onResponse)
|
return this.getProvider(assistant.model || getDefaultModel()).translate(content, assistant, onResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||||
return this.getProvider(assistant.model?.id).summaries(messages, assistant)
|
return this.getProvider(assistant.model || getDefaultModel()).summaries(messages, assistant)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||||
return this.getProvider(assistant.model?.id).summaryForSearch(messages, assistant)
|
return this.getProvider(assistant.model || getDefaultModel()).summaryForSearch(messages, assistant)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||||
return this.getProvider(assistant.model?.id).suggestions(messages, assistant)
|
return this.getProvider(assistant.model || getDefaultModel()).suggestions(messages, assistant)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
||||||
return this.getProvider(model.id).check(model, stream)
|
return this.getProvider(model).check(model, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
return this.getProvider(model.id).getEmbeddingDimensions(model)
|
return this.getProvider(model).getEmbeddingDimensions(model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import {
|
|||||||
findTokenLimit,
|
findTokenLimit,
|
||||||
getOpenAIWebSearchParams,
|
getOpenAIWebSearchParams,
|
||||||
isHunyuanSearchModel,
|
isHunyuanSearchModel,
|
||||||
isOpenAILLMModel,
|
|
||||||
isOpenAIReasoningModel,
|
isOpenAIReasoningModel,
|
||||||
isOpenAIWebSearch,
|
isOpenAIWebSearch,
|
||||||
isReasoningModel,
|
isReasoningModel,
|
||||||
@ -331,10 +330,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
|
|
||||||
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
|
|
||||||
await super.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId
|
const isEnabledWebSearch = assistant.enableWebSearch || !!assistant.webSearchProviderId
|
||||||
messages = addImageFileToContents(messages)
|
messages = addImageFileToContents(messages)
|
||||||
@ -693,9 +688,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
async translate(content: string, assistant: Assistant, onResponse?: (text: string, isComplete: boolean) => void) {
|
async translate(content: string, assistant: Assistant, onResponse?: (text: string, isComplete: boolean) => void) {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
|
|
||||||
return await super.translate(content, assistant, onResponse)
|
|
||||||
}
|
|
||||||
const messagesForApi = content
|
const messagesForApi = content
|
||||||
? [
|
? [
|
||||||
{ role: 'system', content: assistant.prompt },
|
{ role: 'system', content: assistant.prompt },
|
||||||
@ -770,10 +763,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||||
|
|
||||||
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
|
|
||||||
return await super.summaries(messages, assistant)
|
|
||||||
}
|
|
||||||
|
|
||||||
const userMessages = takeRight(messages, 5)
|
const userMessages = takeRight(messages, 5)
|
||||||
.filter((message) => !message.isPreset)
|
.filter((message) => !message.isPreset)
|
||||||
.map((message) => ({
|
.map((message) => ({
|
||||||
@ -823,10 +812,6 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||||
const model = assistant.model || getDefaultModel()
|
const model = assistant.model || getDefaultModel()
|
||||||
|
|
||||||
if (assistant.model?.provider === 'aihubmix' && isOpenAILLMModel(model)) {
|
|
||||||
return await super.summaryForSearch(messages, assistant)
|
|
||||||
}
|
|
||||||
|
|
||||||
const systemMessage = {
|
const systemMessage = {
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: assistant.prompt
|
content: assistant.prompt
|
||||||
@ -938,9 +923,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider {
|
|||||||
if (!model) {
|
if (!model) {
|
||||||
return { valid: false, error: new Error('No model found') }
|
return { valid: false, error: new Error('No model found') }
|
||||||
}
|
}
|
||||||
if (model.provider === 'aihubmix' && isOpenAILLMModel(model)) {
|
|
||||||
return await super.check(model, stream)
|
|
||||||
}
|
|
||||||
const body = {
|
const body = {
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages: [{ role: 'user', content: 'hi' }],
|
messages: [{ role: 'user', content: 'hi' }],
|
||||||
|
|||||||
@ -11,11 +11,11 @@ export default class ProviderFactory {
|
|||||||
static create(provider: Provider): BaseProvider {
|
static create(provider: Provider): BaseProvider {
|
||||||
switch (provider.type) {
|
switch (provider.type) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
|
return new OpenAIProvider(provider)
|
||||||
|
case 'openai-compatible':
|
||||||
if (provider.id === 'aihubmix') {
|
if (provider.id === 'aihubmix') {
|
||||||
return new AihubmixProvider(provider)
|
return new AihubmixProvider(provider)
|
||||||
}
|
}
|
||||||
return new OpenAIProvider(provider)
|
|
||||||
case 'openai-compatible':
|
|
||||||
return new OpenAICompatibleProvider(provider)
|
return new OpenAICompatibleProvider(provider)
|
||||||
case 'anthropic':
|
case 'anthropic':
|
||||||
return new AnthropicProvider(provider)
|
return new AnthropicProvider(provider)
|
||||||
|
|||||||
@ -38,7 +38,7 @@ export const INITIAL_PROVIDERS: Provider[] = [
|
|||||||
{
|
{
|
||||||
id: 'aihubmix',
|
id: 'aihubmix',
|
||||||
name: 'AiHubMix',
|
name: 'AiHubMix',
|
||||||
type: 'openai',
|
type: 'openai-compatible',
|
||||||
apiKey: '',
|
apiKey: '',
|
||||||
apiHost: 'https://aihubmix.com',
|
apiHost: 'https://aihubmix.com',
|
||||||
models: SYSTEM_MODELS.aihubmix,
|
models: SYSTEM_MODELS.aihubmix,
|
||||||
@ -68,7 +68,7 @@ export const INITIAL_PROVIDERS: Provider[] = [
|
|||||||
{
|
{
|
||||||
id: 'openrouter',
|
id: 'openrouter',
|
||||||
name: 'OpenRouter',
|
name: 'OpenRouter',
|
||||||
type: 'openai',
|
type: 'openai-compatible',
|
||||||
apiKey: '',
|
apiKey: '',
|
||||||
apiHost: 'https://openrouter.ai/api/v1/',
|
apiHost: 'https://openrouter.ai/api/v1/',
|
||||||
models: SYSTEM_MODELS.openrouter,
|
models: SYSTEM_MODELS.openrouter,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user