feat(aiCore): enhance dynamic provider registration and refactor HubProvider

- Introduced dynamic provider registration functionality, allowing for flexible management of providers through a new registry system.
- Refactored HubProvider to streamline model resolution and improve error handling for unsupported models.
- Added utility functions for managing dynamic providers, including registration, cleanup, and alias resolution.
- Updated index exports to include new dynamic provider APIs, enhancing overall usability and integration.
- Removed outdated provider files and simplified the provider management structure for better maintainability.
This commit is contained in:
MyPrototypeWhat 2025-08-26 16:17:01 +08:00
parent 53dcda6942
commit 84eef25ff9
18 changed files with 551 additions and 256 deletions

View File

@ -5,7 +5,7 @@
* 例如: aihubmix:anthropic:claude-3.5-sonnet
*/
import { ProviderV2 } from '@ai-sdk/provider'
import { EmbeddingModelV2, ImageModelV2, ProviderV2, SpeechModelV2, TranscriptionModelV2 } from '@ai-sdk/provider'
import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement'
@ -47,13 +47,7 @@ function parseHubModelId(modelId: string): { provider: string; actualModelId: st
* Hub Provider
*/
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
const { hubId, debug = false } = config
function logDebug(message: string, ...args: any[]) {
if (debug) {
console.log(`[HubProvider:${hubId}] ${message}`, ...args)
}
}
const { hubId } = config
function getTargetProvider(providerId: string): ProviderV2 {
// 从全局注册表获取provider实例
@ -77,72 +71,26 @@ export function createHubProvider(config: HubProviderConfig): ProviderV2 {
}
}
function resolveModel<T>(modelId: string, modelType: string, methodName: keyof ProviderV2): T {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider[methodName]) {
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
}
return (targetProvider[methodName] as any)(actualModelId)
}
return customProvider({
fallbackProvider: {
languageModel: (modelId: string) => {
logDebug('Resolving language model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.languageModel) {
throw new HubProviderError(`Provider "${provider}" does not support language models`, hubId, provider)
}
return targetProvider.languageModel(actualModelId)
},
textEmbeddingModel: (modelId: string) => {
logDebug('Resolving text embedding model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.textEmbeddingModel) {
throw new HubProviderError(`Provider "${provider}" does not support text embedding models`, hubId, provider)
}
return targetProvider.textEmbeddingModel(actualModelId)
},
imageModel: (modelId: string) => {
logDebug('Resolving image model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.imageModel) {
throw new HubProviderError(`Provider "${provider}" does not support image models`, hubId, provider)
}
return targetProvider.imageModel(actualModelId)
},
transcriptionModel: (modelId: string) => {
logDebug('Resolving transcription model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.transcriptionModel) {
throw new HubProviderError(`Provider "${provider}" does not support transcription models`, hubId, provider)
}
return targetProvider.transcriptionModel(actualModelId)
},
speechModel: (modelId: string) => {
logDebug('Resolving speech model:', modelId)
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
if (!targetProvider.speechModel) {
throw new HubProviderError(`Provider "${provider}" does not support speech models`, hubId, provider)
}
return targetProvider.speechModel(actualModelId)
}
languageModel: (modelId: string) => resolveModel(modelId, 'language models', 'languageModel'),
textEmbeddingModel: (modelId: string) =>
resolveModel<EmbeddingModelV2<string>>(modelId, 'text embedding models', 'textEmbeddingModel'),
imageModel: (modelId: string) => resolveModel<ImageModelV2>(modelId, 'image models', 'imageModel'),
transcriptionModel: (modelId: string) =>
resolveModel<TranscriptionModelV2>(modelId, 'transcription models', 'transcriptionModel'),
speechModel: (modelId: string) => resolveModel<SpeechModelV2>(modelId, 'speech models', 'speechModel')
}
})
}

View File

