mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-06 21:35:52 +08:00
feat(aiCore): enhance provider management and registration system
- Added support for a new provider configuration structure in package.json, enabling better integration of provider types. - Updated tsdown.config.ts to include new entry points for provider modules, improving build organization. - Refactored index.ts to streamline exports and enhance type handling for provider-related functionalities. - Simplified provider initialization and registration processes, allowing for more flexible provider management. - Improved type definitions and removed deprecated methods to enhance code clarity and maintainability.
This commit is contained in:
parent
2ce9314a10
commit
9c01e24317
@ -71,6 +71,13 @@
|
|||||||
"import": "./dist/built-in/plugins/index.mjs",
|
"import": "./dist/built-in/plugins/index.mjs",
|
||||||
"require": "./dist/built-in/plugins/index.js",
|
"require": "./dist/built-in/plugins/index.js",
|
||||||
"default": "./dist/built-in/plugins/index.js"
|
"default": "./dist/built-in/plugins/index.js"
|
||||||
|
},
|
||||||
|
"./provider": {
|
||||||
|
"types": "./dist/provider/index.d.ts",
|
||||||
|
"react-native": "./dist/provider/index.js",
|
||||||
|
"import": "./dist/provider/index.mjs",
|
||||||
|
"require": "./dist/provider/index.js",
|
||||||
|
"default": "./dist/provider/index.js"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
// 核心类型和接口
|
// 核心类型和接口
|
||||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||||
import type { ProviderId } from '../../types'
|
import type { ProviderId } from '../providers'
|
||||||
import type { AiPlugin, AiRequestContext } from './types'
|
import type { AiPlugin, AiRequestContext } from './types'
|
||||||
|
|
||||||
// 插件管理器
|
// 插件管理器
|
||||||
|
|||||||
@ -154,10 +154,10 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取已注册的 provider 列表(排除别名)
|
* 获取已注册的 provider 列表
|
||||||
*/
|
*/
|
||||||
getRegisteredProviders(): string[] {
|
getRegisteredProviders(): string[] {
|
||||||
return Object.keys(this.providers).filter((id) => !this.aliases.has(id))
|
return Object.keys(this.providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,60 +1,83 @@
|
|||||||
/**
|
/**
|
||||||
* Providers 模块统一导出 - 现代化架构
|
* Providers 模块统一导出 - 独立Provider包
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// ==================== 新版架构(推荐使用)====================
|
// ==================== 核心管理器 ====================
|
||||||
|
|
||||||
// Provider 注册表管理器
|
// Provider 注册表管理器
|
||||||
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
||||||
|
|
||||||
// Provider 初始化器(核心功能)
|
// Provider 核心功能
|
||||||
export {
|
export {
|
||||||
|
// 状态管理
|
||||||
|
cleanup,
|
||||||
clearAllProviders,
|
clearAllProviders,
|
||||||
|
createAndRegisterProvider,
|
||||||
|
createProvider,
|
||||||
|
getAllProviderConfigAliases,
|
||||||
|
getAllProviderConfigs,
|
||||||
getImageModel,
|
getImageModel,
|
||||||
|
// 工具函数
|
||||||
getInitializedProviders,
|
getInitializedProviders,
|
||||||
getLanguageModel,
|
getLanguageModel,
|
||||||
getProviderInfo,
|
getProviderConfig,
|
||||||
|
getProviderConfigByAlias,
|
||||||
getSupportedProviders,
|
getSupportedProviders,
|
||||||
getTextEmbeddingModel,
|
getTextEmbeddingModel,
|
||||||
hasInitializedProviders,
|
hasInitializedProviders,
|
||||||
// initializeImageProvider, // deprecated: 使用 initializeProvider 即可
|
// 工具函数
|
||||||
initializeProvider,
|
hasProviderConfig,
|
||||||
initializeProviders,
|
// 别名支持
|
||||||
isProviderInitialized,
|
hasProviderConfigByAlias,
|
||||||
isProviderSupported,
|
isProviderConfigAlias,
|
||||||
|
// 错误类型
|
||||||
ProviderInitializationError,
|
ProviderInitializationError,
|
||||||
ProviderInitializer,
|
// 全局访问
|
||||||
providerRegistry,
|
providerRegistry,
|
||||||
reinitializeProvider
|
registerMultipleProviderConfigs,
|
||||||
|
registerProvider,
|
||||||
|
// 统一Provider系统
|
||||||
|
registerProviderConfig,
|
||||||
|
resolveProviderConfigId
|
||||||
} from './registry'
|
} from './registry'
|
||||||
|
|
||||||
// 动态Provider注册功能
|
// ==================== 基础数据和类型 ====================
|
||||||
export {
|
|
||||||
cleanup,
|
|
||||||
getAllAliases,
|
|
||||||
getAllDynamicMappings,
|
|
||||||
getDynamicProviders,
|
|
||||||
getProviderMapping,
|
|
||||||
isAlias,
|
|
||||||
isDynamicProvider,
|
|
||||||
registerDynamicProvider,
|
|
||||||
registerMultipleProviders,
|
|
||||||
resolveProviderId
|
|
||||||
} from './registry'
|
|
||||||
|
|
||||||
// ==================== 保留的导出(兼容性)====================
|
|
||||||
|
|
||||||
// 基础Provider数据源
|
// 基础Provider数据源
|
||||||
export { baseProviderIds, baseProviders } from './schemas'
|
export { baseProviderIds, baseProviders } from './schemas'
|
||||||
|
|
||||||
|
// 类型定义和Schema
|
||||||
|
export type {
|
||||||
|
BaseProviderId,
|
||||||
|
CustomProviderId,
|
||||||
|
DynamicProviderRegistration,
|
||||||
|
ProviderConfig,
|
||||||
|
ProviderId
|
||||||
|
} from './schemas' // 从 schemas 导出的类型
|
||||||
|
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||||
|
export type {
|
||||||
|
DynamicProviderRegistry,
|
||||||
|
ExtensibleProviderSettingsMap,
|
||||||
|
ProviderError,
|
||||||
|
ProviderSettingsMap,
|
||||||
|
ProviderTypeRegistrar
|
||||||
|
} from './types'
|
||||||
|
|
||||||
|
// ==================== 工具函数 ====================
|
||||||
|
|
||||||
|
// Provider配置工厂
|
||||||
|
export {
|
||||||
|
type BaseProviderConfig,
|
||||||
|
createProviderConfig,
|
||||||
|
ProviderConfigBuilder,
|
||||||
|
providerConfigBuilder,
|
||||||
|
ProviderConfigFactory
|
||||||
|
} from './factory'
|
||||||
|
|
||||||
|
// 工具函数
|
||||||
|
export { formatPrivateKey } from './utils'
|
||||||
|
|
||||||
|
// ==================== 扩展功能 ====================
|
||||||
|
|
||||||
// Hub Provider 功能
|
// Hub Provider 功能
|
||||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
||||||
|
|
||||||
// 类型定义(可能被其他模块使用)
|
|
||||||
export type { ProviderConfig, ProviderId, ProviderSettingsMap } from './types'
|
|
||||||
|
|
||||||
// Provider验证功能(使用更好的Zod版本)
|
|
||||||
export { validateProviderConfig } from './schemas'
|
|
||||||
|
|
||||||
// 验证功能(可能被其他地方使用)
|
|
||||||
export { validateProviderId } from './schemas'
|
|
||||||
|
|||||||
@ -6,9 +6,8 @@
|
|||||||
|
|
||||||
import { customProvider } from 'ai'
|
import { customProvider } from 'ai'
|
||||||
|
|
||||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
|
||||||
import { globalRegistryManagement } from './RegistryManagement'
|
import { globalRegistryManagement } from './RegistryManagement'
|
||||||
import { baseProviders, type DynamicProviderRegistration } from './schemas'
|
import { baseProviders, type ProviderConfig } from './schemas'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provider 初始化错误类型
|
* Provider 初始化错误类型
|
||||||
@ -24,141 +23,6 @@ class ProviderInitializationError extends Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Provider 初始化器类
|
|
||||||
*/
|
|
||||||
export class ProviderInitializer {
|
|
||||||
/**
|
|
||||||
* 初始化单个 provider 并注册
|
|
||||||
*/
|
|
||||||
static initializeProvider(providerId: string, options: any): void {
|
|
||||||
try {
|
|
||||||
// 1. 从 schemas 获取 provider 配置
|
|
||||||
const providerConfig = baseProviders.find((p) => p.id === providerId)
|
|
||||||
if (!providerConfig) {
|
|
||||||
throw new ProviderInitializationError(`Provider configuration for '${providerId}' not found`, providerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 使用 creator 函数创建已配置的 provider
|
|
||||||
const configuredProvider = providerConfig.creator(options)
|
|
||||||
|
|
||||||
// 3. 处理特殊逻辑并注册到全局管理器
|
|
||||||
this.handleProviderSpecificLogic(configuredProvider, providerId)
|
|
||||||
} catch (error) {
|
|
||||||
if (error instanceof ProviderInitializationError) {
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
throw new ProviderInitializationError(
|
|
||||||
`Failed to initialize provider ${providerId}: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
|
||||||
providerId,
|
|
||||||
error instanceof Error ? error : undefined
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 批量初始化 providers
|
|
||||||
*/
|
|
||||||
static initializeProviders(providers: Record<string, any>): void {
|
|
||||||
Object.entries(providers).forEach(([providerId, options]) => {
|
|
||||||
try {
|
|
||||||
this.initializeProvider(providerId, options)
|
|
||||||
} catch (error) {
|
|
||||||
console.error(`Failed to initialize provider ${providerId}:`, error)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 处理特定 provider 的特殊逻辑 (从 ModelCreator 迁移并改进)
|
|
||||||
*/
|
|
||||||
private static handleProviderSpecificLogic(provider: any, providerId: string): void {
|
|
||||||
if (providerId === 'openai') {
|
|
||||||
// 🎯 OpenAI 默认注册 (responses 模式)
|
|
||||||
globalRegistryManagement.registerProvider('openai', provider)
|
|
||||||
|
|
||||||
// 🎯 使用 AI SDK 官方的 customProvider 创建 chat 模式变体
|
|
||||||
const openaiChatProvider = customProvider({
|
|
||||||
fallbackProvider: {
|
|
||||||
...provider,
|
|
||||||
// 覆盖 languageModel 方法指向 chat
|
|
||||||
languageModel: (modelId: string) => provider.chat(modelId)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
|
|
||||||
} else {
|
|
||||||
// 其他 provider 直接注册
|
|
||||||
globalRegistryManagement.registerProvider(providerId, provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 初始化图像生成 provider (从 ModelCreator 迁移)
|
|
||||||
*
|
|
||||||
* @deprecated 不再需要单独的图像provider初始化,使用 initializeProvider() 即可
|
|
||||||
* 一个provider实例可以同时支持文本和图像功能,无需分别初始化
|
|
||||||
*
|
|
||||||
* TODO: 考虑在下个版本中删除此方法
|
|
||||||
*/
|
|
||||||
// static initializeImageProvider(providerId: string, options: any): void {
|
|
||||||
// try {
|
|
||||||
// const providerConfig = baseProviders.find((p) => p.id === providerId)
|
|
||||||
// if (!providerConfig) {
|
|
||||||
// throw new ProviderInitializationError(`Provider configuration for '${providerId}' not found`, providerId)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (!providerConfig.supportsImageGeneration) {
|
|
||||||
// throw new ProviderInitializationError(`Provider "${providerId}" does not support image generation`, providerId)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const provider = providerConfig.creator(options)
|
|
||||||
|
|
||||||
// // 注册图像 provider (使用特殊前缀区分)
|
|
||||||
// globalRegistryManagement.registerProvider(`${providerId}-image`, provider as any)
|
|
||||||
// } catch (error) {
|
|
||||||
// if (error instanceof ProviderInitializationError) {
|
|
||||||
// throw error
|
|
||||||
// }
|
|
||||||
// throw new ProviderInitializationError(
|
|
||||||
// `Failed to initialize image provider ${providerId}: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
|
||||||
// providerId,
|
|
||||||
// error instanceof Error ? error : undefined
|
|
||||||
// )
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 检查 provider 是否已初始化
|
|
||||||
*/
|
|
||||||
static isProviderInitialized(providerId: string): boolean {
|
|
||||||
return globalRegistryManagement.getRegisteredProviders().includes(providerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 重新初始化 provider(更新配置)
|
|
||||||
*/
|
|
||||||
static reinitializeProvider(providerId: string, options: any): void {
|
|
||||||
this.initializeProvider(providerId, options) // 会覆盖已有的
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 清除所有已初始化的 providers
|
|
||||||
*/
|
|
||||||
static clearAllProviders(): void {
|
|
||||||
globalRegistryManagement.clear()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 便捷函数导出 ====================
|
|
||||||
|
|
||||||
export const initializeProvider = ProviderInitializer.initializeProvider.bind(ProviderInitializer)
|
|
||||||
export const initializeProviders = ProviderInitializer.initializeProviders.bind(ProviderInitializer)
|
|
||||||
// export const initializeImageProvider = ProviderInitializer.initializeImageProvider.bind(ProviderInitializer) // deprecated: 使用 initializeProvider 即可
|
|
||||||
export const isProviderInitialized = ProviderInitializer.isProviderInitialized.bind(ProviderInitializer)
|
|
||||||
export const reinitializeProvider = ProviderInitializer.reinitializeProvider.bind(ProviderInitializer)
|
|
||||||
export const clearAllProviders = ProviderInitializer.clearAllProviders.bind(ProviderInitializer)
|
|
||||||
|
|
||||||
// ==================== 全局管理器导出 ====================
|
// ==================== 全局管理器导出 ====================
|
||||||
|
|
||||||
export { globalRegistryManagement as providerRegistry }
|
export { globalRegistryManagement as providerRegistry }
|
||||||
@ -169,10 +33,10 @@ export const getLanguageModel = (id: string) => globalRegistryManagement.languag
|
|||||||
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
|
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
|
||||||
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
||||||
|
|
||||||
// ==================== 工具函数 (从 ModelCreator 迁移) ====================
|
// ==================== 工具函数 ====================
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取支持的 Providers 列表 (从 ModelCreator 迁移)
|
* 获取支持的 Providers 列表
|
||||||
*/
|
*/
|
||||||
export function getSupportedProviders(): Array<{
|
export function getSupportedProviders(): Array<{
|
||||||
id: string
|
id: string
|
||||||
@ -184,35 +48,6 @@ export function getSupportedProviders(): Array<{
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 检查 Provider 是否被支持
|
|
||||||
*/
|
|
||||||
export function isProviderSupported(providerId: string): boolean {
|
|
||||||
return getProviderInfo(providerId).isSupported
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 获取 Provider 信息 (从 ModelCreator 迁移并改进)
|
|
||||||
*/
|
|
||||||
export function getProviderInfo(providerId: string): {
|
|
||||||
id: string
|
|
||||||
name: string
|
|
||||||
isSupported: boolean
|
|
||||||
isInitialized: boolean
|
|
||||||
effectiveProvider: string
|
|
||||||
} {
|
|
||||||
const provider = baseProviders.find((p) => p.id === providerId)
|
|
||||||
const isInitialized = globalRegistryManagement.getRegisteredProviders().includes(providerId)
|
|
||||||
|
|
||||||
return {
|
|
||||||
id: providerId,
|
|
||||||
name: provider?.name || providerId,
|
|
||||||
isSupported: !!provider,
|
|
||||||
isInitialized,
|
|
||||||
effectiveProvider: isInitialized ? providerId : 'openai-compatible'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取所有已初始化的 providers
|
* 获取所有已初始化的 providers
|
||||||
*/
|
*/
|
||||||
@ -227,56 +62,165 @@ export function hasInitializedProviders(): boolean {
|
|||||||
return globalRegistryManagement.hasProviders()
|
return globalRegistryManagement.hasProviders()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== 动态Provider注册功能 ====================
|
// ==================== 统一Provider配置系统 ====================
|
||||||
|
|
||||||
// 全局动态provider存储
|
// 全局Provider配置存储
|
||||||
const dynamicProviders = new Map<string, DynamicProviderRegistration>()
|
const providerConfigs = new Map<string, ProviderConfig>()
|
||||||
|
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
|
||||||
|
const providerConfigAliases = new Map<string, string>() // alias -> realId
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 注册动态provider
|
* 初始化内置配置 - 将baseProviders转换为统一格式
|
||||||
*/
|
*/
|
||||||
export function registerDynamicProvider(config: DynamicProviderRegistration): boolean {
|
function initializeBuiltInConfigs(): void {
|
||||||
|
baseProviders.forEach((provider) => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: provider.id,
|
||||||
|
name: provider.name,
|
||||||
|
creator: provider.creator as any, // 类型转换以兼容多种creator签名
|
||||||
|
supportsImageGeneration: provider.supportsImageGeneration || false
|
||||||
|
}
|
||||||
|
providerConfigs.set(provider.id, config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动时自动注册内置配置
|
||||||
|
initializeBuiltInConfigs()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
|
||||||
|
*/
|
||||||
|
export function registerProviderConfig(config: ProviderConfig): boolean {
|
||||||
try {
|
try {
|
||||||
// 验证配置
|
// 验证配置
|
||||||
if (!config.id || !config.name) {
|
if (!config.id || !config.name) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否与基础provider冲突
|
// 检查是否与已有配置冲突(包括内置配置)
|
||||||
if (baseProviders.find((p) => p.id === config.id)) {
|
if (providerConfigs.has(config.id)) {
|
||||||
console.warn(`Dynamic provider "${config.id}" conflicts with base provider`)
|
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 存储动态provider配置
|
// 存储配置(内置和用户配置统一处理)
|
||||||
dynamicProviders.set(config.id, config)
|
providerConfigs.set(config.id, config)
|
||||||
|
|
||||||
// 如果有creator函数,立即初始化
|
// 处理别名
|
||||||
if (config.creator) {
|
if (config.aliases && config.aliases.length > 0) {
|
||||||
try {
|
config.aliases.forEach((alias) => {
|
||||||
const provider = config.creator({}) as any // 使用空配置初始化,类型断言为any
|
if (providerConfigAliases.has(alias)) {
|
||||||
const aliases = config.mappings ? Object.keys(config.mappings) : undefined
|
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
|
||||||
globalRegistryManagement.registerProvider(config.id, provider, aliases)
|
|
||||||
} catch (error) {
|
|
||||||
console.error(`Failed to initialize dynamic provider "${config.id}":`, error)
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
providerConfigAliases.set(alias, config.id)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Failed to register dynamic provider:`, error)
|
console.error(`Failed to register ProviderConfig:`, error)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 批量注册动态providers
|
* 步骤2: 创建Provider - 根据配置执行实际创建
|
||||||
*/
|
*/
|
||||||
export function registerMultipleProviders(configs: DynamicProviderRegistration[]): number {
|
export async function createProvider(providerId: string, options: any): Promise<any> {
|
||||||
|
// 支持通过别名查找配置
|
||||||
|
const config = getProviderConfigByAlias(providerId)
|
||||||
|
|
||||||
|
if (!config) {
|
||||||
|
throw new Error(`ProviderConfig not found for id: ${providerId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
let creator: (options: any) => any
|
||||||
|
|
||||||
|
if (config.creator) {
|
||||||
|
// 方式1: 直接执行 creator
|
||||||
|
creator = config.creator
|
||||||
|
} else if (config.import && config.creatorFunctionName) {
|
||||||
|
// 方式2: 动态导入并执行
|
||||||
|
const module = await config.import()
|
||||||
|
creator = (module as any)[config.creatorFunctionName]
|
||||||
|
|
||||||
|
if (!creator || typeof creator !== 'function') {
|
||||||
|
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error('No valid creator method provided in ProviderConfig')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用真实配置创建provider实例
|
||||||
|
return creator(options)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to create provider "${providerId}":`, error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 步骤3: 注册Provider到全局管理器
|
||||||
|
*/
|
||||||
|
export function registerProvider(providerId: string, provider: any): boolean {
|
||||||
|
try {
|
||||||
|
const config = providerConfigs.get(providerId)
|
||||||
|
if (!config) {
|
||||||
|
console.error(`ProviderConfig not found for id: ${providerId}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取aliases配置
|
||||||
|
const aliases = config.aliases
|
||||||
|
|
||||||
|
// 处理特殊provider逻辑
|
||||||
|
if (providerId === 'openai') {
|
||||||
|
// 注册默认 openai
|
||||||
|
globalRegistryManagement.registerProvider('openai', provider, aliases)
|
||||||
|
|
||||||
|
// 创建并注册 openai-chat 变体
|
||||||
|
const openaiChatProvider = customProvider({
|
||||||
|
fallbackProvider: {
|
||||||
|
...provider,
|
||||||
|
languageModel: (modelId: string) => provider.chat(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
|
||||||
|
} else {
|
||||||
|
// 其他provider直接注册
|
||||||
|
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 便捷函数: 一次性完成创建+注册
|
||||||
|
*/
|
||||||
|
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
// 步骤2: 创建provider
|
||||||
|
const provider = await createProvider(providerId, options)
|
||||||
|
|
||||||
|
// 步骤3: 注册到全局管理器
|
||||||
|
return registerProvider(providerId, provider)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to create and register provider "${providerId}":`, error)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量注册Provider配置
|
||||||
|
*/
|
||||||
|
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
|
||||||
let successCount = 0
|
let successCount = 0
|
||||||
configs.forEach((config) => {
|
configs.forEach((config) => {
|
||||||
if (registerDynamicProvider(config)) {
|
if (registerProviderConfig(config)) {
|
||||||
successCount++
|
successCount++
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -284,47 +228,83 @@ export function registerMultipleProviders(configs: DynamicProviderRegistration[]
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取provider映射(解析别名)
|
* 检查是否有对应的Provider配置
|
||||||
*/
|
*/
|
||||||
export function getProviderMapping(providerId: string): string {
|
export function hasProviderConfig(providerId: string): boolean {
|
||||||
return globalRegistryManagement.resolveProviderId(providerId)
|
return providerConfigs.has(providerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 检查是否为动态provider
|
* 通过别名或ID检查是否有对应的Provider配置
|
||||||
*/
|
*/
|
||||||
export function isDynamicProvider(providerId: string): boolean {
|
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
|
||||||
return dynamicProviders.has(providerId)
|
const realId = resolveProviderConfigId(aliasOrId)
|
||||||
|
return providerConfigs.has(realId)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取所有动态providers
|
* 获取所有Provider配置
|
||||||
*/
|
*/
|
||||||
export function getDynamicProviders(): DynamicProviderRegistration[] {
|
export function getAllProviderConfigs(): ProviderConfig[] {
|
||||||
return Array.from(dynamicProviders.values())
|
return Array.from(providerConfigs.values())
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取所有别名映射关系
|
* 根据ID获取Provider配置
|
||||||
*/
|
*/
|
||||||
export function getAllDynamicMappings(): Record<string, string> {
|
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
|
||||||
return globalRegistryManagement.getAllAliases()
|
return providerConfigs.get(providerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 清理所有动态providers
|
* 通过别名或ID获取Provider配置
|
||||||
|
*/
|
||||||
|
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
|
||||||
|
// 先检查是否为别名,如果是则解析为真实ID
|
||||||
|
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||||
|
return providerConfigs.get(realId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析真实的ProviderConfig ID(去别名化)
|
||||||
|
*/
|
||||||
|
export function resolveProviderConfigId(aliasOrId: string): string {
|
||||||
|
return providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否为ProviderConfig别名
|
||||||
|
*/
|
||||||
|
export function isProviderConfigAlias(id: string): boolean {
|
||||||
|
return providerConfigAliases.has(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有ProviderConfig别名映射关系
|
||||||
|
*/
|
||||||
|
export function getAllProviderConfigAliases(): Record<string, string> {
|
||||||
|
const result: Record<string, string> = {}
|
||||||
|
providerConfigAliases.forEach((realId, alias) => {
|
||||||
|
result[alias] = realId
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理所有Provider配置和已注册的providers
|
||||||
*/
|
*/
|
||||||
export function cleanup(): void {
|
export function cleanup(): void {
|
||||||
dynamicProviders.clear()
|
providerConfigs.clear()
|
||||||
|
providerConfigAliases.clear() // 清理别名映射
|
||||||
|
globalRegistryManagement.clear()
|
||||||
|
// 重新初始化内置配置
|
||||||
|
initializeBuiltInConfigs()
|
||||||
|
}
|
||||||
|
|
||||||
|
export function clearAllProviders(): void {
|
||||||
globalRegistryManagement.clear()
|
globalRegistryManagement.clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== 导出别名相关API ====================
|
// ==================== 导出错误类型 ====================
|
||||||
|
|
||||||
export const resolveProviderId = (id: string) => globalRegistryManagement.resolveProviderId(id)
|
export { ProviderInitializationError }
|
||||||
export const isAlias = (id: string) => globalRegistryManagement.isAlias(id)
|
|
||||||
export const getAllAliases = () => globalRegistryManagement.getAllAliases()
|
|
||||||
|
|
||||||
// ==================== 导出错误类型和工具函数 ====================
|
|
||||||
|
|
||||||
export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError }
|
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* 基于 Zod 的 Provider 验证系统
|
* Provider Config 定义
|
||||||
* - 纯验证层,无状态管理
|
|
||||||
* - 数据驱动的 Provider 定义
|
|
||||||
* - 完整的类型安全
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||||
@ -71,7 +68,7 @@ export const baseProviders = [
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 基础 Provider IDs
|
* 基础 Provider IDs
|
||||||
* 从 baseProviders 自动提取,避免重复维护
|
* 从 baseProviders 动态生成
|
||||||
*/
|
*/
|
||||||
export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]]
|
export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]]
|
||||||
|
|
||||||
@ -81,46 +78,28 @@ export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as read
|
|||||||
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 动态 Provider ID Schema
|
* 用户自定义 Provider ID Schema
|
||||||
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
||||||
*/
|
*/
|
||||||
export const dynamicProviderIdSchema = z
|
export const customProviderIdSchema = z
|
||||||
.string()
|
.string()
|
||||||
.min(1)
|
.min(1)
|
||||||
.refine((id) => !baseProviderIds.includes(id as any), {
|
.refine((id) => !baseProviderIds.includes(id as any), {
|
||||||
message: 'Dynamic provider ID cannot conflict with base provider IDs'
|
message: 'Custom provider ID cannot conflict with base provider IDs'
|
||||||
})
|
})
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 组合的 Provider ID Schema
|
* Provider ID Schema - 支持基础和自定义
|
||||||
* 支持基础 providers + 动态扩展
|
|
||||||
*/
|
*/
|
||||||
export const providerIdSchema = z.union([baseProviderIdSchema, dynamicProviderIdSchema])
|
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provider 配置 Schema
|
* Provider 配置 Schema
|
||||||
|
* 用于Provider的配置验证
|
||||||
*/
|
*/
|
||||||
export const providerConfigSchema = z
|
export const providerConfigSchema = z
|
||||||
.object({
|
.object({
|
||||||
id: providerIdSchema,
|
id: customProviderIdSchema, // 只允许自定义ID
|
||||||
name: z.string().min(1),
|
|
||||||
creator: z.function().optional(),
|
|
||||||
import: z.function().optional(),
|
|
||||||
creatorFunctionName: z.string().optional(),
|
|
||||||
supportsImageGeneration: z.boolean().default(false),
|
|
||||||
imageCreator: z.function().optional(),
|
|
||||||
validateOptions: z.function().optional()
|
|
||||||
})
|
|
||||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
|
||||||
message: 'Must provide either creator function or import configuration'
|
|
||||||
})
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 动态 Provider 注册配置 Schema
|
|
||||||
*/
|
|
||||||
export const dynamicProviderRegistrationSchema = z
|
|
||||||
.object({
|
|
||||||
id: dynamicProviderIdSchema,
|
|
||||||
name: z.string().min(1),
|
name: z.string().min(1),
|
||||||
creator: z.function().optional(),
|
creator: z.function().optional(),
|
||||||
import: z.function().optional(),
|
import: z.function().optional(),
|
||||||
@ -128,68 +107,26 @@ export const dynamicProviderRegistrationSchema = z
|
|||||||
supportsImageGeneration: z.boolean().default(false),
|
supportsImageGeneration: z.boolean().default(false),
|
||||||
imageCreator: z.function().optional(),
|
imageCreator: z.function().optional(),
|
||||||
validateOptions: z.function().optional(),
|
validateOptions: z.function().optional(),
|
||||||
mappings: z.record(z.string()).optional()
|
aliases: z.array(z.string()).optional()
|
||||||
})
|
})
|
||||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||||
message: 'Must provide either creator function or import configuration'
|
message: 'Must provide either creator function or import configuration'
|
||||||
})
|
})
|
||||||
|
|
||||||
// ===== 类型推导 =====
|
/**
|
||||||
|
* Provider ID 类型 - 基于 zod schema 推导
|
||||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
*/
|
||||||
export type DynamicProviderId = z.infer<typeof dynamicProviderIdSchema>
|
|
||||||
export type ProviderId = z.infer<typeof providerIdSchema>
|
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||||
|
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||||
|
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider 配置类型
|
||||||
|
*/
|
||||||
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
||||||
export type DynamicProviderRegistration = z.infer<typeof dynamicProviderRegistrationSchema>
|
|
||||||
|
|
||||||
// ===== 纯验证函数 =====
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 验证 Provider ID 是否有效(包括基础和动态格式)
|
* 兼容性类型别名
|
||||||
|
* @deprecated 使用 ProviderConfig 替代
|
||||||
*/
|
*/
|
||||||
export function validateProviderId(id: string): boolean {
|
export type DynamicProviderRegistration = ProviderConfig
|
||||||
return providerIdSchema.safeParse(id).success
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 验证是否为基础 Provider ID
|
|
||||||
*/
|
|
||||||
export function isBaseProviderId(id: string): id is BaseProviderId {
|
|
||||||
return baseProviderIdSchema.safeParse(id).success
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 验证是否为有效的动态 Provider ID 格式
|
|
||||||
*/
|
|
||||||
export function isValidDynamicProviderId(id: string): boolean {
|
|
||||||
return dynamicProviderIdSchema.safeParse(id).success
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 验证 Provider 配置
|
|
||||||
*/
|
|
||||||
export function validateProviderConfig(config: unknown): ProviderConfig | null {
|
|
||||||
const result = providerConfigSchema.safeParse(config)
|
|
||||||
if (result.success) {
|
|
||||||
return result.data
|
|
||||||
}
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 验证动态 Provider 注册配置
|
|
||||||
*/
|
|
||||||
export function validateDynamicProviderRegistration(config: unknown): DynamicProviderRegistration | null {
|
|
||||||
const result = dynamicProviderRegistrationSchema.safeParse(config)
|
|
||||||
if (result.success) {
|
|
||||||
return result.data
|
|
||||||
}
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 获取基础 Provider 配置
|
|
||||||
*/
|
|
||||||
export function getBaseProviderConfig(id: BaseProviderId): (typeof baseProviders)[number] | undefined {
|
|
||||||
return baseProviders.find((p) => p.id === id)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -7,13 +7,7 @@ import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible
|
|||||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||||
|
|
||||||
// 导入基于 Zod 的 ProviderId 类型
|
// 导入基于 Zod 的 ProviderId 类型
|
||||||
import {
|
import { type ProviderId as ZodProviderId } from './schemas'
|
||||||
type BaseProviderId,
|
|
||||||
type DynamicProviderId,
|
|
||||||
type DynamicProviderRegistration,
|
|
||||||
type ProviderConfig,
|
|
||||||
type ProviderId as ZodProviderId
|
|
||||||
} from './schemas'
|
|
||||||
|
|
||||||
export interface ExtensibleProviderSettingsMap {
|
export interface ExtensibleProviderSettingsMap {
|
||||||
// 基础的静态providers
|
// 基础的静态providers
|
||||||
@ -76,10 +70,6 @@ export class ProviderError extends Error {
|
|||||||
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||||
export type ProviderId = ZodProviderId
|
export type ProviderId = ZodProviderId
|
||||||
|
|
||||||
// 重新导出相关类型
|
|
||||||
export type { BaseProviderId, DynamicProviderId, DynamicProviderRegistration, ProviderConfig }
|
|
||||||
|
|
||||||
// Provider类型注册工具
|
|
||||||
export interface ProviderTypeRegistrar {
|
export interface ProviderTypeRegistrar {
|
||||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||||
getProviderSettings<T extends string>(providerId: T): any
|
getProviderSettings<T extends string>(providerId: T): any
|
||||||
|
|||||||
@ -12,11 +12,10 @@ import {
|
|||||||
streamText
|
streamText
|
||||||
} from 'ai'
|
} from 'ai'
|
||||||
|
|
||||||
import { type ProviderId } from '../../types'
|
|
||||||
import { globalModelResolver } from '../models'
|
import { globalModelResolver } from '../models'
|
||||||
import { type ModelConfig } from '../models/types'
|
import { type ModelConfig } from '../models/types'
|
||||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||||
import { getProviderInfo } from '../providers/registry'
|
import { type ProviderId } from '../providers'
|
||||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||||
import { PluginEngine } from './pluginEngine'
|
import { PluginEngine } from './pluginEngine'
|
||||||
import { type RuntimeConfig } from './types'
|
import { type RuntimeConfig } from './types'
|
||||||
@ -311,13 +310,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 获取客户端信息
|
|
||||||
*/
|
|
||||||
getClientInfo() {
|
|
||||||
return getProviderInfo(this.config.providerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// === 静态工厂方法 ===
|
// === 静态工厂方法 ===
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import { ImageModelV2 } from '@ai-sdk/provider'
|
|||||||
import { LanguageModel } from 'ai'
|
import { LanguageModel } from 'ai'
|
||||||
|
|
||||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||||
import { isProviderSupported } from '../providers/registry'
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -240,20 +239,4 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
): PluginEngine<'openai-compatible'> {
|
): PluginEngine<'openai-compatible'> {
|
||||||
return new PluginEngine('openai-compatible', plugins)
|
return new PluginEngine('openai-compatible', plugins)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 创建标准提供商客户端
|
|
||||||
*/
|
|
||||||
static create<T extends ProviderId>(providerId: T, plugins?: AiPlugin[]): PluginEngine<T>
|
|
||||||
|
|
||||||
static create(providerId: string, plugins?: AiPlugin[]): PluginEngine<'openai-compatible'>
|
|
||||||
|
|
||||||
static create(providerId: string, plugins: AiPlugin[] = []): PluginEngine {
|
|
||||||
if (isProviderSupported(providerId)) {
|
|
||||||
return new PluginEngine(providerId as ProviderId, plugins)
|
|
||||||
} else {
|
|
||||||
// 对于未知 provider,使用 openai-compatible
|
|
||||||
return new PluginEngine('openai-compatible', plugins)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,10 +4,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// 导入内部使用的类和函数
|
// 导入内部使用的类和函数
|
||||||
import { getSupportedProviders, isProviderSupported } from './core/providers/registry'
|
|
||||||
import type { ProviderId } from './core/providers/types'
|
|
||||||
import type { ProviderSettingsMap } from './core/providers/types'
|
|
||||||
import { createExecutor } from './core/runtime'
|
|
||||||
|
|
||||||
// ==================== 主要用户接口 ====================
|
// ==================== 主要用户接口 ====================
|
||||||
export {
|
export {
|
||||||
@ -28,29 +24,8 @@ export { createContext, definePlugin, PluginManager } from './core/plugins'
|
|||||||
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||||
|
|
||||||
// ==================== 低级 API ====================
|
|
||||||
export { providerRegistry } from './core/providers/registry'
|
|
||||||
|
|
||||||
// ==================== 类型定义 ====================
|
// ==================== 类型定义 ====================
|
||||||
export type { ProviderConfig } from './core/providers/types'
|
export type { GenerateObjectParams, GenerateTextParams, StreamObjectParams, StreamTextParams } from './types'
|
||||||
export type { ProviderError } from './core/providers/types'
|
|
||||||
export type {
|
|
||||||
AnthropicProviderSettings,
|
|
||||||
AzureOpenAIProviderSettings,
|
|
||||||
DeepSeekProviderSettings,
|
|
||||||
GenerateObjectParams,
|
|
||||||
GenerateTextParams,
|
|
||||||
GoogleGenerativeAIProviderSettings,
|
|
||||||
OpenAICompatibleProviderSettings,
|
|
||||||
OpenAIProviderSettings,
|
|
||||||
ProviderId,
|
|
||||||
ProviderSettings,
|
|
||||||
ProviderSettingsMap,
|
|
||||||
StreamObjectParams,
|
|
||||||
StreamTextParams,
|
|
||||||
XaiProviderSettings
|
|
||||||
} from './types'
|
|
||||||
export * as aiSdk from 'ai'
|
|
||||||
|
|
||||||
// ==================== AI SDK 常用类型导出 ====================
|
// ==================== AI SDK 常用类型导出 ====================
|
||||||
// 直接导出 AI SDK 的常用类型,方便使用
|
// 直接导出 AI SDK 的常用类型,方便使用
|
||||||
@ -112,115 +87,6 @@ export {
|
|||||||
type TypedProviderOptions
|
type TypedProviderOptions
|
||||||
} from './core/options'
|
} from './core/options'
|
||||||
|
|
||||||
// ==================== Provider 初始化和管理 ====================
|
|
||||||
export {
|
|
||||||
clearAllProviders,
|
|
||||||
getImageModel,
|
|
||||||
getInitializedProviders,
|
|
||||||
// 访问功能
|
|
||||||
getLanguageModel,
|
|
||||||
getProviderInfo,
|
|
||||||
getTextEmbeddingModel,
|
|
||||||
hasInitializedProviders,
|
|
||||||
// initializeImageProvider, // deprecated: 使用 initializeProvider 即可
|
|
||||||
// 初始化功能
|
|
||||||
initializeProvider,
|
|
||||||
initializeProviders,
|
|
||||||
isProviderInitialized,
|
|
||||||
isProviderSupported,
|
|
||||||
// 错误类型
|
|
||||||
ProviderInitializationError,
|
|
||||||
reinitializeProvider
|
|
||||||
} from './core/providers/registry'
|
|
||||||
|
|
||||||
// ==================== 动态Provider注册和别名映射 ====================
|
|
||||||
export {
|
|
||||||
cleanup,
|
|
||||||
getAllAliases,
|
|
||||||
getAllDynamicMappings,
|
|
||||||
getDynamicProviders,
|
|
||||||
getProviderMapping,
|
|
||||||
isAlias,
|
|
||||||
isDynamicProvider,
|
|
||||||
registerDynamicProvider,
|
|
||||||
registerMultipleProviders,
|
|
||||||
resolveProviderId
|
|
||||||
} from './core/providers/registry'
|
|
||||||
|
|
||||||
// ==================== Zod Schema 和验证 ====================
|
|
||||||
export { baseProviderIds, validateProviderId } from './core/providers'
|
|
||||||
|
|
||||||
// ==================== Hub Provider ====================
|
|
||||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './core/providers/HubProvider'
|
|
||||||
|
|
||||||
// ==================== Provider 配置工厂 ====================
|
|
||||||
export {
|
|
||||||
type BaseProviderConfig,
|
|
||||||
createProviderConfig,
|
|
||||||
type ProviderConfigBuilder,
|
|
||||||
providerConfigBuilder,
|
|
||||||
ProviderConfigFactory
|
|
||||||
} from './core/providers/factory'
|
|
||||||
|
|
||||||
// ==================== 包信息 ====================
|
// ==================== 包信息 ====================
|
||||||
export const AI_CORE_VERSION = '1.0.0'
|
export const AI_CORE_VERSION = '1.0.0'
|
||||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||||
|
|
||||||
// ==================== 便捷 API ====================
|
|
||||||
// 主要的便捷工厂类
|
|
||||||
export const AiCore = {
|
|
||||||
version: AI_CORE_VERSION,
|
|
||||||
name: AI_CORE_NAME,
|
|
||||||
|
|
||||||
// 创建主要执行器(推荐使用)
|
|
||||||
create(providerId: ProviderId, options: ProviderSettingsMap[ProviderId], plugins: any[] = []) {
|
|
||||||
return createExecutor(providerId, options, plugins)
|
|
||||||
},
|
|
||||||
|
|
||||||
// 获取支持的providers
|
|
||||||
getSupportedProviders() {
|
|
||||||
return getSupportedProviders()
|
|
||||||
},
|
|
||||||
|
|
||||||
isSupported(providerId: ProviderId) {
|
|
||||||
return isProviderSupported(providerId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 推荐使用的执行器创建函数
|
|
||||||
export const createOpenAIExecutor = (options: ProviderSettingsMap['openai'], plugins?: any[]) => {
|
|
||||||
return createExecutor('openai', options, plugins)
|
|
||||||
}
|
|
||||||
|
|
||||||
export const createAnthropicExecutor = (options: ProviderSettingsMap['anthropic'], plugins?: any[]) => {
|
|
||||||
return createExecutor('anthropic', options, plugins)
|
|
||||||
}
|
|
||||||
|
|
||||||
export const createGoogleExecutor = (options: ProviderSettingsMap['google'], plugins?: any[]) => {
|
|
||||||
return createExecutor('google', options, plugins)
|
|
||||||
}
|
|
||||||
|
|
||||||
export const createXAIExecutor = (options: ProviderSettingsMap['xai'], plugins?: any[]) => {
|
|
||||||
return createExecutor('xai', options, plugins)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 调试和开发工具 ====================
|
|
||||||
export const DevTools = {
|
|
||||||
// 列出所有支持的providers
|
|
||||||
listProviders() {
|
|
||||||
return getSupportedProviders()
|
|
||||||
},
|
|
||||||
|
|
||||||
// 获取provider详细信息
|
|
||||||
getProviderDetails() {
|
|
||||||
const supportedProviders = getSupportedProviders()
|
|
||||||
|
|
||||||
return {
|
|
||||||
supportedProviders: supportedProviders.length,
|
|
||||||
providers: supportedProviders.map((p) => ({
|
|
||||||
id: p.id,
|
|
||||||
name: p.name
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,27 +1,9 @@
|
|||||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||||
|
|
||||||
import type { ProviderSettingsMap } from './core/providers/types'
|
|
||||||
|
|
||||||
// ProviderSettings 是所有 Provider Settings 的联合类型
|
|
||||||
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
|
|
||||||
|
|
||||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
|
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
|
||||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
|
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
|
||||||
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||||
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||||
|
|
||||||
// 重新导出 ProviderSettingsMap 中的所有类型
|
|
||||||
export type {
|
|
||||||
AnthropicProviderSettings,
|
|
||||||
AzureOpenAIProviderSettings,
|
|
||||||
DeepSeekProviderSettings,
|
|
||||||
GoogleGenerativeAIProviderSettings,
|
|
||||||
OpenAICompatibleProviderSettings,
|
|
||||||
OpenAIProviderSettings,
|
|
||||||
ProviderId,
|
|
||||||
ProviderSettingsMap,
|
|
||||||
XaiProviderSettings
|
|
||||||
} from './core/providers/types'
|
|
||||||
|
|
||||||
// 重新导出插件类型
|
// 重新导出插件类型
|
||||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||||
|
|||||||
@ -3,7 +3,8 @@ import { defineConfig } from 'tsdown'
|
|||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
entry: {
|
entry: {
|
||||||
index: 'src/index.ts',
|
index: 'src/index.ts',
|
||||||
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts'
|
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
|
||||||
|
'provider/index': 'src/core/providers/index.ts'
|
||||||
},
|
},
|
||||||
outDir: 'dist',
|
outDir: 'dist',
|
||||||
format: ['esm', 'cjs'],
|
format: ['esm', 'cjs'],
|
||||||
|
|||||||
@ -8,7 +8,8 @@
|
|||||||
* 3. 暂时保持接口兼容性
|
* 3. 暂时保持接口兼容性
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { createExecutor, generateImage, initializeProvider, StreamTextParams } from '@cherrystudio/ai-core'
|
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
|
||||||
|
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||||
@ -36,21 +37,6 @@ export default class ModernAiProvider {
|
|||||||
|
|
||||||
// 只保存配置,不预先创建executor
|
// 只保存配置,不预先创建executor
|
||||||
this.config = providerToAiSdkConfig(this.actualProvider)
|
this.config = providerToAiSdkConfig(this.actualProvider)
|
||||||
|
|
||||||
// 初始化 provider 到全局管理器
|
|
||||||
try {
|
|
||||||
initializeProvider(this.config.providerId, this.config.options)
|
|
||||||
logger.debug('Provider initialized successfully', {
|
|
||||||
providerId: this.config.providerId,
|
|
||||||
hasOptions: !!this.config.options
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
|
||||||
logger.debug('Provider initialization skipped (may already be initialized)', {
|
|
||||||
providerId: this.config.providerId,
|
|
||||||
error: error instanceof Error ? error.message : String(error)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public getActualProvider() {
|
public getActualProvider() {
|
||||||
@ -67,6 +53,21 @@ export default class ModernAiProvider {
|
|||||||
callType: string
|
callType: string
|
||||||
}
|
}
|
||||||
): Promise<CompletionsResult> {
|
): Promise<CompletionsResult> {
|
||||||
|
// 初始化 provider 到全局管理器
|
||||||
|
try {
|
||||||
|
await createAndRegisterProvider(this.config.providerId, this.config.options)
|
||||||
|
logger.debug('Provider initialized successfully', {
|
||||||
|
providerId: this.config.providerId,
|
||||||
|
hasOptions: !!this.config.options
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
||||||
|
logger.debug('Provider initialization skipped (may already be initialized)', {
|
||||||
|
providerId: this.config.providerId,
|
||||||
|
error: error instanceof Error ? error.message : String(error)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if (config.isImageGenerationEndpoint) {
|
if (config.isImageGenerationEndpoint) {
|
||||||
return await this.modernImageGeneration(modelId, params, config)
|
return await this.modernImageGeneration(modelId, params, config)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
|
import {
|
||||||
|
hasProviderConfig,
|
||||||
|
ProviderConfigFactory,
|
||||||
|
type ProviderId,
|
||||||
|
type ProviderSettingsMap
|
||||||
|
} from '@cherrystudio/ai-core/provider'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
@ -86,7 +91,7 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果AI SDK支持该provider,使用原生配置
|
// 如果AI SDK支持该provider,使用原生配置
|
||||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||||
return {
|
return {
|
||||||
providerId: aiSdkProviderId as ProviderId,
|
providerId: aiSdkProviderId as ProviderId,
|
||||||
@ -120,5 +125,5 @@ export function isModernSdkSupported(provider: Provider): boolean {
|
|||||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||||
|
|
||||||
// 如果映射到了支持的provider,则支持现代SDK
|
// 如果映射到了支持的provider,则支持现代SDK
|
||||||
return AiCore.isSupported(aiSdkProviderId)
|
return hasProviderConfig(aiSdkProviderId)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
|
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { Provider } from '@renderer/types'
|
import { Provider } from '@renderer/types'
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 尝试解析provider标识符(支持静态映射和动态映射)
|
* 尝试解析provider标识符(支持静态映射和别名)
|
||||||
*/
|
*/
|
||||||
function tryResolveProviderId(identifier: string): ProviderId | null {
|
function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||||
// 1. 检查静态映射
|
// 1. 检查静态映射
|
||||||
@ -39,15 +39,10 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
|
|||||||
return staticMapping
|
return staticMapping
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 检查动态映射
|
// 2. 检查AiCore是否支持(包括别名支持)
|
||||||
const dynamicMapping = getProviderMapping(identifier)
|
if (hasProviderConfigByAlias(identifier)) {
|
||||||
if (dynamicMapping && dynamicMapping !== identifier) {
|
// 解析为真实的Provider ID
|
||||||
return dynamicMapping as ProviderId
|
return resolveProviderConfigId(identifier) as ProviderId
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 检查AiCore是否直接支持
|
|
||||||
if (AiCore.isSupported(identifier)) {
|
|
||||||
return identifier as ProviderId
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null
|
return null
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
|
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ProviderConfigs')
|
const logger = loggerService.withContext('ProviderConfigs')
|
||||||
@ -16,9 +16,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
|||||||
import: () => import('@openrouter/ai-sdk-provider'),
|
import: () => import('@openrouter/ai-sdk-provider'),
|
||||||
creatorFunctionName: 'createOpenRouter',
|
creatorFunctionName: 'createOpenRouter',
|
||||||
supportsImageGeneration: true,
|
supportsImageGeneration: true,
|
||||||
mappings: {
|
aliases: ['openrouter']
|
||||||
openrouter: 'openrouter'
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'google-vertex',
|
id: 'google-vertex',
|
||||||
@ -26,10 +24,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
|||||||
import: () => import('@ai-sdk/google-vertex'),
|
import: () => import('@ai-sdk/google-vertex'),
|
||||||
creatorFunctionName: 'createGoogleVertex',
|
creatorFunctionName: 'createGoogleVertex',
|
||||||
supportsImageGeneration: true,
|
supportsImageGeneration: true,
|
||||||
mappings: {
|
aliases: ['google-vertex', 'vertexai']
|
||||||
'google-vertex': 'google-vertex',
|
|
||||||
vertexai: 'google-vertex'
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'bedrock',
|
id: 'bedrock',
|
||||||
@ -37,9 +32,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
|||||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||||
creatorFunctionName: 'createAmazonBedrock',
|
creatorFunctionName: 'createAmazonBedrock',
|
||||||
supportsImageGeneration: true,
|
supportsImageGeneration: true,
|
||||||
mappings: {
|
aliases: ['aws-bedrock']
|
||||||
'aws-bedrock': 'bedrock'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
] as const
|
] as const
|
||||||
|
|
||||||
@ -49,19 +42,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
|||||||
*/
|
*/
|
||||||
export async function initializeNewProviders(): Promise<void> {
|
export async function initializeNewProviders(): Promise<void> {
|
||||||
try {
|
try {
|
||||||
logger.info('Starting to register new providers', {
|
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
|
||||||
providerCount: NEW_PROVIDER_CONFIGS.length,
|
|
||||||
providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id)
|
|
||||||
})
|
|
||||||
|
|
||||||
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
|
|
||||||
|
|
||||||
logger.info('Provider registration completed', {
|
|
||||||
successCount,
|
|
||||||
totalCount: NEW_PROVIDER_CONFIGS.length,
|
|
||||||
failedCount: NEW_PROVIDER_CONFIGS.length - successCount
|
|
||||||
})
|
|
||||||
|
|
||||||
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
||||||
logger.warn('Some providers failed to register. Check previous error logs.')
|
logger.warn('Some providers failed to register. Check previous error logs.')
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
import { Tool } from '@cherrystudio/ai-core'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||||
|
import { jsonSchema, tool } from 'ai'
|
||||||
import { JSONSchema7 } from 'json-schema'
|
import { JSONSchema7 } from 'json-schema'
|
||||||
|
|
||||||
const { tool } = aiSdk
|
|
||||||
const logger = loggerService.withContext('MCP-utils')
|
const logger = loggerService.withContext('MCP-utils')
|
||||||
|
|
||||||
// Setup tools configuration based on provided parameters
|
// Setup tools configuration based on provided parameters
|
||||||
@ -30,7 +30,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
|
|||||||
for (const mcpTool of mcpTools) {
|
for (const mcpTool of mcpTools) {
|
||||||
tools[mcpTool.name] = tool({
|
tools[mcpTool.name] = tool({
|
||||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||||
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||||
execute: async (params) => {
|
execute: async (params) => {
|
||||||
// 创建适配的 MCPToolResponse 对象
|
// 创建适配的 MCPToolResponse 对象
|
||||||
const toolResponse: MCPToolResponse = {
|
const toolResponse: MCPToolResponse = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user