mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat(aiCore): enhance provider management and registration system
- Added support for a new provider configuration structure in package.json, enabling better integration of provider types. - Updated tsdown.config.ts to include new entry points for provider modules, improving build organization. - Refactored index.ts to streamline exports and enhance type handling for provider-related functionalities. - Simplified provider initialization and registration processes, allowing for more flexible provider management. - Improved type definitions and removed deprecated methods to enhance code clarity and maintainability.
This commit is contained in:
parent
2ce9314a10
commit
9c01e24317
@ -71,6 +71,13 @@
|
||||
"import": "./dist/built-in/plugins/index.mjs",
|
||||
"require": "./dist/built-in/plugins/index.js",
|
||||
"default": "./dist/built-in/plugins/index.js"
|
||||
},
|
||||
"./provider": {
|
||||
"types": "./dist/provider/index.d.ts",
|
||||
"react-native": "./dist/provider/index.js",
|
||||
"import": "./dist/provider/index.mjs",
|
||||
"require": "./dist/provider/index.js",
|
||||
"default": "./dist/provider/index.js"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||
import type { ProviderId } from '../../types'
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
// 插件管理器
|
||||
|
||||
@ -154,10 +154,10 @@ export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARA
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的 provider 列表(排除别名)
|
||||
* 获取已注册的 provider 列表
|
||||
*/
|
||||
getRegisteredProviders(): string[] {
|
||||
return Object.keys(this.providers).filter((id) => !this.aliases.has(id))
|
||||
return Object.keys(this.providers)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,60 +1,83 @@
|
||||
/**
|
||||
* Providers 模块统一导出 - 现代化架构
|
||||
* Providers 模块统一导出 - 独立Provider包
|
||||
*/
|
||||
|
||||
// ==================== 新版架构(推荐使用)====================
|
||||
// ==================== 核心管理器 ====================
|
||||
|
||||
// Provider 注册表管理器
|
||||
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
||||
|
||||
// Provider 初始化器(核心功能)
|
||||
// Provider 核心功能
|
||||
export {
|
||||
// 状态管理
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getImageModel,
|
||||
// 工具函数
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderInfo,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
getTextEmbeddingModel,
|
||||
hasInitializedProviders,
|
||||
// initializeImageProvider, // deprecated: 使用 initializeProvider 即可
|
||||
initializeProvider,
|
||||
initializeProviders,
|
||||
isProviderInitialized,
|
||||
isProviderSupported,
|
||||
// 工具函数
|
||||
hasProviderConfig,
|
||||
// 别名支持
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
// 错误类型
|
||||
ProviderInitializationError,
|
||||
ProviderInitializer,
|
||||
// 全局访问
|
||||
providerRegistry,
|
||||
reinitializeProvider
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
// 统一Provider系统
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from './registry'
|
||||
|
||||
// 动态Provider注册功能
|
||||
export {
|
||||
cleanup,
|
||||
getAllAliases,
|
||||
getAllDynamicMappings,
|
||||
getDynamicProviders,
|
||||
getProviderMapping,
|
||||
isAlias,
|
||||
isDynamicProvider,
|
||||
registerDynamicProvider,
|
||||
registerMultipleProviders,
|
||||
resolveProviderId
|
||||
} from './registry'
|
||||
|
||||
// ==================== 保留的导出(兼容性)====================
|
||||
// ==================== 基础数据和类型 ====================
|
||||
|
||||
// 基础Provider数据源
|
||||
export { baseProviderIds, baseProviders } from './schemas'
|
||||
|
||||
// 类型定义和Schema
|
||||
export type {
|
||||
BaseProviderId,
|
||||
CustomProviderId,
|
||||
DynamicProviderRegistration,
|
||||
ProviderConfig,
|
||||
ProviderId
|
||||
} from './schemas' // 从 schemas 导出的类型
|
||||
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||
export type {
|
||||
DynamicProviderRegistry,
|
||||
ExtensibleProviderSettingsMap,
|
||||
ProviderError,
|
||||
ProviderSettingsMap,
|
||||
ProviderTypeRegistrar
|
||||
} from './types'
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
// Provider配置工厂
|
||||
export {
|
||||
type BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './factory'
|
||||
|
||||
// 工具函数
|
||||
export { formatPrivateKey } from './utils'
|
||||
|
||||
// ==================== 扩展功能 ====================
|
||||
|
||||
// Hub Provider 功能
|
||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
||||
|
||||
// 类型定义(可能被其他模块使用)
|
||||
export type { ProviderConfig, ProviderId, ProviderSettingsMap } from './types'
|
||||
|
||||
// Provider验证功能(使用更好的Zod版本)
|
||||
export { validateProviderConfig } from './schemas'
|
||||
|
||||
// 验证功能(可能被其他地方使用)
|
||||
export { validateProviderId } from './schemas'
|
||||
|
||||
@ -6,9 +6,8 @@
|
||||
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import { baseProviders, type DynamicProviderRegistration } from './schemas'
|
||||
import { baseProviders, type ProviderConfig } from './schemas'
|
||||
|
||||
/**
|
||||
* Provider 初始化错误类型
|
||||
@ -24,141 +23,6 @@ class ProviderInitializationError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider 初始化器类
|
||||
*/
|
||||
export class ProviderInitializer {
|
||||
/**
|
||||
* 初始化单个 provider 并注册
|
||||
*/
|
||||
static initializeProvider(providerId: string, options: any): void {
|
||||
try {
|
||||
// 1. 从 schemas 获取 provider 配置
|
||||
const providerConfig = baseProviders.find((p) => p.id === providerId)
|
||||
if (!providerConfig) {
|
||||
throw new ProviderInitializationError(`Provider configuration for '${providerId}' not found`, providerId)
|
||||
}
|
||||
|
||||
// 2. 使用 creator 函数创建已配置的 provider
|
||||
const configuredProvider = providerConfig.creator(options)
|
||||
|
||||
// 3. 处理特殊逻辑并注册到全局管理器
|
||||
this.handleProviderSpecificLogic(configuredProvider, providerId)
|
||||
} catch (error) {
|
||||
if (error instanceof ProviderInitializationError) {
|
||||
throw error
|
||||
}
|
||||
throw new ProviderInitializationError(
|
||||
`Failed to initialize provider ${providerId}: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量初始化 providers
|
||||
*/
|
||||
static initializeProviders(providers: Record<string, any>): void {
|
||||
Object.entries(providers).forEach(([providerId, options]) => {
|
||||
try {
|
||||
this.initializeProvider(providerId, options)
|
||||
} catch (error) {
|
||||
console.error(`Failed to initialize provider ${providerId}:`, error)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理特定 provider 的特殊逻辑 (从 ModelCreator 迁移并改进)
|
||||
*/
|
||||
private static handleProviderSpecificLogic(provider: any, providerId: string): void {
|
||||
if (providerId === 'openai') {
|
||||
// 🎯 OpenAI 默认注册 (responses 模式)
|
||||
globalRegistryManagement.registerProvider('openai', provider)
|
||||
|
||||
// 🎯 使用 AI SDK 官方的 customProvider 创建 chat 模式变体
|
||||
const openaiChatProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
// 覆盖 languageModel 方法指向 chat
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
|
||||
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
|
||||
} else {
|
||||
// 其他 provider 直接注册
|
||||
globalRegistryManagement.registerProvider(providerId, provider)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化图像生成 provider (从 ModelCreator 迁移)
|
||||
*
|
||||
* @deprecated 不再需要单独的图像provider初始化,使用 initializeProvider() 即可
|
||||
* 一个provider实例可以同时支持文本和图像功能,无需分别初始化
|
||||
*
|
||||
* TODO: 考虑在下个版本中删除此方法
|
||||
*/
|
||||
// static initializeImageProvider(providerId: string, options: any): void {
|
||||
// try {
|
||||
// const providerConfig = baseProviders.find((p) => p.id === providerId)
|
||||
// if (!providerConfig) {
|
||||
// throw new ProviderInitializationError(`Provider configuration for '${providerId}' not found`, providerId)
|
||||
// }
|
||||
|
||||
// if (!providerConfig.supportsImageGeneration) {
|
||||
// throw new ProviderInitializationError(`Provider "${providerId}" does not support image generation`, providerId)
|
||||
// }
|
||||
|
||||
// const provider = providerConfig.creator(options)
|
||||
|
||||
// // 注册图像 provider (使用特殊前缀区分)
|
||||
// globalRegistryManagement.registerProvider(`${providerId}-image`, provider as any)
|
||||
// } catch (error) {
|
||||
// if (error instanceof ProviderInitializationError) {
|
||||
// throw error
|
||||
// }
|
||||
// throw new ProviderInitializationError(
|
||||
// `Failed to initialize image provider ${providerId}: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
// providerId,
|
||||
// error instanceof Error ? error : undefined
|
||||
// )
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* 检查 provider 是否已初始化
|
||||
*/
|
||||
static isProviderInitialized(providerId: string): boolean {
|
||||
return globalRegistryManagement.getRegisteredProviders().includes(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 重新初始化 provider(更新配置)
|
||||
*/
|
||||
static reinitializeProvider(providerId: string, options: any): void {
|
||||
this.initializeProvider(providerId, options) // 会覆盖已有的
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除所有已初始化的 providers
|
||||
*/
|
||||
static clearAllProviders(): void {
|
||||
globalRegistryManagement.clear()
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 便捷函数导出 ====================
|
||||
|
||||
export const initializeProvider = ProviderInitializer.initializeProvider.bind(ProviderInitializer)
|
||||
export const initializeProviders = ProviderInitializer.initializeProviders.bind(ProviderInitializer)
|
||||
// export const initializeImageProvider = ProviderInitializer.initializeImageProvider.bind(ProviderInitializer) // deprecated: 使用 initializeProvider 即可
|
||||
export const isProviderInitialized = ProviderInitializer.isProviderInitialized.bind(ProviderInitializer)
|
||||
export const reinitializeProvider = ProviderInitializer.reinitializeProvider.bind(ProviderInitializer)
|
||||
export const clearAllProviders = ProviderInitializer.clearAllProviders.bind(ProviderInitializer)
|
||||
|
||||
// ==================== 全局管理器导出 ====================
|
||||
|
||||
export { globalRegistryManagement as providerRegistry }
|
||||
@ -169,10 +33,10 @@ export const getLanguageModel = (id: string) => globalRegistryManagement.languag
|
||||
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
|
||||
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
||||
|
||||
// ==================== 工具函数 (从 ModelCreator 迁移) ====================
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
/**
|
||||
* 获取支持的 Providers 列表 (从 ModelCreator 迁移)
|
||||
* 获取支持的 Providers 列表
|
||||
*/
|
||||
export function getSupportedProviders(): Array<{
|
||||
id: string
|
||||
@ -184,35 +48,6 @@ export function getSupportedProviders(): Array<{
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 Provider 是否被支持
|
||||
*/
|
||||
export function isProviderSupported(providerId: string): boolean {
|
||||
return getProviderInfo(providerId).isSupported
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Provider 信息 (从 ModelCreator 迁移并改进)
|
||||
*/
|
||||
export function getProviderInfo(providerId: string): {
|
||||
id: string
|
||||
name: string
|
||||
isSupported: boolean
|
||||
isInitialized: boolean
|
||||
effectiveProvider: string
|
||||
} {
|
||||
const provider = baseProviders.find((p) => p.id === providerId)
|
||||
const isInitialized = globalRegistryManagement.getRegisteredProviders().includes(providerId)
|
||||
|
||||
return {
|
||||
id: providerId,
|
||||
name: provider?.name || providerId,
|
||||
isSupported: !!provider,
|
||||
isInitialized,
|
||||
effectiveProvider: isInitialized ? providerId : 'openai-compatible'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已初始化的 providers
|
||||
*/
|
||||
@ -227,56 +62,165 @@ export function hasInitializedProviders(): boolean {
|
||||
return globalRegistryManagement.hasProviders()
|
||||
}
|
||||
|
||||
// ==================== 动态Provider注册功能 ====================
|
||||
// ==================== 统一Provider配置系统 ====================
|
||||
|
||||
// 全局动态provider存储
|
||||
const dynamicProviders = new Map<string, DynamicProviderRegistration>()
|
||||
// 全局Provider配置存储
|
||||
const providerConfigs = new Map<string, ProviderConfig>()
|
||||
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
|
||||
const providerConfigAliases = new Map<string, string>() // alias -> realId
|
||||
|
||||
/**
|
||||
* 注册动态provider
|
||||
* 初始化内置配置 - 将baseProviders转换为统一格式
|
||||
*/
|
||||
export function registerDynamicProvider(config: DynamicProviderRegistration): boolean {
|
||||
function initializeBuiltInConfigs(): void {
|
||||
baseProviders.forEach((provider) => {
|
||||
const config: ProviderConfig = {
|
||||
id: provider.id,
|
||||
name: provider.name,
|
||||
creator: provider.creator as any, // 类型转换以兼容多种creator签名
|
||||
supportsImageGeneration: provider.supportsImageGeneration || false
|
||||
}
|
||||
providerConfigs.set(provider.id, config)
|
||||
})
|
||||
}
|
||||
|
||||
// 启动时自动注册内置配置
|
||||
initializeBuiltInConfigs()
|
||||
|
||||
/**
|
||||
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
|
||||
*/
|
||||
export function registerProviderConfig(config: ProviderConfig): boolean {
|
||||
try {
|
||||
// 验证配置
|
||||
if (!config.id || !config.name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否与基础provider冲突
|
||||
if (baseProviders.find((p) => p.id === config.id)) {
|
||||
console.warn(`Dynamic provider "${config.id}" conflicts with base provider`)
|
||||
return false
|
||||
// 检查是否与已有配置冲突(包括内置配置)
|
||||
if (providerConfigs.has(config.id)) {
|
||||
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
|
||||
}
|
||||
|
||||
// 存储动态provider配置
|
||||
dynamicProviders.set(config.id, config)
|
||||
// 存储配置(内置和用户配置统一处理)
|
||||
providerConfigs.set(config.id, config)
|
||||
|
||||
// 如果有creator函数,立即初始化
|
||||
if (config.creator) {
|
||||
try {
|
||||
const provider = config.creator({}) as any // 使用空配置初始化,类型断言为any
|
||||
const aliases = config.mappings ? Object.keys(config.mappings) : undefined
|
||||
globalRegistryManagement.registerProvider(config.id, provider, aliases)
|
||||
} catch (error) {
|
||||
console.error(`Failed to initialize dynamic provider "${config.id}":`, error)
|
||||
return false
|
||||
}
|
||||
// 处理别名
|
||||
if (config.aliases && config.aliases.length > 0) {
|
||||
config.aliases.forEach((alias) => {
|
||||
if (providerConfigAliases.has(alias)) {
|
||||
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
|
||||
}
|
||||
providerConfigAliases.set(alias, config.id)
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register dynamic provider:`, error)
|
||||
console.error(`Failed to register ProviderConfig:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册动态providers
|
||||
* 步骤2: 创建Provider - 根据配置执行实际创建
|
||||
*/
|
||||
export function registerMultipleProviders(configs: DynamicProviderRegistration[]): number {
|
||||
export async function createProvider(providerId: string, options: any): Promise<any> {
|
||||
// 支持通过别名查找配置
|
||||
const config = getProviderConfigByAlias(providerId)
|
||||
|
||||
if (!config) {
|
||||
throw new Error(`ProviderConfig not found for id: ${providerId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
let creator: (options: any) => any
|
||||
|
||||
if (config.creator) {
|
||||
// 方式1: 直接执行 creator
|
||||
creator = config.creator
|
||||
} else if (config.import && config.creatorFunctionName) {
|
||||
// 方式2: 动态导入并执行
|
||||
const module = await config.import()
|
||||
creator = (module as any)[config.creatorFunctionName]
|
||||
|
||||
if (!creator || typeof creator !== 'function') {
|
||||
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
|
||||
}
|
||||
} else {
|
||||
throw new Error('No valid creator method provided in ProviderConfig')
|
||||
}
|
||||
|
||||
// 使用真实配置创建provider实例
|
||||
return creator(options)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create provider "${providerId}":`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤3: 注册Provider到全局管理器
|
||||
*/
|
||||
export function registerProvider(providerId: string, provider: any): boolean {
|
||||
try {
|
||||
const config = providerConfigs.get(providerId)
|
||||
if (!config) {
|
||||
console.error(`ProviderConfig not found for id: ${providerId}`)
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取aliases配置
|
||||
const aliases = config.aliases
|
||||
|
||||
// 处理特殊provider逻辑
|
||||
if (providerId === 'openai') {
|
||||
// 注册默认 openai
|
||||
globalRegistryManagement.registerProvider('openai', provider, aliases)
|
||||
|
||||
// 创建并注册 openai-chat 变体
|
||||
const openaiChatProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
|
||||
} else {
|
||||
// 其他provider直接注册
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷函数: 一次性完成创建+注册
|
||||
*/
|
||||
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
|
||||
try {
|
||||
// 步骤2: 创建provider
|
||||
const provider = await createProvider(providerId, options)
|
||||
|
||||
// 步骤3: 注册到全局管理器
|
||||
return registerProvider(providerId, provider)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create and register provider "${providerId}":`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册Provider配置
|
||||
*/
|
||||
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
|
||||
let successCount = 0
|
||||
configs.forEach((config) => {
|
||||
if (registerDynamicProvider(config)) {
|
||||
if (registerProviderConfig(config)) {
|
||||
successCount++
|
||||
}
|
||||
})
|
||||
@ -284,47 +228,83 @@ export function registerMultipleProviders(configs: DynamicProviderRegistration[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取provider映射(解析别名)
|
||||
* 检查是否有对应的Provider配置
|
||||
*/
|
||||
export function getProviderMapping(providerId: string): string {
|
||||
return globalRegistryManagement.resolveProviderId(providerId)
|
||||
export function hasProviderConfig(providerId: string): boolean {
|
||||
return providerConfigs.has(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为动态provider
|
||||
* 通过别名或ID检查是否有对应的Provider配置
|
||||
*/
|
||||
export function isDynamicProvider(providerId: string): boolean {
|
||||
return dynamicProviders.has(providerId)
|
||||
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
|
||||
const realId = resolveProviderConfigId(aliasOrId)
|
||||
return providerConfigs.has(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有动态providers
|
||||
* 获取所有Provider配置
|
||||
*/
|
||||
export function getDynamicProviders(): DynamicProviderRegistration[] {
|
||||
return Array.from(dynamicProviders.values())
|
||||
export function getAllProviderConfigs(): ProviderConfig[] {
|
||||
return Array.from(providerConfigs.values())
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有别名映射关系
|
||||
* 根据ID获取Provider配置
|
||||
*/
|
||||
export function getAllDynamicMappings(): Record<string, string> {
|
||||
return globalRegistryManagement.getAllAliases()
|
||||
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
|
||||
return providerConfigs.get(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有动态providers
|
||||
* 通过别名或ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
|
||||
// 先检查是否为别名,如果是则解析为真实ID
|
||||
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
return providerConfigs.get(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的ProviderConfig ID(去别名化)
|
||||
*/
|
||||
export function resolveProviderConfigId(aliasOrId: string): string {
|
||||
return providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为ProviderConfig别名
|
||||
*/
|
||||
export function isProviderConfigAlias(id: string): boolean {
|
||||
return providerConfigAliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有ProviderConfig别名映射关系
|
||||
*/
|
||||
export function getAllProviderConfigAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
providerConfigAliases.forEach((realId, alias) => {
|
||||
result[alias] = realId
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有Provider配置和已注册的providers
|
||||
*/
|
||||
export function cleanup(): void {
|
||||
dynamicProviders.clear()
|
||||
providerConfigs.clear()
|
||||
providerConfigAliases.clear() // 清理别名映射
|
||||
globalRegistryManagement.clear()
|
||||
// 重新初始化内置配置
|
||||
initializeBuiltInConfigs()
|
||||
}
|
||||
|
||||
export function clearAllProviders(): void {
|
||||
globalRegistryManagement.clear()
|
||||
}
|
||||
|
||||
// ==================== 导出别名相关API ====================
|
||||
// ==================== 导出错误类型 ====================
|
||||
|
||||
export const resolveProviderId = (id: string) => globalRegistryManagement.resolveProviderId(id)
|
||||
export const isAlias = (id: string) => globalRegistryManagement.isAlias(id)
|
||||
export const getAllAliases = () => globalRegistryManagement.getAllAliases()
|
||||
|
||||
// ==================== 导出错误类型和工具函数 ====================
|
||||
|
||||
export { isOpenAIChatCompletionOnlyModel, ProviderInitializationError }
|
||||
export { ProviderInitializationError }
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
/**
|
||||
* 基于 Zod 的 Provider 验证系统
|
||||
* - 纯验证层,无状态管理
|
||||
* - 数据驱动的 Provider 定义
|
||||
* - 完整的类型安全
|
||||
* Provider Config 定义
|
||||
*/
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
@ -71,7 +68,7 @@ export const baseProviders = [
|
||||
|
||||
/**
|
||||
* 基础 Provider IDs
|
||||
* 从 baseProviders 自动提取,避免重复维护
|
||||
* 从 baseProviders 动态生成
|
||||
*/
|
||||
export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]]
|
||||
|
||||
@ -81,46 +78,28 @@ export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as read
|
||||
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||
|
||||
/**
|
||||
* 动态 Provider ID Schema
|
||||
* 用户自定义 Provider ID Schema
|
||||
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
||||
*/
|
||||
export const dynamicProviderIdSchema = z
|
||||
export const customProviderIdSchema = z
|
||||
.string()
|
||||
.min(1)
|
||||
.refine((id) => !baseProviderIds.includes(id as any), {
|
||||
message: 'Dynamic provider ID cannot conflict with base provider IDs'
|
||||
message: 'Custom provider ID cannot conflict with base provider IDs'
|
||||
})
|
||||
|
||||
/**
|
||||
* 组合的 Provider ID Schema
|
||||
* 支持基础 providers + 动态扩展
|
||||
* Provider ID Schema - 支持基础和自定义
|
||||
*/
|
||||
export const providerIdSchema = z.union([baseProviderIdSchema, dynamicProviderIdSchema])
|
||||
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
|
||||
|
||||
/**
|
||||
* Provider 配置 Schema
|
||||
* 用于Provider的配置验证
|
||||
*/
|
||||
export const providerConfigSchema = z
|
||||
.object({
|
||||
id: providerIdSchema,
|
||||
name: z.string().min(1),
|
||||
creator: z.function().optional(),
|
||||
import: z.function().optional(),
|
||||
creatorFunctionName: z.string().optional(),
|
||||
supportsImageGeneration: z.boolean().default(false),
|
||||
imageCreator: z.function().optional(),
|
||||
validateOptions: z.function().optional()
|
||||
})
|
||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||
message: 'Must provide either creator function or import configuration'
|
||||
})
|
||||
|
||||
/**
|
||||
* 动态 Provider 注册配置 Schema
|
||||
*/
|
||||
export const dynamicProviderRegistrationSchema = z
|
||||
.object({
|
||||
id: dynamicProviderIdSchema,
|
||||
id: customProviderIdSchema, // 只允许自定义ID
|
||||
name: z.string().min(1),
|
||||
creator: z.function().optional(),
|
||||
import: z.function().optional(),
|
||||
@ -128,68 +107,26 @@ export const dynamicProviderRegistrationSchema = z
|
||||
supportsImageGeneration: z.boolean().default(false),
|
||||
imageCreator: z.function().optional(),
|
||||
validateOptions: z.function().optional(),
|
||||
mappings: z.record(z.string()).optional()
|
||||
aliases: z.array(z.string()).optional()
|
||||
})
|
||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||
message: 'Must provide either creator function or import configuration'
|
||||
})
|
||||
|
||||
// ===== 类型推导 =====
|
||||
|
||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
export type DynamicProviderId = z.infer<typeof dynamicProviderIdSchema>
|
||||
/**
|
||||
* Provider ID 类型 - 基于 zod schema 推导
|
||||
*/
|
||||
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
|
||||
|
||||
/**
|
||||
* Provider 配置类型
|
||||
*/
|
||||
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
||||
export type DynamicProviderRegistration = z.infer<typeof dynamicProviderRegistrationSchema>
|
||||
|
||||
// ===== 纯验证函数 =====
|
||||
|
||||
/**
|
||||
* 验证 Provider ID 是否有效(包括基础和动态格式)
|
||||
* 兼容性类型别名
|
||||
* @deprecated 使用 ProviderConfig 替代
|
||||
*/
|
||||
export function validateProviderId(id: string): boolean {
|
||||
return providerIdSchema.safeParse(id).success
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证是否为基础 Provider ID
|
||||
*/
|
||||
export function isBaseProviderId(id: string): id is BaseProviderId {
|
||||
return baseProviderIdSchema.safeParse(id).success
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证是否为有效的动态 Provider ID 格式
|
||||
*/
|
||||
export function isValidDynamicProviderId(id: string): boolean {
|
||||
return dynamicProviderIdSchema.safeParse(id).success
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 Provider 配置
|
||||
*/
|
||||
export function validateProviderConfig(config: unknown): ProviderConfig | null {
|
||||
const result = providerConfigSchema.safeParse(config)
|
||||
if (result.success) {
|
||||
return result.data
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证动态 Provider 注册配置
|
||||
*/
|
||||
export function validateDynamicProviderRegistration(config: unknown): DynamicProviderRegistration | null {
|
||||
const result = dynamicProviderRegistrationSchema.safeParse(config)
|
||||
if (result.success) {
|
||||
return result.data
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取基础 Provider 配置
|
||||
*/
|
||||
export function getBaseProviderConfig(id: BaseProviderId): (typeof baseProviders)[number] | undefined {
|
||||
return baseProviders.find((p) => p.id === id)
|
||||
}
|
||||
export type DynamicProviderRegistration = ProviderConfig
|
||||
|
||||
@ -7,13 +7,7 @@ import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible
|
||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||
|
||||
// 导入基于 Zod 的 ProviderId 类型
|
||||
import {
|
||||
type BaseProviderId,
|
||||
type DynamicProviderId,
|
||||
type DynamicProviderRegistration,
|
||||
type ProviderConfig,
|
||||
type ProviderId as ZodProviderId
|
||||
} from './schemas'
|
||||
import { type ProviderId as ZodProviderId } from './schemas'
|
||||
|
||||
export interface ExtensibleProviderSettingsMap {
|
||||
// 基础的静态providers
|
||||
@ -76,10 +70,6 @@ export class ProviderError extends Error {
|
||||
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||
export type ProviderId = ZodProviderId
|
||||
|
||||
// 重新导出相关类型
|
||||
export type { BaseProviderId, DynamicProviderId, DynamicProviderRegistration, ProviderConfig }
|
||||
|
||||
// Provider类型注册工具
|
||||
export interface ProviderTypeRegistrar {
|
||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||
getProviderSettings<T extends string>(providerId: T): any
|
||||
|
||||
@ -12,11 +12,10 @@ import {
|
||||
streamText
|
||||
} from 'ai'
|
||||
|
||||
import { type ProviderId } from '../../types'
|
||||
import { globalModelResolver } from '../models'
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { getProviderInfo } from '../providers/registry'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
@ -311,13 +310,6 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取客户端信息
|
||||
*/
|
||||
getClientInfo() {
|
||||
return getProviderInfo(this.config.providerId)
|
||||
}
|
||||
|
||||
// === 静态工厂方法 ===
|
||||
|
||||
/**
|
||||
|
||||
@ -3,7 +3,6 @@ import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { isProviderSupported } from '../providers/registry'
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||
|
||||
/**
|
||||
@ -240,20 +239,4 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
): PluginEngine<'openai-compatible'> {
|
||||
return new PluginEngine('openai-compatible', plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建标准提供商客户端
|
||||
*/
|
||||
static create<T extends ProviderId>(providerId: T, plugins?: AiPlugin[]): PluginEngine<T>
|
||||
|
||||
static create(providerId: string, plugins?: AiPlugin[]): PluginEngine<'openai-compatible'>
|
||||
|
||||
static create(providerId: string, plugins: AiPlugin[] = []): PluginEngine {
|
||||
if (isProviderSupported(providerId)) {
|
||||
return new PluginEngine(providerId as ProviderId, plugins)
|
||||
} else {
|
||||
// 对于未知 provider,使用 openai-compatible
|
||||
return new PluginEngine('openai-compatible', plugins)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,10 +4,6 @@
|
||||
*/
|
||||
|
||||
// 导入内部使用的类和函数
|
||||
import { getSupportedProviders, isProviderSupported } from './core/providers/registry'
|
||||
import type { ProviderId } from './core/providers/types'
|
||||
import type { ProviderSettingsMap } from './core/providers/types'
|
||||
import { createExecutor } from './core/runtime'
|
||||
|
||||
// ==================== 主要用户接口 ====================
|
||||
export {
|
||||
@ -28,29 +24,8 @@ export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== 低级 API ====================
|
||||
export { providerRegistry } from './core/providers/registry'
|
||||
|
||||
// ==================== 类型定义 ====================
|
||||
export type { ProviderConfig } from './core/providers/types'
|
||||
export type { ProviderError } from './core/providers/types'
|
||||
export type {
|
||||
AnthropicProviderSettings,
|
||||
AzureOpenAIProviderSettings,
|
||||
DeepSeekProviderSettings,
|
||||
GenerateObjectParams,
|
||||
GenerateTextParams,
|
||||
GoogleGenerativeAIProviderSettings,
|
||||
OpenAICompatibleProviderSettings,
|
||||
OpenAIProviderSettings,
|
||||
ProviderId,
|
||||
ProviderSettings,
|
||||
ProviderSettingsMap,
|
||||
StreamObjectParams,
|
||||
StreamTextParams,
|
||||
XaiProviderSettings
|
||||
} from './types'
|
||||
export * as aiSdk from 'ai'
|
||||
export type { GenerateObjectParams, GenerateTextParams, StreamObjectParams, StreamTextParams } from './types'
|
||||
|
||||
// ==================== AI SDK 常用类型导出 ====================
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
@ -112,115 +87,6 @@ export {
|
||||
type TypedProviderOptions
|
||||
} from './core/options'
|
||||
|
||||
// ==================== Provider 初始化和管理 ====================
|
||||
export {
|
||||
clearAllProviders,
|
||||
getImageModel,
|
||||
getInitializedProviders,
|
||||
// 访问功能
|
||||
getLanguageModel,
|
||||
getProviderInfo,
|
||||
getTextEmbeddingModel,
|
||||
hasInitializedProviders,
|
||||
// initializeImageProvider, // deprecated: 使用 initializeProvider 即可
|
||||
// 初始化功能
|
||||
initializeProvider,
|
||||
initializeProviders,
|
||||
isProviderInitialized,
|
||||
isProviderSupported,
|
||||
// 错误类型
|
||||
ProviderInitializationError,
|
||||
reinitializeProvider
|
||||
} from './core/providers/registry'
|
||||
|
||||
// ==================== 动态Provider注册和别名映射 ====================
|
||||
export {
|
||||
cleanup,
|
||||
getAllAliases,
|
||||
getAllDynamicMappings,
|
||||
getDynamicProviders,
|
||||
getProviderMapping,
|
||||
isAlias,
|
||||
isDynamicProvider,
|
||||
registerDynamicProvider,
|
||||
registerMultipleProviders,
|
||||
resolveProviderId
|
||||
} from './core/providers/registry'
|
||||
|
||||
// ==================== Zod Schema 和验证 ====================
|
||||
export { baseProviderIds, validateProviderId } from './core/providers'
|
||||
|
||||
// ==================== Hub Provider ====================
|
||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './core/providers/HubProvider'
|
||||
|
||||
// ==================== Provider 配置工厂 ====================
|
||||
export {
|
||||
type BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
type ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './core/providers/factory'
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
|
||||
// ==================== 便捷 API ====================
|
||||
// 主要的便捷工厂类
|
||||
export const AiCore = {
|
||||
version: AI_CORE_VERSION,
|
||||
name: AI_CORE_NAME,
|
||||
|
||||
// 创建主要执行器(推荐使用)
|
||||
create(providerId: ProviderId, options: ProviderSettingsMap[ProviderId], plugins: any[] = []) {
|
||||
return createExecutor(providerId, options, plugins)
|
||||
},
|
||||
|
||||
// 获取支持的providers
|
||||
getSupportedProviders() {
|
||||
return getSupportedProviders()
|
||||
},
|
||||
|
||||
isSupported(providerId: ProviderId) {
|
||||
return isProviderSupported(providerId)
|
||||
}
|
||||
}
|
||||
|
||||
// 推荐使用的执行器创建函数
|
||||
export const createOpenAIExecutor = (options: ProviderSettingsMap['openai'], plugins?: any[]) => {
|
||||
return createExecutor('openai', options, plugins)
|
||||
}
|
||||
|
||||
export const createAnthropicExecutor = (options: ProviderSettingsMap['anthropic'], plugins?: any[]) => {
|
||||
return createExecutor('anthropic', options, plugins)
|
||||
}
|
||||
|
||||
export const createGoogleExecutor = (options: ProviderSettingsMap['google'], plugins?: any[]) => {
|
||||
return createExecutor('google', options, plugins)
|
||||
}
|
||||
|
||||
export const createXAIExecutor = (options: ProviderSettingsMap['xai'], plugins?: any[]) => {
|
||||
return createExecutor('xai', options, plugins)
|
||||
}
|
||||
|
||||
// ==================== 调试和开发工具 ====================
|
||||
export const DevTools = {
|
||||
// 列出所有支持的providers
|
||||
listProviders() {
|
||||
return getSupportedProviders()
|
||||
},
|
||||
|
||||
// 获取provider详细信息
|
||||
getProviderDetails() {
|
||||
const supportedProviders = getSupportedProviders()
|
||||
|
||||
return {
|
||||
supportedProviders: supportedProviders.length,
|
||||
providers: supportedProviders.map((p) => ({
|
||||
id: p.id,
|
||||
name: p.name
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,27 +1,9 @@
|
||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||
|
||||
import type { ProviderSettingsMap } from './core/providers/types'
|
||||
|
||||
// ProviderSettings 是所有 Provider Settings 的联合类型
|
||||
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
|
||||
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
|
||||
// 重新导出 ProviderSettingsMap 中的所有类型
|
||||
export type {
|
||||
AnthropicProviderSettings,
|
||||
AzureOpenAIProviderSettings,
|
||||
DeepSeekProviderSettings,
|
||||
GoogleGenerativeAIProviderSettings,
|
||||
OpenAICompatibleProviderSettings,
|
||||
OpenAIProviderSettings,
|
||||
ProviderId,
|
||||
ProviderSettingsMap,
|
||||
XaiProviderSettings
|
||||
} from './core/providers/types'
|
||||
|
||||
// 重新导出插件类型
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||
|
||||
@ -3,7 +3,8 @@ import { defineConfig } from 'tsdown'
|
||||
export default defineConfig({
|
||||
entry: {
|
||||
index: 'src/index.ts',
|
||||
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts'
|
||||
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
|
||||
'provider/index': 'src/core/providers/index.ts'
|
||||
},
|
||||
outDir: 'dist',
|
||||
format: ['esm', 'cjs'],
|
||||
|
||||
@ -8,7 +8,8 @@
|
||||
* 3. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage, initializeProvider, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
@ -36,21 +37,6 @@ export default class ModernAiProvider {
|
||||
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider)
|
||||
|
||||
// 初始化 provider 到全局管理器
|
||||
try {
|
||||
initializeProvider(this.config.providerId, this.config.options)
|
||||
logger.debug('Provider initialized successfully', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options
|
||||
})
|
||||
} catch (error) {
|
||||
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
||||
logger.debug('Provider initialization skipped (may already be initialized)', {
|
||||
providerId: this.config.providerId,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
@ -67,6 +53,21 @@ export default class ModernAiProvider {
|
||||
callType: string
|
||||
}
|
||||
): Promise<CompletionsResult> {
|
||||
// 初始化 provider 到全局管理器
|
||||
try {
|
||||
await createAndRegisterProvider(this.config.providerId, this.config.options)
|
||||
logger.debug('Provider initialized successfully', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options
|
||||
})
|
||||
} catch (error) {
|
||||
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
||||
logger.debug('Provider initialization skipped (may already be initialized)', {
|
||||
providerId: this.config.providerId,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
}
|
||||
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
return await this.modernImageGeneration(modelId, params, config)
|
||||
}
|
||||
|
||||
@ -1,4 +1,9 @@
|
||||
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
hasProviderConfig,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap
|
||||
} from '@cherrystudio/ai-core/provider'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
@ -86,7 +91,7 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
}
|
||||
|
||||
// 如果AI SDK支持该provider,使用原生配置
|
||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||
return {
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
@ -120,5 +125,5 @@ export function isModernSdkSupported(provider: Provider): boolean {
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
// 如果映射到了支持的provider,则支持现代SDK
|
||||
return AiCore.isSupported(aiSdkProviderId)
|
||||
return hasProviderConfig(aiSdkProviderId)
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
|
||||
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
@ -30,7 +30,7 @@ const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
}
|
||||
|
||||
/**
|
||||
* 尝试解析provider标识符(支持静态映射和动态映射)
|
||||
* 尝试解析provider标识符(支持静态映射和别名)
|
||||
*/
|
||||
function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||
// 1. 检查静态映射
|
||||
@ -39,15 +39,10 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||
return staticMapping
|
||||
}
|
||||
|
||||
// 2. 检查动态映射
|
||||
const dynamicMapping = getProviderMapping(identifier)
|
||||
if (dynamicMapping && dynamicMapping !== identifier) {
|
||||
return dynamicMapping as ProviderId
|
||||
}
|
||||
|
||||
// 3. 检查AiCore是否直接支持
|
||||
if (AiCore.isSupported(identifier)) {
|
||||
return identifier as ProviderId
|
||||
// 2. 检查AiCore是否支持(包括别名支持)
|
||||
if (hasProviderConfigByAlias(identifier)) {
|
||||
// 解析为真实的Provider ID
|
||||
return resolveProviderConfigId(identifier) as ProviderId
|
||||
}
|
||||
|
||||
return null
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
|
||||
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
const logger = loggerService.withContext('ProviderConfigs')
|
||||
@ -16,9 +16,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
import: () => import('@openrouter/ai-sdk-provider'),
|
||||
creatorFunctionName: 'createOpenRouter',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
openrouter: 'openrouter'
|
||||
}
|
||||
aliases: ['openrouter']
|
||||
},
|
||||
{
|
||||
id: 'google-vertex',
|
||||
@ -26,10 +24,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
import: () => import('@ai-sdk/google-vertex'),
|
||||
creatorFunctionName: 'createGoogleVertex',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
'google-vertex': 'google-vertex',
|
||||
vertexai: 'google-vertex'
|
||||
}
|
||||
aliases: ['google-vertex', 'vertexai']
|
||||
},
|
||||
{
|
||||
id: 'bedrock',
|
||||
@ -37,9 +32,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||
creatorFunctionName: 'createAmazonBedrock',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
'aws-bedrock': 'bedrock'
|
||||
}
|
||||
aliases: ['aws-bedrock']
|
||||
}
|
||||
] as const
|
||||
|
||||
@ -49,19 +42,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
*/
|
||||
export async function initializeNewProviders(): Promise<void> {
|
||||
try {
|
||||
logger.info('Starting to register new providers', {
|
||||
providerCount: NEW_PROVIDER_CONFIGS.length,
|
||||
providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id)
|
||||
})
|
||||
|
||||
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
|
||||
|
||||
logger.info('Provider registration completed', {
|
||||
successCount,
|
||||
totalCount: NEW_PROVIDER_CONFIGS.length,
|
||||
failedCount: NEW_PROVIDER_CONFIGS.length - successCount
|
||||
})
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
|
||||
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
||||
logger.warn('Some providers failed to register. Check previous error logs.')
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||
import { Tool } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { jsonSchema, tool } from 'ai'
|
||||
import { JSONSchema7 } from 'json-schema'
|
||||
|
||||
const { tool } = aiSdk
|
||||
const logger = loggerService.withContext('MCP-utils')
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
@ -30,7 +30,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
|
||||
for (const mcpTool of mcpTools) {
|
||||
tools[mcpTool.name] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params) => {
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user