feat(aiCore): enhance provider management and registration system

- Added support for a new provider configuration structure in package.json, enabling better integration of provider types.
- Updated tsdown.config.ts to include new entry points for provider modules, improving build organization.
- Refactored index.ts to streamline exports and enhance type handling for provider-related functionalities.
- Simplified provider initialization and registration processes, allowing for more flexible provider management.
- Improved type definitions and removed deprecated methods to enhance code clarity and maintainability.
This commit is contained in:
MyPrototypeWhat 2025-08-27 16:17:57 +08:00
parent 2ce9314a10
commit 9c01e24317
17 changed files with 329 additions and 586 deletions

View File

@ -71,6 +71,13 @@
"import": "./dist/built-in/plugins/index.mjs",
"require": "./dist/built-in/plugins/index.js",
"default": "./dist/built-in/plugins/index.js"
},
"./provider": {
"types": "./dist/provider/index.d.ts",
"react-native": "./dist/provider/index.js",
"import": "./dist/provider/index.mjs",
"require": "./dist/provider/index.js",
"default": "./dist/provider/index.js"
}
}
}

View File

@ -1,6 +1,6 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
import type { ProviderId } from '../../types'
import type { ProviderId } from '../providers'
import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器

View File

@ -154,10 +154,10 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
}
/**
* provider
* provider
*/
getRegisteredProviders(): string[] {
return Object.keys(this.providers).filter((id) => !this.aliases.has(id))
return Object.keys(this.providers)
}
/**

View File

@ -1,60 +1,83 @@
/**
* Providers -
* Providers - Provider包
*/
// ==================== 新版架构(推荐使用)====================
// ==================== 核心管理器 ====================
// Provider 注册表管理器
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
// Provider 初始化器(核心功能
// Provider 核心功能
export {
// 状态管理
cleanup,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getImageModel,
// 工具函数
getInitializedProviders,
getLanguageModel,
getProviderInfo,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
getTextEmbeddingModel,
hasInitializedProviders,
// initializeImageProvider, // deprecated: 使用 initializeProvider 即可
initializeProvider,
initializeProviders,
isProviderInitialized,
isProviderSupported,
// 工具函数
hasProviderConfig,
// 别名支持
hasProviderConfigByAlias,
isProviderConfigAlias,
// 错误类型
ProviderInitializationError,
ProviderInitializer,
// 全局访问
providerRegistry,
reinitializeProvider
registerMultipleProviderConfigs,
registerProvider,
// 统一Provider系统
registerProviderConfig,
resolveProviderConfigId
} from './registry'
// 动态Provider注册功能
export {
cleanup,
getAllAliases,
getAllDynamicMappings,
getDynamicProviders,
getProviderMapping,
isAlias,
isDynamicProvider,
registerDynamicProvider,
registerMultipleProviders,
resolveProviderId
} from './registry'
// ==================== 保留的导出(兼容性)====================
// ==================== 基础数据和类型 ====================
// 基础Provider数据源
export { baseProviderIds, baseProviders } from './schemas'
// 类型定义和Schema
export type {
BaseProviderId,
CustomProviderId,
DynamicProviderRegistration,
ProviderConfig,
ProviderId
} from './schemas' // 从 schemas 导出的类型
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
export type {
DynamicProviderRegistry,
ExtensibleProviderSettingsMap,
ProviderError,
ProviderSettingsMap,
ProviderTypeRegistrar
} from './types'
// ==================== 工具函数 ====================
// Provider配置工厂
export {
type BaseProviderConfig,
createProviderConfig,
ProviderConfigBuilder,
providerConfigBuilder,
ProviderConfigFactory
} from './factory'
// 工具函数
export { formatPrivateKey } from './utils'
// ==================== 扩展功能 ====================
// Hub Provider 功能
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
// 类型定义(可能被其他模块使用)
export type { ProviderConfig, ProviderId, ProviderSettingsMap } from './types'
// Provider验证功能使用更好的Zod版本
export { validateProviderConfig } from './schemas'
// 验证功能(可能被其他地方使用)
export { validateProviderId } from './schemas'

View File

