feat(aiCore): enhance dynamic provider registration and refactor HubProvider

- Introduced dynamic provider registration functionality, allowing for flexible management of providers through a new registry system.
- Refactored HubProvider to streamline model resolution and improve error handling for unsupported models.
- Added utility functions for managing dynamic providers, including registration, cleanup, and alias resolution.
- Updated index exports to include new dynamic provider APIs, enhancing overall usability and integration.
- Removed outdated provider files and simplified the provider management structure for better maintainability.
This commit is contained in:
MyPrototypeWhat 2025-08-26 16:17:01 +08:00
parent 53dcda6942
commit 84eef25ff9
18 changed files with 551 additions and 256 deletions

View File

@ -5,7 +5,7 @@
* 例如: aihubmix:anthropic:claude-3.5-sonnet * 例如: aihubmix:anthropic:claude-3.5-sonnet
*/ */
import { ProviderV2 } from '@ai-sdk/provider' import { EmbeddingModelV2, ImageModelV2, ProviderV2, SpeechModelV2, TranscriptionModelV2 } from '@ai-sdk/provider'
import { customProvider } from 'ai' import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement' import { globalRegistryManagement } from './RegistryManagement'
@ -47,13 +47,7 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st
* Hub Provider * Hub Provider
*/ */
export function createHubProvider(config: HubProviderConfig): ProviderV2 { export function createHubProvider(config: HubProviderConfig): ProviderV2 {
const { hubId, debug = false } = config const { hubId } = config
function logDebug(message: string, ...args: any[]) {
if (debug) {
console.log(`[HubProvider:${hubId}] ${message}`, ...args)
}
}
function getTargetProvider(providerId: string): ProviderV2 { function getTargetProvider(providerId: string): ProviderV2 {
// 从全局注册表获取provider实例 // 从全局注册表获取provider实例
@ -77,72 +71,26 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 {
} }
} }
function resolveModel<T>(modelId: string, modelType: string, methodName: keyof ProviderV2): T {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider[methodName]) {
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
}
return (targetProvider[methodName] as any)(actualModelId)
}
return customProvider({ return customProvider({
fallbackProvider: { fallbackProvider: {
languageModel: (modelId: string) => { languageModel: (modelId: string) => resolveModel(modelId, 'language models', 'languageModel'),
logDebug('Resolving language model:', modelId) textEmbeddingModel: (modelId: string) =>
resolveModel<EmbeddingModelV2<string>>(modelId, 'text embedding models', 'textEmbeddingModel'),
const { provider, actualModelId } = parseHubModelId(modelId) imageModel: (modelId: string) => resolveModel<ImageModelV2>(modelId, 'image models', 'imageModel'),
const targetProvider = getTargetProvider(provider) transcriptionModel: (modelId: string) =>
resolveModel<TranscriptionModelV2>(modelId, 'transcription models', 'transcriptionModel'),
if (!targetProvider.languageModel) { speechModel: (modelId: string) => resolveModel<SpeechModelV2>(modelId, 'speech models', 'speechModel')
throw new HubProviderError(`Provider "${provider}" does not support language models`, hubId, provider)
}
return targetProvider.languageModel(actualModelId)
},
textEmbeddingModel: (modelId: string) => {
logDebug('Resolving text embedding model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.textEmbeddingModel) {
throw new HubProviderError(`Provider "${provider}" does not support text embedding models`, hubId, provider)
}
return targetProvider.textEmbeddingModel(actualModelId)
},
imageModel: (modelId: string) => {
logDebug('Resolving image model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.imageModel) {
throw new HubProviderError(`Provider "${provider}" does not support image models`, hubId, provider)
}
return targetProvider.imageModel(actualModelId)
},
transcriptionModel: (modelId: string) => {
logDebug('Resolving transcription model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.transcriptionModel) {
throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider)
}
return targetProvider.transcriptionModel(actualModelId)
},
speechModel: (modelId: string) => {
logDebug('Resolving speech model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.speechModel) {
throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider)
}
return targetProvider.speechModel(actualModelId)
}
} }
}) })
} }

View File