@ -15,6 +15,7 @@ export const DEFAULT_SEPARATOR = ':'
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
private providers: PROVIDERS = {}
private aliases: Set<string> = new Set() // 记录哪些key是别名
private separator: SEPARATOR
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
@ -25,8 +26,18 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
/**
* provider
*/
registerProvider(id: string, provider: ProviderV2): this {
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
// 注册主provider
this.providers[id] = provider
// 注册别名都指向同一个provider实例
if (aliases) {
aliases.forEach((alias) => {
this.providers[alias] = provider // 直接存储引用
this.aliases.add(alias) // 标记为别名
})
}
this.rebuildRegistry()
return this
}
@ -48,9 +59,31 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
}
/**
* provider
* provider
*/
unregisterProvider(id: string): this {
const provider = this.providers[id]
if (!provider) return this
// 如果移除的是真实ID需要清理所有指向它的别名
if (!this.aliases.has(id)) {
// 找到所有指向此provider的别名并删除
const aliasesToRemove: string[] = []
this.aliases.forEach((alias) => {
if (this.providers[alias] === provider) {
aliasesToRemove.push(alias)
}
})
aliasesToRemove.forEach((alias) => {
delete this.providers[alias]
this.aliases.delete(alias)
})
} else {
// 如果移除的是别名,只删除别名记录
this.aliases.delete(id)
}
delete this.providers[id]
this.rebuildRegistry()
return this
@ -121,10 +154,10 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
}
/**
* provider
* provider
*/
getRegisteredProviders(): string[] {
return Object.keys(this.providers)
return Object.keys(this.providers).filter((id) => !this.aliases.has(id))
}
/**
@ -139,9 +172,46 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
*/
clear(): this {
this.providers = {}
this.aliases.clear()
this.registry = null
return this
}
/**
* Provider IDgetAiSdkProviderId使用
* Provider ID
* ID
*/
resolveProviderId(id: string): string {
if (!this.aliases.has(id)) return id // 不是别名,直接返回
// 是别名找到真实ID
const targetProvider = this.providers[id]
for (const [realId, provider] of Object.entries(this.providers)) {
if (provider === targetProvider && !this.aliases.has(realId)) {
return realId
}
}
return id
}
/**
*
*/
isAlias(id: string): boolean {
return this.aliases.has(id)
}
/**
*
*/
getAllAliases(): Record<string, string> {
const result: Record<string, string> = {}
this.aliases.forEach((alias) => {
result[alias] = this.resolveProviderId(alias)
})
return result
}
}
/**

View File

@ -28,6 +28,20 @@ export {
reinitializeProvider
} from './registry'
// 动态Provider注册功能
export {
cleanup,
getAllAliases,
getAllDynamicMappings,
getDynamicProviders,
getProviderMapping,
isAlias,
isDynamicProvider,
registerDynamicProvider,
registerMultipleProviders,
resolveProviderId
} from './registry'
// ==================== 保留的导出(兼容性)====================
// 基础Provider数据源

View File

@ -8,7 +8,7 @@ import { customProvider } from 'ai'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
import { globalRegistryManagement } from './RegistryManagement'
import { baseProviders } from './schemas'
import { baseProviders, type DynamicProviderRegistration } from './schemas'
/**
* Provider
@ -227,6 +227,104 @@ export function hasInitializedProviders(): boolean {
return globalRegistryManagement.hasProviders()
}
// ==================== 动态Provider注册功能 ====================
// 全局动态provider存储
const dynamicProviders = new Map<string, DynamicProviderRegistration>()
/**
* provider
*/
export function registerDynamicProvider(config: DynamicProviderRegistration): boolean {
try {
// 验证配置
if (!config.id || !config.name) {
return false
}
// 检查是否与基础provider冲突
if (baseProviders.find((p) => p.id === config.id)) {
console.warn(`Dynamic provider "${config.id}" conflicts with base provider`)
return false
}
// 存储动态provider配置
dynamicProviders.set(config.id, config)
// 如果有creator函数立即初始化
if (config.creator) {
try {
const provider = config.creator({}) as any // 使用空配置初始化类型断言为any
const aliases = config.mappings ? Object.keys(config.mappings) : undefined
globalRegistryManagement.registerProvider(config.id, provider, aliases)
} catch (error) {
console.error(`Failed to initialize dynamic provider "${config.id}":`, error)
return false
}
}
return true
} catch (error) {
console.error(`Failed to register dynamic provider:`, error)
return false
}
}
/**
* providers
*/
export function registerMultipleProviders(configs: DynamicProviderRegistration[]): number {
let successCount = 0
configs.forEach((config) => {
if (registerDynamicProvider(config)) {
successCount++
}
})
return successCount
}
/**
* provider映射
*/
export function getProviderMapping(providerId: string): string {
return globalRegistryManagement.resolveProviderId(providerId)
}
/**
* provider
*/
export function isDynamicProvider(providerId: string): boolean {
return dynamicProviders.has(providerId)
}
/**
* providers
*/
export function getDynamicProviders(): DynamicProviderRegistration[] {
return Array.from(dynamicProviders.values())
}
/**
*
*/
export function getAllDynamicMappings(): Record<string, string> {
return globalRegistryManagement.getAllAliases()
}
/**
* providers
*/
export function cleanup(): void {
dynamicProviders.clear()
globalRegistryManagement.clear()
}
// ==================== 导出别名相关API ====================
export const resolveProviderId = (id: string) => globalRegistryManagement.resolveProviderId(id)
export const isAlias = (id: string) => globalRegistryManagement.isAlias(id)
export const getAllAliases = () => globalRegistryManagement.getAllAliases()
// ==================== 导出错误类型和工具函数 ====================
export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError }

View File

@ -38,7 +38,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
}
createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
return definePlugin({
name: '_internal_resolveModel',
enforce: 'post',
@ -50,7 +50,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
})
}
createResolveImageModelPlugin() {
private createResolveImageModelPlugin() {
return definePlugin({
name: '_internal_resolveImageModel',
enforce: 'post',
@ -61,7 +61,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
})
}
createConfigureContextPlugin() {
private createConfigureContextPlugin() {
return definePlugin({
name: '_internal_configureContext',
configureContext: async (context: AiRequestContext) => {

View File

@ -133,6 +133,20 @@ export {
reinitializeProvider
} from './core/providers/registry'
// ==================== 动态Provider注册和别名映射 ====================
export {
cleanup,
getAllAliases,
getAllDynamicMappings,
getDynamicProviders,
getProviderMapping,
isAlias,
isDynamicProvider,
registerDynamicProvider,
registerMultipleProviders,
resolveProviderId
} from './core/providers/registry'
// ==================== Zod Schema 和验证 ====================
export { baseProviderIds, validateProviderId } from './core/providers'

View File

@ -284,6 +284,12 @@ export class AiSdkToChunkAdapter {
}
})
break
case 'abort':
this.onChunk({
type: ChunkType.ERROR,
error: new DOMException('Request was aborted', 'AbortError')
})
break
case 'error':
this.onChunk({
type: ChunkType.ERROR,

View File

@ -6,7 +6,7 @@ import {
InvokeModelWithResponseStreamCommand
} from '@aws-sdk/client-bedrock-runtime'
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
import {

View File

@ -1,109 +1,124 @@
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { cloneDeep } from 'lodash'
import { createAihubmixProvider } from './aihubmix'
import { aihubmixProviderCreator, newApiResolverCreator } from './config'
import { getAiSdkProviderId } from './factory'
export function getActualProvider(model: Model): Provider {
const provider = getProviderByModel(model)
// 如果是 vertexai 类型且没有 googleCredentials转换为 VertexProvider
let actualProvider = cloneDeep(provider)
/**
* provider的转换逻辑
*/
function handleSpecialProviders(model: Model, provider: Provider): Provider {
if (provider.type === 'vertexai' && !isVertexProvider(provider)) {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
actualProvider = createVertexProvider(provider)
return createVertexProvider(provider)
}
if (provider.id === 'aihubmix') {
actualProvider = createAihubmixProvider(model, actualProvider)
return aihubmixProviderCreator(model, provider)
}
if (actualProvider.type === 'gemini') {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost, 'v1beta')
if (provider.id === 'newapi') {
return newApiResolverCreator(model, provider)
}
return provider
}
/**
* provider的API Host
*/
function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider }
if (formatted.type === 'gemini') {
formatted.apiHost = formatApiHost(formatted.apiHost, 'v1beta')
} else {
actualProvider.apiHost = formatApiHost(actualProvider.apiHost)
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
/**
* Provider配置
*
*/
export function getActualProvider(model: Model): Provider {
const baseProvider = getProviderByModel(model)
// 按顺序处理各种转换
let actualProvider = cloneDeep(baseProvider)
actualProvider = handleSpecialProviders(model, actualProvider)
actualProvider = formatProviderApiHost(actualProvider)
return actualProvider
}
/**
* Provider AI SDK
*
*/
export function providerToAiSdkConfig(actualProvider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
const actualProviderType = actualProvider.type
const openaiResponseOptions =
actualProviderType === 'openai-response'
? {
mode: 'responses'
}
: aiSdkProviderId === 'openai'
? {
mode: 'chat'
}
: undefined
console.log('openaiResponseOptions', openaiResponseOptions)
console.log('actualProvider', actualProvider)
console.log('aiSdkProviderId', aiSdkProviderId)
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(
aiSdkProviderId,
{
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
},
{ ...openaiResponseOptions, headers: actualProvider.extra_headers }
)
// 构建基础配置
const baseConfig = {
baseURL: actualProvider.apiHost,
apiKey: actualProvider.apiKey
}
// 处理OpenAI模式简化逻辑
const extraOptions: any = {}
if (actualProvider.type === 'openai-response') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai') {
extraOptions.mode = 'chat'
}
// 添加额外headers
if (actualProvider.extra_headers) {
extraOptions.headers = actualProvider.extra_headers
}
// 如果AI SDK支持该provider使用原生配置
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId as ProviderId,
options
}
} else {
console.log(`Using openai-compatible fallback for provider: ${actualProvider.type}`)
const options = ProviderConfigFactory.createOpenAICompatible(actualProvider.apiHost, actualProvider.apiKey)
}
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id
}
// 否则fallback到openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions
}
}
}
/**
* 使AI SDK
* provider系统
*/
export function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 对于 vertexai检查配置是否完整
export function isModernSdkSupported(provider: Provider): boolean {
// 特殊检查vertexai需要配置完整
if (provider.type === 'vertexai' && !isVertexAIConfigured()) {
return false
}
// 图像生成模型现在支持新的 AI SDK
// (但需要确保 provider 是支持的
// 使用getAiSdkProviderId获取映射后的providerId然后检查AI SDK是否支持
const aiSdkProviderId = getAiSdkProviderId(provider)
if (model && isDedicatedImageGenerationModel(model)) {
return true
}
return true
// 如果映射到了支持的provider则支持现代SDK
return AiCore.isSupported(aiSdkProviderId)
}

View File

@ -1,56 +0,0 @@
import { ProviderId } from '@cherrystudio/ai-core/types'
import { isOpenAIModel } from '@renderer/config/models'
import { Model, Provider } from '@renderer/types'
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
console.log('getAiSdkProviderIdForAihubmix', model)
const id = model.id.toLowerCase()
if (id.startsWith('claude')) {
return 'anthropic'
}
// TODO:暂时注释,不清楚为什么排除,webSearch时会导致gemini模型走openai的逻辑
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
return 'google'
}
if (isOpenAIModel(model)) {
return 'openai'
}
return 'openai-compatible'
}
export function createAihubmixProvider(model: Model, provider: Provider): Provider {
const providerId = getAiSdkProviderIdForAihubmix(model)
provider = {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
if (providerId === 'google') {
return {
...provider,
type: 'gemini',
apiHost: 'https://aihubmix.com/gemini'
}
}
if (providerId === 'openai') {
return {
...provider,
type: 'openai-response'
}
}
if (providerId === 'anthropic') {
return {
...provider,
type: 'anthropic'
}
}
return provider
}

View File

@ -0,0 +1,57 @@
/**
* AiHubMix规则集
*/
import { isOpenAIModel } from '@renderer/config/models'
import { Provider } from '@renderer/types'
import { startsWith } from './helper'
import { provider2Provider } from './helper'
import type { ModelRule } from './types'
const extraProviderConfig = (provider: Provider) => {
return {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
}
const AIHUBMIX_RULES: ModelRule[] = [
{
name: 'claude',
match: startsWith('claude'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'anthropic'
})
}
},
{
name: 'gemini',
match: (model) =>
(startsWith('gemini')(model) || startsWith('imagen')(model)) &&
!model.id.endsWith('-nothink') &&
!model.id.endsWith('-search'),
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
apiHost: 'https://aihubmix.com/gemini'
})
}
},
{
name: 'openai',
match: isOpenAIModel,
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'openai-response'
})
}
}
]
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)