@ -6,9 +6,8 @@
import { customProvider } from 'ai'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
import { globalRegistryManagement } from './RegistryManagement'
import { baseProviders, type DynamicProviderRegistration } from './schemas'
import { baseProviders, type ProviderConfig } from './schemas'
/**
* Provider
@ -24,141 +23,6 @@ class ProviderInitializationError extends Error {
}
}
/**
* Provider
*/
export class ProviderInitializer {
/**
* provider
*/
static initializeProvider(providerId: string, options: any): void {
try {
// 1. 从 schemas 获取 provider 配置
const providerConfig = baseProviders.find((p) => p.id === providerId)
if (!providerConfig) {
throw new ProviderInitializationError(`Provider configuration for '${providerId}' not found`, providerId)
}
// 2. 使用 creator 函数创建已配置的 provider
const configuredProvider = providerConfig.creator(options)
// 3. 处理特殊逻辑并注册到全局管理器
this.handleProviderSpecificLogic(configuredProvider, providerId)
} catch (error) {
if (error instanceof ProviderInitializationError) {
throw error
}
throw new ProviderInitializationError(
`Failed to initialize provider ${providerId}: ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
)
}
}
/**
* providers
*/
static initializeProviders(providers: Record<string, any>): void {
Object.entries(providers).forEach(([providerId, options]) => {
try {
this.initializeProvider(providerId, options)
} catch (error) {
console.error(`Failed to initialize provider ${providerId}:`, error)
}
})
}
/**
* provider ( ModelCreator )
*/
private static handleProviderSpecificLogic(provider: any, providerId: string): void {
if (providerId === 'openai') {
// 🎯 OpenAI 默认注册 (responses 模式)
globalRegistryManagement.registerProvider('openai', provider)
// 🎯 使用 AI SDK 官方的 customProvider 创建 chat 模式变体
const openaiChatProvider = customProvider({
fallbackProvider: {
...provider,
// 覆盖 languageModel 方法指向 chat
languageModel: (modelId: string) => provider.chat(modelId)
}
})
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
} else {
// 其他 provider 直接注册
globalRegistryManagement.registerProvider(providerId, provider)
}
}
/**
* provider ( ModelCreator )
*
* @deprecated provider初始化使 initializeProvider()
* provider实例可以同时支持文本和图像功能
*
* TODO: 考虑在下个版本中删除此方法
*/
// static initializeImageProvider(providerId: string, options: any): void {
// try {
// const providerConfig = baseProviders.find((p) => p.id === providerId)
// if (!providerConfig) {
// throw new ProviderInitializationError(`Provider configuration for '${providerId}' not found`, providerId)
// }
// if (!providerConfig.supportsImageGeneration) {
// throw new ProviderInitializationError(`Provider "${providerId}" does not support image generation`, providerId)
// }
// const provider = providerConfig.creator(options)
// // 注册图像 provider (使用特殊前缀区分)
// globalRegistryManagement.registerProvider(`${providerId}-image`, provider as any)
// } catch (error) {
// if (error instanceof ProviderInitializationError) {
// throw error
// }
// throw new ProviderInitializationError(
// `Failed to initialize image provider ${providerId}: ${error instanceof Error ? error.message : 'Unknown error'}`,
// providerId,
// error instanceof Error ? error : undefined
// )
// }
// }
/**
* provider
*/
static isProviderInitialized(providerId: string): boolean {
return globalRegistryManagement.getRegisteredProviders().includes(providerId)
}
/**
* provider
*/
static reinitializeProvider(providerId: string, options: any): void {
this.initializeProvider(providerId, options) // 会覆盖已有的
}
/**
* providers
*/
static clearAllProviders(): void {
globalRegistryManagement.clear()
}
}
// ==================== 便捷函数导出 ====================
export const initializeProvider = ProviderInitializer.initializeProvider.bind(ProviderInitializer)
export const initializeProviders = ProviderInitializer.initializeProviders.bind(ProviderInitializer)
// export const initializeImageProvider = ProviderInitializer.initializeImageProvider.bind(ProviderInitializer) // deprecated: 使用 initializeProvider 即可
export const isProviderInitialized = ProviderInitializer.isProviderInitialized.bind(ProviderInitializer)
export const reinitializeProvider = ProviderInitializer.reinitializeProvider.bind(ProviderInitializer)
export const clearAllProviders = ProviderInitializer.clearAllProviders.bind(ProviderInitializer)
// ==================== 全局管理器导出 ====================
export { globalRegistryManagement as providerRegistry }
@ -169,10 +33,10 @@ export const getLanguageModel = (id: string) => globalRegistryManagement.languag
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
// ==================== 工具函数 (从 ModelCreator 迁移) ====================
// ==================== 工具函数 ====================
/**
* Providers ( ModelCreator )
* Providers
*/
export function getSupportedProviders(): Array<{
id: string
@ -184,35 +48,6 @@ export function getSupportedProviders(): Array<{
}))
}
/**
* Provider
*/
export function isProviderSupported(providerId: string): boolean {
return getProviderInfo(providerId).isSupported
}
/**
* Provider ( ModelCreator )
*/
export function getProviderInfo(providerId: string): {
id: string
name: string
isSupported: boolean
isInitialized: boolean
effectiveProvider: string
} {
const provider = baseProviders.find((p) => p.id === providerId)
const isInitialized = globalRegistryManagement.getRegisteredProviders().includes(providerId)
return {
id: providerId,
name: provider?.name || providerId,
isSupported: !!provider,
isInitialized,
effectiveProvider: isInitialized ? providerId : 'openai-compatible'
}
}
/**
* providers
*/
@ -227,56 +62,165 @@ export function hasInitializedProviders(): boolean {
return globalRegistryManagement.hasProviders()
}
// ==================== 动态Provider注册功能 ====================
// ==================== 统一Provider配置系统 ====================
// 全局动态provider存储
const dynamicProviders = new Map<string, DynamicProviderRegistration>()
// 全局Provider配置存储
const providerConfigs = new Map<string, ProviderConfig>()
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
const providerConfigAliases = new Map<string, string>() // alias -> realId
/**
* provider
* - baseProviders转换为统一格式
*/
export function registerDynamicProvider(config: DynamicProviderRegistration): boolean {
function initializeBuiltInConfigs(): void {
baseProviders.forEach((provider) => {
const config: ProviderConfig = {
id: provider.id,
name: provider.name,
creator: provider.creator as any, // 类型转换以兼容多种creator签名
supportsImageGeneration: provider.supportsImageGeneration || false
}
providerConfigs.set(provider.id, config)
})
}
// 启动时自动注册内置配置
initializeBuiltInConfigs()
/**
* 步骤1: 注册Provider配置 -
*/
export function registerProviderConfig(config: ProviderConfig): boolean {
try {
// 验证配置
if (!config.id || !config.name) {
return false
}
// 检查是否与基础provider冲突
if (baseProviders.find((p) => p.id === config.id)) {
console.warn(`Dynamic provider "${config.id}" conflicts with base provider`)
return false
// 检查是否与已有配置冲突(包括内置配置)
if (providerConfigs.has(config.id)) {
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
}
// 存储动态provider配置
dynamicProviders.set(config.id, config)
// 存储配置(内置和用户配置统一处理)
providerConfigs.set(config.id, config)
// 如果有creator函数立即初始化
if (config.creator) {
try {
const provider = config.creator({}) as any // 使用空配置初始化类型断言为any
const aliases = config.mappings ? Object.keys(config.mappings) : undefined
globalRegistryManagement.registerProvider(config.id, provider, aliases)
} catch (error) {
console.error(`Failed to initialize dynamic provider "${config.id}":`, error)
return false
}
// 处理别名
if (config.aliases && config.aliases.length > 0) {
config.aliases.forEach((alias) => {
if (providerConfigAliases.has(alias)) {
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
}
providerConfigAliases.set(alias, config.id)
})
}
return true
} catch (error) {
console.error(`Failed to register dynamic provider:`, error)
console.error(`Failed to register ProviderConfig:`, error)
return false
}
}
/**
* providers
* 步骤2: 创建Provider -
*/
export function registerMultipleProviders(configs: DynamicProviderRegistration[]): number {
export async function createProvider(providerId: string, options: any): Promise<any> {
// 支持通过别名查找配置
const config = getProviderConfigByAlias(providerId)
if (!config) {
throw new Error(`ProviderConfig not found for id: ${providerId}`)
}
try {
let creator: (options: any) => any
if (config.creator) {
// 方式1: 直接执行 creator
creator = config.creator
} else if (config.import && config.creatorFunctionName) {
// 方式2: 动态导入并执行
const module = await config.import()
creator = (module as any)[config.creatorFunctionName]
if (!creator || typeof creator !== 'function') {
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
}
} else {
throw new Error('No valid creator method provided in ProviderConfig')
}
// 使用真实配置创建provider实例
return creator(options)
} catch (error) {
console.error(`Failed to create provider "${providerId}":`, error)
throw error
}
}
/**
* 步骤3: 注册Provider到全局管理器
*/
export function registerProvider(providerId: string, provider: any): boolean {
try {
const config = providerConfigs.get(providerId)
if (!config) {
console.error(`ProviderConfig not found for id: ${providerId}`)
return false
}
// 获取aliases配置
const aliases = config.aliases
// 处理特殊provider逻辑
if (providerId === 'openai') {
// 注册默认 openai
globalRegistryManagement.registerProvider('openai', provider, aliases)
// 创建并注册 openai-chat 变体
const openaiChatProvider = customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
} else {
// 其他provider直接注册
globalRegistryManagement.registerProvider(providerId, provider, aliases)
}
return true
} catch (error) {
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
return false
}
}
/**
* 便捷函数: 一次性完成创建+
*/
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
try {
// 步骤2: 创建provider
const provider = await createProvider(providerId, options)
// 步骤3: 注册到全局管理器
return registerProvider(providerId, provider)
} catch (error) {
console.error(`Failed to create and register provider "${providerId}":`, error)
return false
}
}
/**
* Provider配置
*/
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
let successCount = 0
configs.forEach((config) => {
if (registerDynamicProvider(config)) {
if (registerProviderConfig(config)) {
successCount++
}
})
@ -284,47 +228,83 @@ export function registerMultipleProviders(configs: DynamicProviderRegistration[]
}
/**
* provider映射
* Provider配置
*/
export function getProviderMapping(providerId: string): string {
return globalRegistryManagement.resolveProviderId(providerId)
export function hasProviderConfig(providerId: string): boolean {
return providerConfigs.has(providerId)
}
/**
* provider
* ID检查是否有对应的Provider配置
*/
export function isDynamicProvider(providerId: string): boolean {
return dynamicProviders.has(providerId)
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
const realId = resolveProviderConfigId(aliasOrId)
return providerConfigs.has(realId)
}
/**
* providers
* Provider配置
*/
export function getDynamicProviders(): DynamicProviderRegistration[] {
return Array.from(dynamicProviders.values())
export function getAllProviderConfigs(): ProviderConfig[] {
return Array.from(providerConfigs.values())
}
/**
*
* ID获取Provider配置
*/
export function getAllDynamicMappings(): Record<string, string> {
return globalRegistryManagement.getAllAliases()
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
return providerConfigs.get(providerId)
}
/**
* providers
* ID获取Provider配置
*/
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
// 先检查是否为别名如果是则解析为真实ID
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
return providerConfigs.get(realId)
}
/**
* ProviderConfig ID
*/
export function resolveProviderConfigId(aliasOrId: string): string {
return providerConfigAliases.get(aliasOrId) || aliasOrId
}
/**
* ProviderConfig别名
*/
export function isProviderConfigAlias(id: string): boolean {
return providerConfigAliases.has(id)
}
/**
* ProviderConfig别名映射关系
*/
export function getAllProviderConfigAliases(): Record<string, string> {
const result: Record<string, string> = {}
providerConfigAliases.forEach((realId, alias) => {
result[alias] = realId
})
return result
}
/**
* Provider配置和已注册的providers
*/
export function cleanup(): void {
dynamicProviders.clear()
providerConfigs.clear()
providerConfigAliases.clear() // 清理别名映射
globalRegistryManagement.clear()
// 重新初始化内置配置
initializeBuiltInConfigs()
}
export function clearAllProviders(): void {
globalRegistryManagement.clear()
}
// ==================== 导出别名相关API ====================
// ==================== 导出错误类型 ====================
export const resolveProviderId = (id: string) => globalRegistryManagement.resolveProviderId(id)
export const isAlias = (id: string) => globalRegistryManagement.isAlias(id)
export const getAllAliases = () => globalRegistryManagement.getAllAliases()
// ==================== 导出错误类型和工具函数 ====================
export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError }
export { ProviderInitializationError }

View File

@ -1,8 +1,5 @@
/**
* Zod Provider
* -
* - Provider
* -
* Provider Config
*/
import { createAnthropic } from '@ai-sdk/anthropic'
@ -71,7 +68,7 @@ export const baseProviders = [
/**
* Provider IDs
* baseProviders
* baseProviders
*/
export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]]
@ -81,46 +78,28 @@ export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as read
export const baseProviderIdSchema = z.enum(baseProviderIds)
/**
* Provider ID Schema
* Provider ID Schema
* provider IDs
*/
export const dynamicProviderIdSchema = z
export const customProviderIdSchema = z
.string()
.min(1)
.refine((id) => !baseProviderIds.includes(id as any), {
message: 'Dynamic provider ID cannot conflict with base provider IDs'
message: 'Custom provider ID cannot conflict with base provider IDs'
})
/**
* Provider ID Schema
* providers +
* Provider ID Schema -
*/
export const providerIdSchema = z.union([baseProviderIdSchema, dynamicProviderIdSchema])
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
/**
* Provider Schema
* Provider的配置验证
*/
export const providerConfigSchema = z
.object({
id: providerIdSchema,
name: z.string().min(1),
creator: z.function().optional(),
import: z.function().optional(),
creatorFunctionName: z.string().optional(),
supportsImageGeneration: z.boolean().default(false),
imageCreator: z.function().optional(),
validateOptions: z.function().optional()
})
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
message: 'Must provide either creator function or import configuration'
})
/**
* Provider Schema
*/
export const dynamicProviderRegistrationSchema = z
.object({
id: dynamicProviderIdSchema,
id: customProviderIdSchema, // 只允许自定义ID
name: z.string().min(1),
creator: z.function().optional(),
import: z.function().optional(),
@ -128,68 +107,26 @@ export const dynamicProviderRegistrationSchema = z
supportsImageGeneration: z.boolean().default(false),
imageCreator: z.function().optional(),
validateOptions: z.function().optional(),
mappings: z.record(z.string()).optional()
aliases: z.array(z.string()).optional()
})
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
message: 'Must provide either creator function or import configuration'
})
// ===== 类型推导 =====
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export type DynamicProviderId = z.infer<typeof dynamicProviderIdSchema>
/**
* Provider ID - zod schema
*/
export type ProviderId = z.infer<typeof providerIdSchema>
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
/**
* Provider
*/
export type ProviderConfig = z.infer<typeof providerConfigSchema>
export type DynamicProviderRegistration = z.infer<typeof dynamicProviderRegistrationSchema>
// ===== 纯验证函数 =====
/**
* Provider ID
*
* @deprecated 使 ProviderConfig
*/
export function validateProviderId(id: string): boolean {
return providerIdSchema.safeParse(id).success
}
/**
* Provider ID
*/
export function isBaseProviderId(id: string): id is BaseProviderId {
return baseProviderIdSchema.safeParse(id).success
}
/**
* Provider ID
*/
export function isValidDynamicProviderId(id: string): boolean {
return dynamicProviderIdSchema.safeParse(id).success
}
/**
* Provider
*/
export function validateProviderConfig(config: unknown): ProviderConfig | null {
const result = providerConfigSchema.safeParse(config)
if (result.success) {
return result.data
}
return null
}
/**
* Provider
*/
export function validateDynamicProviderRegistration(config: unknown): DynamicProviderRegistration | null {
const result = dynamicProviderRegistrationSchema.safeParse(config)
if (result.success) {
return result.data
}
return null
}
/**
* Provider
*/
export function getBaseProviderConfig(id: BaseProviderId): (typeof baseProviders)[number] | undefined {
return baseProviders.find((p) => p.id === id)
}
export type DynamicProviderRegistration = ProviderConfig

View File

@ -7,13 +7,7 @@ import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible
import { type XaiProviderSettings } from '@ai-sdk/xai'
// 导入基于 Zod 的 ProviderId 类型
import {
type BaseProviderId,
type DynamicProviderId,
type DynamicProviderRegistration,
type ProviderConfig,
type ProviderId as ZodProviderId
} from './schemas'
import { type ProviderId as ZodProviderId } from './schemas'
export interface ExtensibleProviderSettingsMap {
// 基础的静态providers
@ -76,10 +70,6 @@ export class ProviderError extends Error {
// 动态ProviderId类型 - 基于 Zod Schema支持运行时扩展和验证
export type ProviderId = ZodProviderId
// 重新导出相关类型
export type { BaseProviderId, DynamicProviderId, DynamicProviderRegistration, ProviderConfig }
// Provider类型注册工具
export interface ProviderTypeRegistrar {
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
getProviderSettings<T extends string>(providerId: T): any

View File

@ -12,11 +12,10 @@ import {
streamText
} from 'ai'
import { type ProviderId } from '../../types'
import { globalModelResolver } from '../models'
import { type ModelConfig } from '../models/types'
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { getProviderInfo } from '../providers/registry'
import { type ProviderId } from '../providers'
import { ImageGenerationError, ImageModelResolutionError } from './errors'
import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types'
@ -311,13 +310,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
}
}
/**
*
*/
getClientInfo() {
return getProviderInfo(this.config.providerId)
}
// === 静态工厂方法 ===
/**

View File

@ -3,7 +3,6 @@ import { ImageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai'
import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
/**
@ -240,20 +239,4 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
): PluginEngine<'openai-compatible'> {
return new PluginEngine('openai-compatible', plugins)
}
/**
*
*/
static create<T extends ProviderId>(providerId: T, plugins?: AiPlugin[]): PluginEngine<T>
static create(providerId: string, plugins?: AiPlugin[]): PluginEngine<'openai-compatible'>
static create(providerId: string, plugins: AiPlugin[] = []): PluginEngine {
if (isProviderSupported(providerId)) {
return new PluginEngine(providerId as ProviderId, plugins)
} else {
// 对于未知 provider使用 openai-compatible
return new PluginEngine('openai-compatible', plugins)
}
}
}

View File

@ -4,10 +4,6 @@
*/
// 导入内部使用的类和函数
import { getSupportedProviders, isProviderSupported } from './core/providers/registry'
import type { ProviderId } from './core/providers/types'
import type { ProviderSettingsMap } from './core/providers/types'
import { createExecutor } from './core/runtime'
// ==================== 主要用户接口 ====================
export {
@ -28,29 +24,8 @@ export { createContext, definePlugin, PluginManager } from './core/plugins'
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
export { PluginEngine } from './core/runtime/pluginEngine'
// ==================== 低级 API ====================
export { providerRegistry } from './core/providers/registry'
// ==================== 类型定义 ====================
export type { ProviderConfig } from './core/providers/types'
export type { ProviderError } from './core/providers/types'
export type {
AnthropicProviderSettings,
AzureOpenAIProviderSettings,
DeepSeekProviderSettings,
GenerateObjectParams,
GenerateTextParams,
GoogleGenerativeAIProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
ProviderId,
ProviderSettings,
ProviderSettingsMap,
StreamObjectParams,
StreamTextParams,
XaiProviderSettings
} from './types'
export * as aiSdk from 'ai'
export type { GenerateObjectParams, GenerateTextParams, StreamObjectParams, StreamTextParams } from './types'
// ==================== AI SDK 常用类型导出 ====================
// 直接导出 AI SDK 的常用类型,方便使用
@ -112,115 +87,6 @@ export {
type TypedProviderOptions
} from './core/options'
// ==================== Provider 初始化和管理 ====================
export {
clearAllProviders,
getImageModel,
getInitializedProviders,
// 访问功能
getLanguageModel,
getProviderInfo,
getTextEmbeddingModel,
hasInitializedProviders,
// initializeImageProvider, // deprecated: 使用 initializeProvider 即可
// 初始化功能
initializeProvider,
initializeProviders,
isProviderInitialized,
isProviderSupported,
// 错误类型
ProviderInitializationError,
reinitializeProvider
} from './core/providers/registry'
// ==================== 动态Provider注册和别名映射 ====================
export {
cleanup,
getAllAliases,
getAllDynamicMappings,
getDynamicProviders,
getProviderMapping,
isAlias,
isDynamicProvider,
registerDynamicProvider,
registerMultipleProviders,
resolveProviderId
} from './core/providers/registry'
// ==================== Zod Schema 和验证 ====================
export { baseProviderIds, validateProviderId } from './core/providers'
// ==================== Hub Provider ====================
export { createHubProvider, type HubProviderConfig, HubProviderError } from './core/providers/HubProvider'
// ==================== Provider 配置工厂 ====================
export {
type BaseProviderConfig,
createProviderConfig,
type ProviderConfigBuilder,
providerConfigBuilder,
ProviderConfigFactory
} from './core/providers/factory'
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherrystudio/ai-core'
// ==================== 便捷 API ====================
// 主要的便捷工厂类
export const AiCore = {
version: AI_CORE_VERSION,
name: AI_CORE_NAME,
// 创建主要执行器(推荐使用)
create(providerId: ProviderId, options: ProviderSettingsMap[ProviderId], plugins: any[] = []) {
return createExecutor(providerId, options, plugins)
},
// 获取支持的providers
getSupportedProviders() {
return getSupportedProviders()
},
isSupported(providerId: ProviderId) {
return isProviderSupported(providerId)
}
}
// 推荐使用的执行器创建函数
export const createOpenAIExecutor = (options: ProviderSettingsMap['openai'], plugins?: any[]) => {
return createExecutor('openai', options, plugins)
}
export const createAnthropicExecutor = (options: ProviderSettingsMap['anthropic'], plugins?: any[]) => {
return createExecutor('anthropic', options, plugins)
}
export const createGoogleExecutor = (options: ProviderSettingsMap['google'], plugins?: any[]) => {
return createExecutor('google', options, plugins)
}
export const createXAIExecutor = (options: ProviderSettingsMap['xai'], plugins?: any[]) => {
return createExecutor('xai', options, plugins)
}
// ==================== 调试和开发工具 ====================
export const DevTools = {
// 列出所有支持的providers
listProviders() {
return getSupportedProviders()
},
// 获取provider详细信息
getProviderDetails() {
const supportedProviders = getSupportedProviders()
return {
supportedProviders: supportedProviders.length,
providers: supportedProviders.map((p) => ({
id: p.id,
name: p.name
}))
}
}
}

View File

@ -1,27 +1,9 @@
import { generateObject, generateText, streamObject, streamText } from 'ai'
import type { ProviderSettingsMap } from './core/providers/types'
// ProviderSettings 是所有 Provider Settings 的联合类型
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
// 重新导出 ProviderSettingsMap 中的所有类型
export type {
AnthropicProviderSettings,
AzureOpenAIProviderSettings,
DeepSeekProviderSettings,
GoogleGenerativeAIProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
ProviderId,
ProviderSettingsMap,
XaiProviderSettings
} from './core/providers/types'
// 重新导出插件类型
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'

View File

@ -3,7 +3,8 @@ import { defineConfig } from 'tsdown'
export default defineConfig({
entry: {
index: 'src/index.ts',
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts'
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
'provider/index': 'src/core/providers/index.ts'
},
outDir: 'dist',
format: ['esm', 'cjs'],

View File

@ -8,7 +8,8 @@
* 3.
*/
import { createExecutor, generateImage, initializeProvider, StreamTextParams } from '@cherrystudio/ai-core'
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
@ -36,21 +37,6 @@ export default class ModernAiProvider {
// 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(this.actualProvider)
// 初始化 provider 到全局管理器
try {
initializeProvider(this.config.providerId, this.config.options)
logger.debug('Provider initialized successfully', {
providerId: this.config.providerId,
hasOptions: !!this.config.options
})
} catch (error) {
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
logger.debug('Provider initialization skipped (may already be initialized)', {
providerId: this.config.providerId,
error: error instanceof Error ? error.message : String(error)
})
}
}
public getActualProvider() {
@ -67,6 +53,21 @@ export default class ModernAiProvider {
callType: string
}
): Promise<CompletionsResult> {
// 初始化 provider 到全局管理器
try {
await createAndRegisterProvider(this.config.providerId, this.config.options)
logger.debug('Provider initialized successfully', {
providerId: this.config.providerId,
hasOptions: !!this.config.options
})
} catch (error) {
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
logger.debug('Provider initialization skipped (may already be initialized)', {
providerId: this.config.providerId,
error: error instanceof Error ? error.message : String(error)
})
}
if (config.isImageGenerationEndpoint) {
return await this.modernImageGeneration(modelId, params, config)
}

View File

@ -1,4 +1,9 @@
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
import {
hasProviderConfig,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
@ -86,7 +91,7 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
}
// 如果AI SDK支持该provider使用原生配置
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId as ProviderId,
@ -120,5 +125,5 @@ export function isModernSdkSupported(provider: Provider): boolean {
const aiSdkProviderId = getAiSdkProviderId(provider)
// 如果映射到了支持的provider则支持现代SDK
return AiCore.isSupported(aiSdkProviderId)
return hasProviderConfig(aiSdkProviderId)
}

View File

@ -1,4 +1,4 @@
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { Provider } from '@renderer/types'
@ -30,7 +30,7 @@ const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
}
/**
* provider标识符
* provider标识符
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
@ -39,15 +39,10 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
return staticMapping
}
// 2. 检查动态映射
const dynamicMapping = getProviderMapping(identifier)
if (dynamicMapping && dynamicMapping !== identifier) {
return dynamicMapping as ProviderId
}
// 3. 检查AiCore是否直接支持
if (AiCore.isSupported(identifier)) {
return identifier as ProviderId
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null

View File

@ -1,4 +1,4 @@
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
const logger = loggerService.withContext('ProviderConfigs')
@ -16,9 +16,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
mappings: {
openrouter: 'openrouter'
}
aliases: ['openrouter']
},
{
id: 'google-vertex',
@ -26,10 +24,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
import: () => import('@ai-sdk/google-vertex'),
creatorFunctionName: 'createGoogleVertex',
supportsImageGeneration: true,
mappings: {
'google-vertex': 'google-vertex',
vertexai: 'google-vertex'
}
aliases: ['google-vertex', 'vertexai']
},
{
id: 'bedrock',
@ -37,9 +32,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
mappings: {
'aws-bedrock': 'bedrock'
}
aliases: ['aws-bedrock']
}
] as const
@ -49,19 +42,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
*/
export async function initializeNewProviders(): Promise<void> {
try {
logger.info('Starting to register new providers', {
providerCount: NEW_PROVIDER_CONFIGS.length,
providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id)
})
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
logger.info('Provider registration completed', {
successCount,
totalCount: NEW_PROVIDER_CONFIGS.length,
failedCount: NEW_PROVIDER_CONFIGS.length - successCount
})
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
if (successCount < NEW_PROVIDER_CONFIGS.length) {
logger.warn('Some providers failed to register. Check previous error logs.')
}

View File

@ -1,11 +1,11 @@
import { aiSdk, Tool } from '@cherrystudio/ai-core'
import { Tool } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
import { MCPTool, MCPToolResponse } from '@renderer/types'
import { callMCPTool } from '@renderer/utils/mcp-tools'
import { jsonSchema, tool } from 'ai'
import { JSONSchema7 } from 'json-schema'
const { tool } = aiSdk
const logger = loggerService.withContext('MCP-utils')
// Setup tools configuration based on provided parameters
@ -30,7 +30,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
for (const mcpTool of mcpTools) {
tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params) => {
// 创建适配的 MCPToolResponse 对象
const toolResponse: MCPToolResponse = {