@ -15,6 +15,7 @@ export const DEFAULT_SEPARATOR = ':'
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> { export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
private providers: PROVIDERS = {} private providers: PROVIDERS = {}
private aliases: Set<string> = new Set() // 记录哪些key是别名
private separator: SEPARATOR private separator: SEPARATOR
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
@ -25,8 +26,18 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
/** /**
* provider * provider
*/ */
registerProvider(id: string, provider: ProviderV2): this { registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
// 注册主provider
this.providers[id] = provider this.providers[id] = provider
// 注册别名都指向同一个provider实例
if (aliases) {
aliases.forEach((alias) => {
this.providers[alias] = provider // 直接存储引用
this.aliases.add(alias) // 标记为别名
})
}
this.rebuildRegistry() this.rebuildRegistry()
return this return this
} }
@ -48,9 +59,31 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
} }
/** /**
* provider * provider
*/ */
unregisterProvider(id: string): this { unregisterProvider(id: string): this {
const provider = this.providers[id]
if (!provider) return this
// 如果移除的是真实ID需要清理所有指向它的别名
if (!this.aliases.has(id)) {
// 找到所有指向此provider的别名并删除
const aliasesToRemove: string[] = []
this.aliases.forEach((alias) => {
if (this.providers[alias] === provider) {
aliasesToRemove.push(alias)
}
})
aliasesToRemove.forEach((alias) => {
delete this.providers[alias]
this.aliases.delete(alias)
})
} else {
// 如果移除的是别名,只删除别名记录
this.aliases.delete(id)
}
delete this.providers[id] delete this.providers[id]
this.rebuildRegistry() this.rebuildRegistry()
return this return this
@ -121,10 +154,10 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
} }
/** /**
* provider * provider
*/ */
getRegisteredProviders(): string[] { getRegisteredProviders(): string[] {
return Object.keys(this.providers) return Object.keys(this.providers).filter((id) => !this.aliases.has(id))
} }
/** /**
@ -139,9 +172,46 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
*/ */
clear(): this { clear(): this {
this.providers = {} this.providers = {}
this.aliases.clear()
this.registry = null this.registry = null
return this return this
} }
/**
* Provider IDgetAiSdkProviderId使用
* Provider ID
* ID
*/
resolveProviderId(id: string): string {
if (!this.aliases.has(id)) return id // 不是别名,直接返回
// 是别名找到真实ID
const targetProvider = this.providers[id]
for (const [realId, provider] of Object.entries(this.providers)) {
if (provider === targetProvider && !this.aliases.has(realId)) {
return realId
}
}
return id
}
/**
*
*/
isAlias(id: string): boolean {
return this.aliases.has(id)
}
/**
*
*/
getAllAliases(): Record<string, string> {
const result: Record<string, string> = {}
this.aliases.forEach((alias) => {
result[alias] = this.resolveProviderId(alias)
})
return result
}
} }
/** /**

View File

@ -28,6 +28,20 @@ export {
reinitializeProvider reinitializeProvider
} from './registry' } from './registry'
// 动态Provider注册功能
export {
cleanup,
getAllAliases,
getAllDynamicMappings,
getDynamicProviders,
getProviderMapping,
isAlias,
isDynamicProvider,
registerDynamicProvider,
registerMultipleProviders,
resolveProviderId
} from './registry'
// ==================== 保留的导出(兼容性)==================== // ==================== 保留的导出(兼容性)====================
// 基础Provider数据源 // 基础Provider数据源

View File

@ -8,7 +8,7 @@ import { customProvider } from 'ai'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
import { globalRegistryManagement } from './RegistryManagement' import { globalRegistryManagement } from './RegistryManagement'
import { baseProviders } from './schemas' import { baseProviders, type DynamicProviderRegistration } from './schemas'
/** /**
* Provider * Provider
@ -227,6 +227,104 @@ export function hasInitializedProviders(): boolean {
return globalRegistryManagement.hasProviders() return globalRegistryManagement.hasProviders()
} }
// ==================== 动态Provider注册功能 ====================
// 全局动态provider存储
const dynamicProviders = new Map<string, DynamicProviderRegistration>()
/**
* provider
*/
export function registerDynamicProvider(config: DynamicProviderRegistration): boolean {
try {
// 验证配置
if (!config.id || !config.name) {
return false
}
// 检查是否与基础provider冲突
if (baseProviders.find((p) => p.id === config.id)) {
console.warn(`Dynamic provider "${config.id}" conflicts with base provider`)
return false
}
// 存储动态provider配置
dynamicProviders.set(config.id, config)
// 如果有creator函数立即初始化
if (config.creator) {
try {
const provider = config.creator({}) as any // 使用空配置初始化类型断言为any
const aliases = config.mappings ? Object.keys(config.mappings) : undefined
globalRegistryManagement.registerProvider(config.id, provider, aliases)
} catch (error) {
console.error(`Failed to initialize dynamic provider "${config.id}":`, error)
return false
}
}
return true
} catch (error) {
console.error(`Failed to register dynamic provider:`, error)
return false
}
}
/**
* providers
*/
export function registerMultipleProviders(configs: DynamicProviderRegistration[]): number {
let successCount = 0
configs.forEach((config) => {
if (registerDynamicProvider(config)) {
successCount++
}
})
return successCount
}
/**
* provider映射
*/
export function getProviderMapping(providerId: string): string {
return globalRegistryManagement.resolveProviderId(providerId)
}
/**
* provider
*/
export function isDynamicProvider(providerId: string): boolean {
return dynamicProviders.has(providerId)
}
/**
* providers
*/
export function getDynamicProviders(): DynamicProviderRegistration[] {
return Array.from(dynamicProviders.values())
}
/**
*
*/
export function getAllDynamicMappings(): Record<string, string> {
return globalRegistryManagement.getAllAliases()
}
/**
* providers
*/
export function cleanup(): void {
dynamicProviders.clear()
globalRegistryManagement.clear()
}
// ==================== 导出别名相关API ====================
export const resolveProviderId = (id: string) => globalRegistryManagement.resolveProviderId(id)
export const isAlias = (id: string) => globalRegistryManagement.isAlias(id)
export const getAllAliases = () => globalRegistryManagement.getAllAliases()
// ==================== 导出错误类型和工具函数 ==================== // ==================== 导出错误类型和工具函数 ====================
export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError } export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError }

View File

@ -38,7 +38,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || []) this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
} }
createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) { private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
return definePlugin({ return definePlugin({
name: '_internal_resolveModel', name: '_internal_resolveModel',
enforce: 'post', enforce: 'post',
@ -50,7 +50,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
}) })
} }
createResolveImageModelPlugin() { private createResolveImageModelPlugin() {
return definePlugin({ return definePlugin({
name: '_internal_resolveImageModel', name: '_internal_resolveImageModel',
enforce: 'post', enforce: 'post',
@ -61,7 +61,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
}) })
} }
createConfigureContextPlugin() { private createConfigureContextPlugin() {
return definePlugin({ return definePlugin({
name: '_internal_configureContext', name: '_internal_configureContext',
configureContext: async (context: AiRequestContext) => { configureContext: async (context: AiRequestContext) => {

View File

@ -133,6 +133,20 @@ export {
reinitializeProvider reinitializeProvider
} from './core/providers/registry' } from './core/providers/registry'
// ==================== 动态Provider注册和别名映射 ====================
export {
cleanup,
getAllAliases,
getAllDynamicMappings,
getDynamicProviders,
getProviderMapping,
isAlias,
isDynamicProvider,
registerDynamicProvider,
registerMultipleProviders,
resolveProviderId
} from './core/providers/registry'
// ==================== Zod Schema 和验证 ==================== // ==================== Zod Schema 和验证 ====================
export { baseProviderIds, validateProviderId } from './core/providers' export { baseProviderIds, validateProviderId } from './core/providers'

View File

@ -284,6 +284,12 @@ export class AiSdkToChunkAdapter {
} }
}) })
break break
case 'abort':
this.onChunk({
type: ChunkType.ERROR,
error: new DOMException('Request was aborted', 'AbortError')
})
break
case 'error': case 'error':
this.onChunk({ this.onChunk({
type: ChunkType.ERROR, type: ChunkType.ERROR,

View File

@ -6,7 +6,7 @@ import {
InvokeModelWithResponseStreamCommand InvokeModelWithResponseStreamCommand
} from '@aws-sdk/client-bedrock-runtime' } from '@aws-sdk/client-bedrock-runtime'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas' import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isReasoningModel } from '@renderer/config/models' import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
import { import {

View File

@ -1,109 +1,124 @@
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core' import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types' import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api' import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash' import { cloneDeep } from 'lodash'
import { createAihubmixProvider } from './aihubmix' import { aihubmixProviderCreator, newApiResolverCreator } from './config'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
export function getActualProvider(model: Model): Provider { /**
const provider = getProviderByModel(model) * provider的转换逻辑
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider */
let actualProvider = cloneDeep(provider) function handleSpecialProviders(model: Model, provider: Provider): Provider {
if (provider.type === 'vertexai' && !isVertexProvider(provider)) { if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
if (!isVertexAIConfigured()) { if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.') throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
} }
actualProvider = createVertexProvider(provider) return createVertexProvider(provider)
} }
if (provider.id === 'aihubmix') { if (provider.id === 'aihubmix') {
actualProvider = createAihubmixProvider(model, actualProvider) return aihubmixProviderCreator(model, provider)
} }
if (actualProvider.type === 'gemini') { if (provider.id === 'newapi') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta') return newApiResolverCreator(model, provider)
}
return provider
}
/**
* provider的API Host
*/
function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider }
if (formatted.type === 'gemini') {
formatted.apiHost = formatApiHost(formatted.apiHost, 'v1beta')
} else { } else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost) formatted.apiHost = formatApiHost(formatted.apiHost)
} }
return formatted
}
/**
* Provider配置
*
*/
export function getActualProvider(model: Model): Provider {
const baseProvider = getProviderByModel(model)
// 按顺序处理各种转换
let actualProvider = cloneDeep(baseProvider)
actualProvider = handleSpecialProviders(model, actualProvider)
actualProvider = formatProviderApiHost(actualProvider)
return actualProvider return actualProvider
} }
/** /**
* Provider AI SDK * Provider AI SDK
*
*/ */
export function providerToAiSdkConfig(actualProvider: Provider): { export function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible' providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap] options: ProviderSettingsMap[keyof ProviderSettingsMap]
} { } {
const aiSdkProviderId = getAiSdkProviderId(actualProvider) const aiSdkProviderId = getAiSdkProviderId(actualProvider)
const actualProviderType = actualProvider.type
const openaiResponseOptions =
actualProviderType === 'openai-response'
? {
mode: 'responses'
}
: aiSdkProviderId === 'openai'
? {
mode: 'chat'
}
: undefined
console.log('openaiResponseOptions', openaiResponseOptions)
console.log('actualProvider', actualProvider)
console.log('aiSdkProviderId', aiSdkProviderId)
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(
aiSdkProviderId,
{
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
},
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
)
// 构建基础配置
const baseConfig = {
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
}
// 处理OpenAI模式简化逻辑
const extraOptions: any = {}
if (actualProvider.type === 'openai-response') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai') {
extraOptions.mode = 'chat'
}
// 添加额外headers
if (actualProvider.extra_headers) {
extraOptions.headers = actualProvider.extra_headers
}
// 如果AI SDK支持该provider使用原生配置
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return { return {
providerId: aiSdkProviderId as ProviderId, providerId: aiSdkProviderId as ProviderId,
options options
} }
} else { }
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
return { // 否则fallback到openai-compatible
providerId: 'openai-compatible', const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
options: { return {
...options, providerId: 'openai-compatible',
name: actualProvider.id options: {
} ...options,
name: actualProvider.id,
...extraOptions
} }
} }
} }
/** /**
* 使AI SDK * 使AI SDK
* provider系统
*/ */
export function isModernSdkSupported(provider: Provider, model?: Model): boolean { export function isModernSdkSupported(provider: Provider): boolean {
// 目前支持主要的providers // 特殊检查vertexai需要配置完整
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 对于 vertexai检查配置是否完整
if (provider.type === 'vertexai' && !isVertexAIConfigured()) { if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
return false return false
} }
// 图像生成模型现在支持新的 AI SDK // 使用getAiSdkProviderId获取映射后的providerId然后检查AI SDK是否支持
// (但需要确保 provider 是支持的 const aiSdkProviderId = getAiSdkProviderId(provider)
if (model && isDedicatedImageGenerationModel(model)) { // 如果映射到了支持的provider则支持现代SDK
return true return AiCore.isSupported(aiSdkProviderId)
}
return true
} }

