mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat(aiCore): enhance dynamic provider registration and refactor HubProvider
- Introduced dynamic provider registration functionality, allowing for flexible management of providers through a new registry system. - Refactored HubProvider to streamline model resolution and improve error handling for unsupported models. - Added utility functions for managing dynamic providers, including registration, cleanup, and alias resolution. - Updated index exports to include new dynamic provider APIs, enhancing overall usability and integration. - Removed outdated provider files and simplified the provider management structure for better maintainability.
This commit is contained in:
parent
53dcda6942
commit
84eef25ff9
@ -5,7 +5,7 @@
|
||||
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||
*/
|
||||
|
||||
import { ProviderV2 } from '@ai-sdk/provider'
|
||||
import { EmbeddingModelV2, ImageModelV2, ProviderV2, SpeechModelV2, TranscriptionModelV2 } from '@ai-sdk/provider'
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
@ -47,13 +47,7 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st
|
||||
* 创建Hub Provider
|
||||
*/
|
||||
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
||||
const { hubId, debug = false } = config
|
||||
|
||||
function logDebug(message: string, ...args: any[]) {
|
||||
if (debug) {
|
||||
console.log(`[HubProvider:${hubId}] ${message}`, ...args)
|
||||
}
|
||||
}
|
||||
const { hubId } = config
|
||||
|
||||
function getTargetProvider(providerId: string): ProviderV2 {
|
||||
// 从全局注册表获取provider实例
|
||||
@ -77,72 +71,26 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
||||
}
|
||||
}
|
||||
|
||||
function resolveModel<T>(modelId: string, modelType: string, methodName: keyof ProviderV2): T {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider[methodName]) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
||||
}
|
||||
|
||||
return (targetProvider[methodName] as any)(actualModelId)
|
||||
}
|
||||
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
languageModel: (modelId: string) => {
|
||||
logDebug('Resolving language model:', modelId)
|
||||
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.languageModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support language models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.languageModel(actualModelId)
|
||||
},
|
||||
|
||||
textEmbeddingModel: (modelId: string) => {
|
||||
logDebug('Resolving text embedding model:', modelId)
|
||||
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.textEmbeddingModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support text embedding models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.textEmbeddingModel(actualModelId)
|
||||
},
|
||||
|
||||
imageModel: (modelId: string) => {
|
||||
logDebug('Resolving image model:', modelId)
|
||||
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.imageModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support image models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.imageModel(actualModelId)
|
||||
},
|
||||
|
||||
transcriptionModel: (modelId: string) => {
|
||||
logDebug('Resolving transcription model:', modelId)
|
||||
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.transcriptionModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.transcriptionModel(actualModelId)
|
||||
},
|
||||
|
||||
speechModel: (modelId: string) => {
|
||||
logDebug('Resolving speech model:', modelId)
|
||||
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider.speechModel) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider)
|
||||
}
|
||||
|
||||
return targetProvider.speechModel(actualModelId)
|
||||
}
|
||||
languageModel: (modelId: string) => resolveModel(modelId, 'language models', 'languageModel'),
|
||||
textEmbeddingModel: (modelId: string) =>
|
||||
resolveModel<EmbeddingModelV2<string>>(modelId, 'text embedding models', 'textEmbeddingModel'),
|
||||
imageModel: (modelId: string) => resolveModel<ImageModelV2>(modelId, 'image models', 'imageModel'),
|
||||
transcriptionModel: (modelId: string) =>
|
||||
resolveModel<TranscriptionModelV2>(modelId, 'transcription models', 'transcriptionModel'),
|
||||
speechModel: (modelId: string) => resolveModel<SpeechModelV2>(modelId, 'speech models', 'speechModel')
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ export const DEFAULT_SEPARATOR = ':'
|
||||
|
||||
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
|
||||
private providers: PROVIDERS = {}
|
||||
private aliases: Set<string> = new Set() // 记录哪些key是别名
|
||||
private separator: SEPARATOR
|
||||
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
|
||||
|
||||
@ -25,8 +26,18 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
|
||||
/**
|
||||
* 注册已配置好的 provider 实例
|
||||
*/
|
||||
registerProvider(id: string, provider: ProviderV2): this {
|
||||
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
|
||||
// 注册主provider
|
||||
this.providers[id] = provider
|
||||
|
||||
// 注册别名(都指向同一个provider实例)
|
||||
if (aliases) {
|
||||
aliases.forEach((alias) => {
|
||||
this.providers[alias] = provider // 直接存储引用
|
||||
this.aliases.add(alias) // 标记为别名
|
||||
})
|
||||
}
|
||||
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
@ -48,9 +59,31 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除 provider
|
||||
* 移除 provider(同时清理相关别名)
|
||||
*/
|
||||
unregisterProvider(id: string): this {
|
||||
const provider = this.providers[id]
|
||||
if (!provider) return this
|
||||
|
||||
// 如果移除的是真实ID,需要清理所有指向它的别名
|
||||
if (!this.aliases.has(id)) {
|
||||
// 找到所有指向此provider的别名并删除
|
||||
const aliasesToRemove: string[] = []
|
||||
this.aliases.forEach((alias) => {
|
||||
if (this.providers[alias] === provider) {
|
||||
aliasesToRemove.push(alias)
|
||||
}
|
||||
})
|
||||
|
||||
aliasesToRemove.forEach((alias) => {
|
||||
delete this.providers[alias]
|
||||
this.aliases.delete(alias)
|
||||
})
|
||||
} else {
|
||||
// 如果移除的是别名,只删除别名记录
|
||||
this.aliases.delete(id)
|
||||
}
|
||||
|
||||
delete this.providers[id]
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
@ -121,10 +154,10 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的 provider 列表
|
||||
* 获取已注册的 provider 列表(排除别名)
|
||||
*/
|
||||
getRegisteredProviders(): string[] {
|
||||
return Object.keys(this.providers)
|
||||
return Object.keys(this.providers).filter((id) => !this.aliases.has(id))
|
||||
}
|
||||
|
||||
/**
|
||||
@ -139,9 +172,46 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
|
||||
*/
|
||||
clear(): this {
|
||||
this.providers = {}
|
||||
this.aliases.clear()
|
||||
this.registry = null
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的Provider ID(供getAiSdkProviderId使用)
|
||||
* 如果传入的是别名,返回真实的Provider ID
|
||||
* 如果传入的是真实ID,直接返回
|
||||
*/
|
||||
resolveProviderId(id: string): string {
|
||||
if (!this.aliases.has(id)) return id // 不是别名,直接返回
|
||||
|
||||
// 是别名,找到真实ID
|
||||
const targetProvider = this.providers[id]
|
||||
for (const [realId, provider] of Object.entries(this.providers)) {
|
||||
if (provider === targetProvider && !this.aliases.has(realId)) {
|
||||
return realId
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为别名
|
||||
*/
|
||||
isAlias(id: string): boolean {
|
||||
return this.aliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有别名映射关系
|
||||
*/
|
||||
getAllAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
this.aliases.forEach((alias) => {
|
||||
result[alias] = this.resolveProviderId(alias)
|
||||
})
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -28,6 +28,20 @@ export {
|
||||
reinitializeProvider
|
||||
} from './registry'
|
||||
|
||||
// 动态Provider注册功能
|
||||
export {
|
||||
cleanup,
|
||||
getAllAliases,
|
||||
getAllDynamicMappings,
|
||||
getDynamicProviders,
|
||||
getProviderMapping,
|
||||
isAlias,
|
||||
isDynamicProvider,
|
||||
registerDynamicProvider,
|
||||
registerMultipleProviders,
|
||||
resolveProviderId
|
||||
} from './registry'
|
||||
|
||||
// ==================== 保留的导出(兼容性)====================
|
||||
|
||||
// 基础Provider数据源
|
||||
|
||||
@ -8,7 +8,7 @@ import { customProvider } from 'ai'
|
||||
|
||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import { baseProviders } from './schemas'
|
||||
import { baseProviders, type DynamicProviderRegistration } from './schemas'
|
||||
|
||||
/**
|
||||
* Provider 初始化错误类型
|
||||
@ -227,6 +227,104 @@ export function hasInitializedProviders(): boolean {
|
||||
return globalRegistryManagement.hasProviders()
|
||||
}
|
||||
|
||||
// ==================== 动态Provider注册功能 ====================
|
||||
|
||||
// 全局动态provider存储
|
||||
const dynamicProviders = new Map<string, DynamicProviderRegistration>()
|
||||
|
||||
/**
|
||||
* 注册动态provider
|
||||
*/
|
||||
export function registerDynamicProvider(config: DynamicProviderRegistration): boolean {
|
||||
try {
|
||||
// 验证配置
|
||||
if (!config.id || !config.name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否与基础provider冲突
|
||||
if (baseProviders.find((p) => p.id === config.id)) {
|
||||
console.warn(`Dynamic provider "${config.id}" conflicts with base provider`)
|
||||
return false
|
||||
}
|
||||
|
||||
// 存储动态provider配置
|
||||
dynamicProviders.set(config.id, config)
|
||||
|
||||
// 如果有creator函数,立即初始化
|
||||
if (config.creator) {
|
||||
try {
|
||||
const provider = config.creator({}) as any // 使用空配置初始化,类型断言为any
|
||||
const aliases = config.mappings ? Object.keys(config.mappings) : undefined
|
||||
globalRegistryManagement.registerProvider(config.id, provider, aliases)
|
||||
} catch (error) {
|
||||
console.error(`Failed to initialize dynamic provider "${config.id}":`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register dynamic provider:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册动态providers
|
||||
*/
|
||||
export function registerMultipleProviders(configs: DynamicProviderRegistration[]): number {
|
||||
let successCount = 0
|
||||
configs.forEach((config) => {
|
||||
if (registerDynamicProvider(config)) {
|
||||
successCount++
|
||||
}
|
||||
})
|
||||
return successCount
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取provider映射(解析别名)
|
||||
*/
|
||||
export function getProviderMapping(providerId: string): string {
|
||||
return globalRegistryManagement.resolveProviderId(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为动态provider
|
||||
*/
|
||||
export function isDynamicProvider(providerId: string): boolean {
|
||||
return dynamicProviders.has(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有动态providers
|
||||
*/
|
||||
export function getDynamicProviders(): DynamicProviderRegistration[] {
|
||||
return Array.from(dynamicProviders.values())
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有别名映射关系
|
||||
*/
|
||||
export function getAllDynamicMappings(): Record<string, string> {
|
||||
return globalRegistryManagement.getAllAliases()
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有动态providers
|
||||
*/
|
||||
export function cleanup(): void {
|
||||
dynamicProviders.clear()
|
||||
globalRegistryManagement.clear()
|
||||
}
|
||||
|
||||
// ==================== 导出别名相关API ====================
|
||||
|
||||
export const resolveProviderId = (id: string) => globalRegistryManagement.resolveProviderId(id)
|
||||
export const isAlias = (id: string) => globalRegistryManagement.isAlias(id)
|
||||
export const getAllAliases = () => globalRegistryManagement.getAllAliases()
|
||||
|
||||
// ==================== 导出错误类型和工具函数 ====================
|
||||
|
||||
export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError }
|
||||
|
||||
@ -38,7 +38,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||
}
|
||||
|
||||
createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveModel',
|
||||
enforce: 'post',
|
||||
@ -50,7 +50,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
})
|
||||
}
|
||||
|
||||
createResolveImageModelPlugin() {
|
||||
private createResolveImageModelPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveImageModel',
|
||||
enforce: 'post',
|
||||
@ -61,7 +61,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
})
|
||||
}
|
||||
|
||||
createConfigureContextPlugin() {
|
||||
private createConfigureContextPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_configureContext',
|
||||
configureContext: async (context: AiRequestContext) => {
|
||||
|
||||
@ -133,6 +133,20 @@ export {
|
||||
reinitializeProvider
|
||||
} from './core/providers/registry'
|
||||
|
||||
// ==================== 动态Provider注册和别名映射 ====================
|
||||
export {
|
||||
cleanup,
|
||||
getAllAliases,
|
||||
getAllDynamicMappings,
|
||||
getDynamicProviders,
|
||||
getProviderMapping,
|
||||
isAlias,
|
||||
isDynamicProvider,
|
||||
registerDynamicProvider,
|
||||
registerMultipleProviders,
|
||||
resolveProviderId
|
||||
} from './core/providers/registry'
|
||||
|
||||
// ==================== Zod Schema 和验证 ====================
|
||||
export { baseProviderIds, validateProviderId } from './core/providers'
|
||||
|
||||
|
||||
@ -284,6 +284,12 @@ export class AiSdkToChunkAdapter {
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'abort':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: new DOMException('Request was aborted', 'AbortError')
|
||||
})
|
||||
break
|
||||
case 'error':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
|
||||
@ -6,7 +6,7 @@ import {
|
||||
InvokeModelWithResponseStreamCommand
|
||||
} from '@aws-sdk/client-bedrock-runtime'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
|
||||
import {
|
||||
|
||||
@ -1,109 +1,124 @@
|
||||
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
import { createAihubmixProvider } from './aihubmix'
|
||||
import { aihubmixProviderCreator, newApiResolverCreator } from './config'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
|
||||
export function getActualProvider(model: Model): Provider {
|
||||
const provider = getProviderByModel(model)
|
||||
// 如果是 vertexai 类型且没有 googleCredentials,转换为 VertexProvider
|
||||
let actualProvider = cloneDeep(provider)
|
||||
/**
|
||||
* 处理特殊provider的转换逻辑
|
||||
*/
|
||||
function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
||||
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
actualProvider = createVertexProvider(provider)
|
||||
return createVertexProvider(provider)
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
actualProvider = createAihubmixProvider(model, actualProvider)
|
||||
return aihubmixProviderCreator(model, provider)
|
||||
}
|
||||
if (actualProvider.type === 'gemini') {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
|
||||
if (provider.id === 'newapi') {
|
||||
return newApiResolverCreator(model, provider)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化provider的API Host
|
||||
*/
|
||||
function formatProviderApiHost(provider: Provider): Provider {
|
||||
const formatted = { ...provider }
|
||||
if (formatted.type === 'gemini') {
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost, 'v1beta')
|
||||
} else {
|
||||
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
|
||||
formatted.apiHost = formatApiHost(formatted.apiHost)
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取实际的Provider配置
|
||||
* 简化版:将逻辑分解为小函数
|
||||
*/
|
||||
export function getActualProvider(model: Model): Provider {
|
||||
const baseProvider = getProviderByModel(model)
|
||||
|
||||
// 按顺序处理各种转换
|
||||
let actualProvider = cloneDeep(baseProvider)
|
||||
actualProvider = handleSpecialProviders(model, actualProvider)
|
||||
actualProvider = formatProviderApiHost(actualProvider)
|
||||
|
||||
return actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
* 简化版:利用新的别名映射系统
|
||||
*/
|
||||
export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
const actualProviderType = actualProvider.type
|
||||
const openaiResponseOptions =
|
||||
actualProviderType === 'openai-response'
|
||||
? {
|
||||
mode: 'responses'
|
||||
}
|
||||
: aiSdkProviderId === 'openai'
|
||||
? {
|
||||
mode: 'chat'
|
||||
}
|
||||
: undefined
|
||||
console.log('openaiResponseOptions', openaiResponseOptions)
|
||||
console.log('actualProvider', actualProvider)
|
||||
console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(
|
||||
aiSdkProviderId,
|
||||
{
|
||||
baseURL: actualProvider.apiHost,
|
||||
apiKey: actualProvider.apiKey
|
||||
},
|
||||
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
|
||||
)
|
||||
|
||||
// 构建基础配置
|
||||
const baseConfig = {
|
||||
baseURL: actualProvider.apiHost,
|
||||
apiKey: actualProvider.apiKey
|
||||
}
|
||||
|
||||
// 处理OpenAI模式(简化逻辑)
|
||||
const extraOptions: any = {}
|
||||
if (actualProvider.type === 'openai-response') {
|
||||
extraOptions.mode = 'responses'
|
||||
} else if (aiSdkProviderId === 'openai') {
|
||||
extraOptions.mode = 'chat'
|
||||
}
|
||||
|
||||
// 添加额外headers
|
||||
if (actualProvider.extra_headers) {
|
||||
extraOptions.headers = actualProvider.extra_headers
|
||||
}
|
||||
|
||||
// 如果AI SDK支持该provider,使用原生配置
|
||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||
return {
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
options
|
||||
}
|
||||
} else {
|
||||
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
|
||||
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
options: {
|
||||
...options,
|
||||
name: actualProvider.id
|
||||
}
|
||||
// 否则fallback到openai-compatible
|
||||
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
|
||||
return {
|
||||
providerId: 'openai-compatible',
|
||||
options: {
|
||||
...options,
|
||||
name: actualProvider.id,
|
||||
...extraOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否支持使用新的AI SDK
|
||||
* 简化版:利用新的别名映射和动态provider系统
|
||||
*/
|
||||
export function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
// 目前支持主要的providers
|
||||
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
|
||||
|
||||
// 检查provider类型
|
||||
if (!supportedProviders.includes(provider.type)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于 vertexai,检查配置是否完整
|
||||
export function isModernSdkSupported(provider: Provider): boolean {
|
||||
// 特殊检查:vertexai需要配置完整
|
||||
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 图像生成模型现在支持新的 AI SDK
|
||||
// (但需要确保 provider 是支持的
|
||||
// 使用getAiSdkProviderId获取映射后的providerId,然后检查AI SDK是否支持
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
if (model && isDedicatedImageGenerationModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
// 如果映射到了支持的provider,则支持现代SDK
|
||||
return AiCore.isSupported(aiSdkProviderId)
|
||||
}
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
import { ProviderId } from '@cherrystudio/ai-core/types'
|
||||
import { isOpenAIModel } from '@renderer/config/models'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
|
||||
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
|
||||
console.log('getAiSdkProviderIdForAihubmix', model)
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
if (id.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
// TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (isOpenAIModel(model)) {
|
||||
return 'openai'
|
||||
}
|
||||
|
||||
return 'openai-compatible'
|
||||
}
|
||||
|
||||
export function createAihubmixProvider(model: Model, provider: Provider): Provider {
|
||||
const providerId = getAiSdkProviderIdForAihubmix(model)
|
||||
provider = {
|
||||
...provider,
|
||||
extra_headers: {
|
||||
...provider.extra_headers,
|
||||
'APP-Code': 'MLTG2087'
|
||||
}
|
||||
}
|
||||
if (providerId === 'google') {
|
||||
return {
|
||||
...provider,
|
||||
type: 'gemini',
|
||||
apiHost: 'https://aihubmix.com/gemini'
|
||||
}
|
||||
}
|
||||
|
||||
if (providerId === 'openai') {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai-response'
|
||||
}
|
||||
}
|
||||
|
||||
if (providerId === 'anthropic') {
|
||||
return {
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
}
|
||||
}
|
||||
|
||||
return provider
|
||||
}
|
||||
57
src/renderer/src/aiCore/provider/config/aihubmix.ts
Normal file
57
src/renderer/src/aiCore/provider/config/aihubmix.ts
Normal file
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* AiHubMix规则集
|
||||
*/
|
||||
import { isOpenAIModel } from '@renderer/config/models'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { startsWith } from './helper'
|
||||
import { provider2Provider } from './helper'
|
||||
import type { ModelRule } from './types'
|
||||
|
||||
const extraProviderConfig = (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
extra_headers: {
|
||||
...provider.extra_headers,
|
||||
'APP-Code': 'MLTG2087'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const AIHUBMIX_RULES: ModelRule[] = [
|
||||
{
|
||||
name: 'claude',
|
||||
match: startsWith('claude'),
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
})
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'gemini',
|
||||
match: (model) =>
|
||||
(startsWith('gemini')(model) || startsWith('imagen')(model)) &&
|
||||
!model.id.endsWith('-nothink') &&
|
||||
!model.id.endsWith('-search'),
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
apiHost: 'https://aihubmix.com/gemini'
|
||||
})
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'openai',
|
||||
match: isOpenAIModel,
|
||||
provider: (provider: Provider) => {
|
||||
return extraProviderConfig({
|
||||
...provider,
|
||||
type: 'openai-response'
|
||||
})
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)
|
||||
22
src/renderer/src/aiCore/provider/config/helper.ts
Normal file
22
src/renderer/src/aiCore/provider/config/helper.ts
Normal file
@ -0,0 +1,22 @@
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
|
||||
import type { ModelRule } from './types'
|
||||
|
||||
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
|
||||
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
|
||||
|
||||
/**
|
||||
* 解析模型对应的Provider ID
|
||||
* @param model 模型对象
|
||||
* @param rules 匹配规则数组
|
||||
* @param fallback 默认fallback的providerId
|
||||
* @returns 解析出的providerId
|
||||
*/
|
||||
export function provider2Provider(rules: ModelRule[], model: Model, provider: Provider): Provider {
|
||||
for (const rule of rules) {
|
||||
if (rule.match(model)) {
|
||||
return rule.provider(provider)
|
||||
}
|
||||
}
|
||||
return provider
|
||||
}
|
||||
16
src/renderer/src/aiCore/provider/config/index.ts
Normal file
16
src/renderer/src/aiCore/provider/config/index.ts
Normal file
@ -0,0 +1,16 @@
|
||||
// /**
|
||||
// * Provider解析规则模块导出
|
||||
// */
|
||||
|
||||
// // 导出类型
|
||||
// export type { ModelRule } from './types'
|
||||
|
||||
// // 导出匹配函数和解析器
|
||||
// export { endpointIs, resolveProvider, startsWith } from './helper'
|
||||
|
||||
// // 导出规则集
|
||||
// export { AIHUBMIX_RULES } from './aihubmix'
|
||||
// export { NEWAPI_RULES } from './newApi'
|
||||
|
||||
export { aihubmixProviderCreator } from './aihubmix'
|
||||
export { newApiResolverCreator } from './newApi'
|
||||
52
src/renderer/src/aiCore/provider/config/newApi.ts
Normal file
52
src/renderer/src/aiCore/provider/config/newApi.ts
Normal file
@ -0,0 +1,52 @@
|
||||
/**
|
||||
* NewAPI规则集
|
||||
*/
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { endpointIs, provider2Provider } from './helper'
|
||||
import type { ModelRule } from './types'
|
||||
|
||||
const NEWAPI_RULES: ModelRule[] = [
|
||||
{
|
||||
name: 'anthropic',
|
||||
match: endpointIs('anthropic'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'anthropic'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'gemini',
|
||||
match: endpointIs('gemini'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'gemini'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'openai-response',
|
||||
match: endpointIs('openai-response'),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai-response'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'openai',
|
||||
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
|
||||
provider: (provider: Provider) => {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)
|
||||
7
src/renderer/src/aiCore/provider/config/types.ts
Normal file
7
src/renderer/src/aiCore/provider/config/types.ts
Normal file
@ -0,0 +1,7 @@
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
|
||||
export interface ModelRule {
|
||||
name: string
|
||||
match: (model: Model) => boolean
|
||||
provider: (provider: Provider) => Provider
|
||||
}
|
||||
@ -1,49 +1,75 @@
|
||||
import { AiCore, type ProviderId } from '@cherrystudio/ai-core'
|
||||
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
// TODO
|
||||
// 初始化新的Provider注册系统
|
||||
// initializeNewProviders()
|
||||
import { initializeNewProviders } from './providerConfigs'
|
||||
|
||||
// 静态Provider映射 - 核心providers
|
||||
const logger = loggerService.withContext('ProviderFactory')
|
||||
|
||||
/**
|
||||
* 初始化动态Provider系统
|
||||
* 在模块加载时自动注册新的providers
|
||||
*/
|
||||
;(async () => {
|
||||
try {
|
||||
await initializeNewProviders()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to initialize new providers:', error as Error)
|
||||
}
|
||||
})()
|
||||
|
||||
/**
|
||||
* 静态Provider映射表
|
||||
* 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射
|
||||
*/
|
||||
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
// anthropic: 'anthropic',
|
||||
gemini: 'google',
|
||||
'azure-openai': 'azure',
|
||||
'openai-response': 'openai',
|
||||
grok: 'xai'
|
||||
gemini: 'google', // Google Gemini -> google
|
||||
'azure-openai': 'azure', // Azure OpenAI -> azure
|
||||
'openai-response': 'openai', // OpenAI Responses -> openai
|
||||
grok: 'xai' // Grok -> xai
|
||||
}
|
||||
|
||||
/**
|
||||
* 尝试解析provider标识符(支持静态映射和动态映射)
|
||||
*/
|
||||
function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||
// 1. 检查静态映射
|
||||
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
|
||||
if (staticMapping) {
|
||||
return staticMapping
|
||||
}
|
||||
|
||||
// 2. 检查动态映射
|
||||
const dynamicMapping = getProviderMapping(identifier)
|
||||
if (dynamicMapping && dynamicMapping !== identifier) {
|
||||
return dynamicMapping as ProviderId
|
||||
}
|
||||
|
||||
// 3. 检查AiCore是否直接支持
|
||||
if (AiCore.isSupported(identifier)) {
|
||||
return identifier as ProviderId
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取AI SDK Provider ID
|
||||
* 简化版:减少重复逻辑,利用通用解析函数
|
||||
*/
|
||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
||||
// 1. 首先检查静态映射
|
||||
const staticProviderId = STATIC_PROVIDER_MAPPING[provider.id]
|
||||
if (staticProviderId) {
|
||||
return staticProviderId
|
||||
}
|
||||
// TODO
|
||||
// 2. 检查动态注册的provider映射(使用aiCore的函数)
|
||||
// const dynamicProviderId = getProviderMapping(provider.id)
|
||||
// if (dynamicProviderId) {
|
||||
// return dynamicProviderId as ProviderId
|
||||
// }
|
||||
|
||||
// 3. 检查provider.type的静态映射
|
||||
const staticProviderType = STATIC_PROVIDER_MAPPING[provider.type]
|
||||
if (staticProviderType) {
|
||||
return staticProviderType
|
||||
}
|
||||
// TODO
|
||||
// 4. 检查provider.type的动态映射
|
||||
// const dynamicProviderType = getProviderMapping(provider.type)
|
||||
// if (dynamicProviderType) {
|
||||
// return dynamicProviderType as ProviderId
|
||||
// }
|
||||
|
||||
// 5. 检查AiCore是否直接支持
|
||||
if (AiCore.isSupported(provider.id)) {
|
||||
return provider.id as ProviderId
|
||||
// 1. 尝试解析provider.id
|
||||
const resolvedFromId = tryResolveProviderId(provider.id)
|
||||
if (resolvedFromId) {
|
||||
return resolvedFromId
|
||||
}
|
||||
|
||||
// 6. 最后的fallback
|
||||
// 2. 尝试解析provider.type
|
||||
const resolvedFromType = tryResolveProviderId(provider.type)
|
||||
if (resolvedFromType) {
|
||||
return resolvedFromType
|
||||
}
|
||||
|
||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { type ProviderConfig } from '@cherrystudio/ai-core'
|
||||
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
const logger = loggerService.withContext('ProviderConfigs')
|
||||
@ -43,19 +43,29 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
}
|
||||
] as const
|
||||
|
||||
// TODO
|
||||
// /**
|
||||
// * 初始化新的Providers
|
||||
// * 使用aiCore的动态注册功能
|
||||
// */
|
||||
// export async function initializeNewProviders(): Promise<void> {
|
||||
// try {
|
||||
// const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
|
||||
/**
|
||||
* 初始化新的Providers
|
||||
* 使用aiCore的动态注册功能
|
||||
*/
|
||||
export async function initializeNewProviders(): Promise<void> {
|
||||
try {
|
||||
logger.info('Starting to register new providers', {
|
||||
providerCount: NEW_PROVIDER_CONFIGS.length,
|
||||
providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id)
|
||||
})
|
||||
|
||||
// if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
||||
// logger.warn('Some providers failed to register. Check previous error logs.')
|
||||
// }
|
||||
// } catch (error) {
|
||||
// logger.error('Failed to initialize new providers:', error as Error)
|
||||
// }
|
||||
// }
|
||||
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
|
||||
|
||||
logger.info('Provider registration completed', {
|
||||
successCount,
|
||||
totalCount: NEW_PROVIDER_CONFIGS.length,
|
||||
failedCount: NEW_PROVIDER_CONFIGS.length - successCount
|
||||
})
|
||||
|
||||
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
||||
logger.warn('Some providers failed to register. Check previous error logs.')
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize new providers:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,7 +28,6 @@ export function buildProviderOptions(
|
||||
}
|
||||
): Record<string, any> {
|
||||
const providerId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
@ -37,11 +36,8 @@ export function buildProviderOptions(
|
||||
case 'openai':
|
||||
case 'azure':
|
||||
providerSpecificOptions = {
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
||||
// 函数内有对于真实provider.id的判断,应该不会影响原生provider
|
||||
...buildGenericProviderOptions(assistant, model, capabilities)
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
}
|
||||
|
||||
break
|
||||
|
||||
case 'anthropic':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user