View File

@ -0,0 +1,22 @@
import type { Model, Provider } from '@renderer/types'
import type { ModelRule } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* Provider ID
* @param model
* @param rules
* @param fallback fallback的providerId
* @returns providerId
*/
export function provider2Provider(rules: ModelRule[], model: Model, provider: Provider): Provider {
for (const rule of rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return provider
}

View File

@ -0,0 +1,16 @@
// /**
// * Provider解析规则模块导出
// */
// // 导出类型
// export type { ModelRule } from './types'
// // 导出匹配函数和解析器
// export { endpointIs, resolveProvider, startsWith } from './helper'
// // 导出规则集
// export { AIHUBMIX_RULES } from './aihubmix'
// export { NEWAPI_RULES } from './newApi'
export { aihubmixProviderCreator } from './aihubmix'
export { newApiResolverCreator } from './newApi'

View File

@ -0,0 +1,52 @@
/**
* NewAPI规则集
*/
import { Provider } from '@renderer/types'
import { endpointIs, provider2Provider } from './helper'
import type { ModelRule } from './types'
const NEWAPI_RULES: ModelRule[] = [
{
name: 'anthropic',
match: endpointIs('anthropic'),
provider: (provider: Provider) => {
return {
...provider,
type: 'anthropic'
}
}
},
{
name: 'gemini',
match: endpointIs('gemini'),
provider: (provider: Provider) => {
return {
...provider,
type: 'gemini'
}
}
},
{
name: 'openai-response',
match: endpointIs('openai-response'),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai-response'
}
}
},
{
name: 'openai',
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider: Provider) => {
return {
...provider,
type: 'openai'
}
}
}
]
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)