View File

@ -1,56 +0,0 @@
import { ProviderId } from '@cherrystudio/ai-core/types'
import { isOpenAIModel } from '@renderer/config/models'
import { Model, Provider } from '@renderer/types'
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
console.log('getAiSdkProviderIdForAihubmix', model)
const id = model.id.toLowerCase()
if (id.startsWith('claude')) {
return 'anthropic'
}
// TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
return 'google'
}
if (isOpenAIModel(model)) {
return 'openai'
}
return 'openai-compatible'
}
export function createAihubmixProvider(model: Model, provider: Provider): Provider {
const providerId = getAiSdkProviderIdForAihubmix(model)
provider = {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
if (providerId === 'google') {
return {
...provider,
type: 'gemini',
apiHost: 'https://aihubmix.com/gemini'
}
}
if (providerId === 'openai') {
return {
...provider,
type: 'openai-response'
}
}
if (providerId === 'anthropic') {
return {
...provider,
type: 'anthropic'
}
}
return provider
}

View File

@ -0,0 +1,57 @@
/**
* AiHubMix规则集
*/
import { isOpenAIModel } from '@renderer/config/models'
import { Provider } from '@renderer/types'
import { startsWith } from './helper'
import { provider2Provider } from './helper'
import type { ModelRule } from './types'
const extraProviderConfig = (provider: Provider) => {
return {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
}
const AIHUBMIX_RULES: ModelRule[] = [
{
name: 'claude',
match: startsWith('claude'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'anthropic'
})
}
},
{
name: 'gemini',
match: (model) =>
(startsWith('gemini')(model) || startsWith('imagen')(model)) &&
!model.id.endsWith('-nothink') &&
!model.id.endsWith('-search'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
apiHost: 'https://aihubmix.com/gemini'
})
}
},
{
name: 'openai',
match: isOpenAIModel,
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'openai-response'
})
}
}
]
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)

View File

@ -0,0 +1,22 @@
import type { Model, Provider } from '@renderer/types'
import type { ModelRule } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* Provider ID
* @param model
* @param rules
* @param fallback fallback的providerId
* @returns providerId
*/
export function provider2Provider(rules: ModelRule[], model: Model, provider: Provider): Provider {
for (const rule of rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return provider
}

View File

@ -0,0 +1,16 @@
// /**
// * Provider解析规则模块导出
// */
// // 导出类型
// export type { ModelRule } from './types'
// // 导出匹配函数和解析器
// export { endpointIs, resolveProvider, startsWith } from './helper'
// // 导出规则集
// export { AIHUBMIX_RULES } from './aihubmix'
// export { NEWAPI_RULES } from './newApi'
export { aihubmixProviderCreator } from './aihubmix'
export { newApiResolverCreator } from './newApi'

View File

@ -0,0 +1,52 @@
/**
* NewAPI规则集
*/
import { Provider } from '@renderer/types'
import { endpointIs, provider2Provider } from './helper'
import type { ModelRule } from './types'
const NEWAPI_RULES: ModelRule[] = [
{
name: 'anthropic',
match: endpointIs('anthropic'),
provider: (provider: Provider) => {
return {
...provider,
type: 'anthropic'
}
}
},
{
name: 'gemini',
match: endpointIs('gemini'),
provider: (provider: Provider) => {
return {
...provider,
type: 'gemini'
}
}
},
{
name: 'openai-response',
match: endpointIs('openai-response'),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai-response'
}
}
},
{
name: 'openai',
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai'
}
}
}
]
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)

View File

@ -0,0 +1,7 @@
import type { Model, Provider } from '@renderer/types'
export interface ModelRule {
name: string
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}

View File