View File

@ -0,0 +1,7 @@
import type { Model, Provider } from '@renderer/types'
export interface ModelRule {
name: string
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}

View File

@ -1,49 +1,75 @@
import { AiCore, type ProviderId } from '@cherrystudio/ai-core'
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { Provider } from '@renderer/types'
// TODO
// 初始化新的Provider注册系统
// initializeNewProviders()
import { initializeNewProviders } from './providerConfigs'
// 静态Provider映射 - 核心providers
const logger = loggerService.withContext('ProviderFactory')
/**
* Provider系统
* providers
*/
;(async () => {
try {
await initializeNewProviders()
} catch (error) {
logger.warn('Failed to initialize new providers:', error as Error)
}
})()
/**
* Provider映射表
* Cherry Studio特有的provider ID到AI SDK标准ID的映射
*/
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
// anthropic: 'anthropic',
gemini: 'google',
'azure-openai': 'azure',
'openai-response': 'openai',
grok: 'xai'
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai' // Grok -> xai
}
/**
* provider标识符
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查动态映射
const dynamicMapping = getProviderMapping(identifier)
if (dynamicMapping && dynamicMapping !== identifier) {
return dynamicMapping as ProviderId
}
// 3. 检查AiCore是否直接支持
if (AiCore.isSupported(identifier)) {
return identifier as ProviderId
}
return null
}
/**
* AI SDK Provider ID
*
*/
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
// 1. 首先检查静态映射
const staticProviderId = STATIC_PROVIDER_MAPPING[provider.id]
if (staticProviderId) {
return staticProviderId
}
// TODO
// 2. 检查动态注册的provider映射使用aiCore的函数
// const dynamicProviderId = getProviderMapping(provider.id)
// if (dynamicProviderId) {
// return dynamicProviderId as ProviderId
// }
// 3. 检查provider.type的静态映射
const staticProviderType = STATIC_PROVIDER_MAPPING[provider.type]
if (staticProviderType) {
return staticProviderType
}
// TODO
// 4. 检查provider.type的动态映射
// const dynamicProviderType = getProviderMapping(provider.type)
// if (dynamicProviderType) {
// return dynamicProviderType as ProviderId
// }
// 5. 检查AiCore是否直接支持
if (AiCore.isSupported(provider.id)) {
return provider.id as ProviderId
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) {
return resolvedFromId
}
// 6. 最后的fallback
// 2. 尝试解析provider.type
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
// 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId
}

View File

@ -1,4 +1,4 @@
import { type ProviderConfig } from '@cherrystudio/ai-core'
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
const logger = loggerService.withContext('ProviderConfigs')
@ -43,19 +43,29 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
}
] as const
// TODO
// /**
// * 初始化新的Providers
// * 使用aiCore的动态注册功能
// */
// export async function initializeNewProviders(): Promise<void> {
// try {
// const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
/**
* Providers
* 使aiCore的动态注册功能
*/
export async function initializeNewProviders(): Promise<void> {
try {
logger.info('Starting to register new providers', {
providerCount: NEW_PROVIDER_CONFIGS.length,
providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id)
})
// if (successCount < NEW_PROVIDER_CONFIGS.length) {
// logger.warn('Some providers failed to register. Check previous error logs.')
// }
// } catch (error) {
// logger.error('Failed to initialize new providers:', error as Error)
// }
// }
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
logger.info('Provider registration completed', {
successCount,
totalCount: NEW_PROVIDER_CONFIGS.length,
failedCount: NEW_PROVIDER_CONFIGS.length - successCount
})
if (successCount < NEW_PROVIDER_CONFIGS.length) {
logger.warn('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
}

View File

@ -28,7 +28,6 @@ export function buildProviderOptions(
}
): Record<string, any> {
const providerId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
@ -37,11 +36,8 @@ export function buildProviderOptions(
case 'openai':
case 'azure':
providerSpecificOptions = {
...buildOpenAIProviderOptions(assistant, model, capabilities),
// 函数内有对于真实provider.id的判断,应该不会影响原生provider
...buildGenericProviderOptions(assistant, model, capabilities)
...buildOpenAIProviderOptions(assistant, model, capabilities)
}
break
case 'anthropic':