@ -1,49 +1,75 @@
import { AiCore, type ProviderId } from '@cherrystudio/ai-core' import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { Provider } from '@renderer/types' import { Provider } from '@renderer/types'
// TODO import { initializeNewProviders } from './providerConfigs'
// 初始化新的Provider注册系统
// initializeNewProviders()
// 静态Provider映射 - 核心providers const logger = loggerService.withContext('ProviderFactory')
/**
* Provider系统
* providers
*/
;(async () => {
try {
await initializeNewProviders()
} catch (error) {
logger.warn('Failed to initialize new providers:', error as Error)
}
})()
/**
* Provider映射表
* Cherry Studio特有的provider ID到AI SDK标准ID的映射
*/
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = { const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
// anthropic: 'anthropic', gemini: 'google', // Google Gemini -> google
gemini: 'google', 'azure-openai': 'azure', // Azure OpenAI -> azure
'azure-openai': 'azure', 'openai-response': 'openai', // OpenAI Responses -> openai
'openai-response': 'openai', grok: 'xai' // Grok -> xai
grok: 'xai'
} }
/**
* provider标识符
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查动态映射
const dynamicMapping = getProviderMapping(identifier)
if (dynamicMapping && dynamicMapping !== identifier) {
return dynamicMapping as ProviderId
}
// 3. 检查AiCore是否直接支持
if (AiCore.isSupported(identifier)) {
return identifier as ProviderId
}
return null
}
/**
* AI SDK Provider ID
*
*/
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' { export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
// 1. 首先检查静态映射 // 1. 尝试解析provider.id
const staticProviderId = STATIC_PROVIDER_MAPPING[provider.id] const resolvedFromId = tryResolveProviderId(provider.id)
if (staticProviderId) { if (resolvedFromId) {
return staticProviderId return resolvedFromId
}
// TODO
// 2. 检查动态注册的provider映射使用aiCore的函数
// const dynamicProviderId = getProviderMapping(provider.id)
// if (dynamicProviderId) {
// return dynamicProviderId as ProviderId
// }
// 3. 检查provider.type的静态映射
const staticProviderType = STATIC_PROVIDER_MAPPING[provider.type]
if (staticProviderType) {
return staticProviderType
}
// TODO
// 4. 检查provider.type的动态映射
// const dynamicProviderType = getProviderMapping(provider.type)
// if (dynamicProviderType) {
// return dynamicProviderType as ProviderId
// }
// 5. 检查AiCore是否直接支持
if (AiCore.isSupported(provider.id)) {
return provider.id as ProviderId
} }
// 6. 最后的fallback // 2. 尝试解析provider.type
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
// 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId return provider.id as ProviderId
} }

View File

@ -1,4 +1,4 @@
import { type ProviderConfig } from '@cherrystudio/ai-core' import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
import { loggerService } from '@logger' import { loggerService } from '@logger'
const logger = loggerService.withContext('ProviderConfigs') const logger = loggerService.withContext('ProviderConfigs')
@ -43,19 +43,29 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
} }
] as const ] as const
// TODO /**
// /** * Providers
// * 初始化新的Providers * 使aiCore的动态注册功能
// * 使用aiCore的动态注册功能 */
// */ export async function initializeNewProviders(): Promise<void> {
// export async function initializeNewProviders(): Promise<void> { try {
// try { logger.info('Starting to register new providers', {
// const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS) providerCount: NEW_PROVIDER_CONFIGS.length,
providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id)
})
// if (successCount < NEW_PROVIDER_CONFIGS.length) { const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
// logger.warn('Some providers failed to register. Check previous error logs.')
// } logger.info('Provider registration completed', {
// } catch (error) { successCount,
// logger.error('Failed to initialize new providers:', error as Error) totalCount: NEW_PROVIDER_CONFIGS.length,
// } failedCount: NEW_PROVIDER_CONFIGS.length - successCount
// } })
if (successCount < NEW_PROVIDER_CONFIGS.length) {
logger.warn('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
}

View File

@ -28,7 +28,6 @@ export function buildProviderOptions(
} }
): Record<string, any> { ): Record<string, any> {
const providerId = getAiSdkProviderId(actualProvider) const providerId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项 // 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {} let providerSpecificOptions: Record<string, any> = {}
@ -37,11 +36,8 @@ export function buildProviderOptions(
case 'openai': case 'openai':
case 'azure': case 'azure':
providerSpecificOptions = { providerSpecificOptions = {
...buildOpenAIProviderOptions(assistant, model, capabilities), ...buildOpenAIProviderOptions(assistant, model, capabilities)
// 函数内有对于真实provider.id的判断,应该不会影响原生provider
...buildGenericProviderOptions(assistant, model, capabilities)
} }
break break
case 'anthropic': case 'anthropic':