mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 05:39:05 +08:00
refactor: reorganize provider and model exports for improved structure
- Updated exports in index.ts and related files to streamline provider and model management. - Introduced a new ModelCreator module for better encapsulation of model creation logic. - Refactored type imports to enhance clarity and maintainability across the codebase. - Removed deprecated provider configurations and cleaned up unused code for better performance.
This commit is contained in:
parent
c3ad18b77e
commit
1248e3c49a
@ -8,17 +8,17 @@ export type { NamedMiddleware } from './middleware'
|
|||||||
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
||||||
|
|
||||||
// 创建管理
|
// 创建管理
|
||||||
export type { ModelConfig } from './models'
|
|
||||||
export {
|
export {
|
||||||
createBaseModel,
|
createBaseModel,
|
||||||
createImageModel,
|
createImageModel,
|
||||||
createModel,
|
createModel,
|
||||||
getProviderInfo,
|
getProviderInfo,
|
||||||
getSupportedProviders,
|
getSupportedProviders,
|
||||||
ProviderCreationError
|
ModelCreationError
|
||||||
} from './models'
|
} from './models'
|
||||||
|
export type { ModelConfig } from './models/types'
|
||||||
|
|
||||||
// 执行管理
|
// 执行管理
|
||||||
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
||||||
export type { ExecutionOptions, ExecutorConfig } from './runtime'
|
|
||||||
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||||
|
export type { RuntimeConfig } from './runtime/types'
|
||||||
|
|||||||
@ -1,22 +1,23 @@
|
|||||||
/**
|
/**
|
||||||
* Provider 创建器
|
* Model Creator
|
||||||
* 负责动态导入 AI SDK providers 并创建基础模型实例
|
* 负责基于 Provider 创建 AI SDK 的 Language Model 和 Image Model 实例
|
||||||
*/
|
*/
|
||||||
import { ImageModelV2, type LanguageModelV2 } from '@ai-sdk/provider'
|
import { ImageModelV2, type LanguageModelV2 } from '@ai-sdk/provider'
|
||||||
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
|
||||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||||
import { aiProviderRegistry, type ProviderConfig } from '../providers/registry'
|
import { createImageProvider, createProvider } from '../providers/creator'
|
||||||
|
import { aiProviderRegistry } from '../providers/registry'
|
||||||
|
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||||
|
|
||||||
// 错误类型
|
// 错误类型
|
||||||
export class ProviderCreationError extends Error {
|
export class ModelCreationError extends Error {
|
||||||
constructor(
|
constructor(
|
||||||
message: string,
|
message: string,
|
||||||
public providerId?: string,
|
public providerId?: string,
|
||||||
public cause?: Error
|
public cause?: Error
|
||||||
) {
|
) {
|
||||||
super(message)
|
super(message)
|
||||||
this.name = 'ProviderCreationError'
|
this.name = 'ModelCreationError'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,13 +30,11 @@ export async function createBaseModel<T extends ProviderId>({
|
|||||||
modelId,
|
modelId,
|
||||||
providerSettings,
|
providerSettings,
|
||||||
extraModelConfig
|
extraModelConfig
|
||||||
// middlewares
|
|
||||||
}: {
|
}: {
|
||||||
providerId: T
|
providerId: T
|
||||||
modelId: string
|
modelId: string
|
||||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||||
extraModelConfig?: any
|
extraModelConfig?: any
|
||||||
// middlewares?: LanguageModelV1Middleware[]
|
|
||||||
}): Promise<LanguageModelV2>
|
}): Promise<LanguageModelV2>
|
||||||
|
|
||||||
export async function createBaseModel({
|
export async function createBaseModel({
|
||||||
@ -43,87 +42,49 @@ export async function createBaseModel({
|
|||||||
modelId,
|
modelId,
|
||||||
providerSettings,
|
providerSettings,
|
||||||
extraModelConfig
|
extraModelConfig
|
||||||
// middlewares
|
|
||||||
}: {
|
}: {
|
||||||
providerId: string
|
providerId: string
|
||||||
modelId: string
|
modelId: string
|
||||||
providerSettings: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' }
|
providerSettings: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' }
|
||||||
extraModelConfig?: any
|
extraModelConfig?: any
|
||||||
// middlewares?: LanguageModelV1Middleware[]
|
|
||||||
}): Promise<LanguageModelV2>
|
}): Promise<LanguageModelV2>
|
||||||
|
|
||||||
export async function createBaseModel({
|
export async function createBaseModel({
|
||||||
providerId,
|
providerId,
|
||||||
modelId,
|
modelId,
|
||||||
providerSettings,
|
providerSettings,
|
||||||
// middlewares,
|
|
||||||
extraModelConfig
|
extraModelConfig
|
||||||
}: {
|
}: {
|
||||||
providerId: string
|
providerId: string
|
||||||
modelId: string
|
modelId: string
|
||||||
providerSettings: ProviderSettingsMap[ProviderId] & { mode?: 'chat' | 'responses' }
|
providerSettings: ProviderSettingsMap[ProviderId] & { mode?: 'chat' | 'responses' }
|
||||||
// middlewares?: LanguageModelV1Middleware[]
|
|
||||||
extraModelConfig?: any
|
extraModelConfig?: any
|
||||||
}): Promise<LanguageModelV2> {
|
}): Promise<LanguageModelV2> {
|
||||||
try {
|
try {
|
||||||
// 对于不在注册表中的 provider,默认使用 openai-compatible
|
|
||||||
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
|
|
||||||
|
|
||||||
// 获取Provider配置
|
// 获取Provider配置
|
||||||
const providerConfig = aiProviderRegistry.getProvider(effectiveProviderId)
|
const providerConfig = aiProviderRegistry.getProvider(providerId)
|
||||||
if (!providerConfig) {
|
if (!providerConfig) {
|
||||||
throw new ProviderCreationError(`Provider "${effectiveProviderId}" is not registered`, providerId)
|
throw new ModelCreationError(`Provider "${providerId}" is not registered`, providerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 动态导入模块
|
// 创建 provider 实例
|
||||||
const module = await providerConfig.import()
|
const provider = await createProvider(providerConfig, providerSettings)
|
||||||
|
|
||||||
// 获取创建函数
|
// 根据 provider 类型处理特殊逻辑
|
||||||
const creatorFunction = module[providerConfig.creatorFunctionName]
|
const finalProvider = handleProviderSpecificLogic(provider, providerConfig.id, providerSettings, modelId)
|
||||||
|
|
||||||
if (typeof creatorFunction !== 'function') {
|
|
||||||
throw new ProviderCreationError(
|
|
||||||
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${effectiveProviderId}"`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
// TODO: 对openai 的 providerSettings.mode参数是否要删除,目前看没毛病
|
|
||||||
// 创建provider实例
|
|
||||||
let provider = creatorFunction(providerSettings)
|
|
||||||
|
|
||||||
// 加一个特判
|
|
||||||
if (providerConfig.id === 'openai') {
|
|
||||||
if (
|
|
||||||
'mode' in providerSettings &&
|
|
||||||
providerSettings.mode === 'responses' &&
|
|
||||||
!isOpenAIChatCompletionOnlyModel(modelId)
|
|
||||||
) {
|
|
||||||
provider = provider.responses
|
|
||||||
} else {
|
|
||||||
provider = provider.chat
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 返回模型实例
|
|
||||||
if (typeof provider === 'function') {
|
|
||||||
// extraModelConfig:例如google的useSearchGrounding
|
|
||||||
const model: LanguageModelV2 = provider(modelId, extraModelConfig)
|
|
||||||
|
|
||||||
// // 应用 AI SDK 中间件
|
|
||||||
// if (middlewares && middlewares.length > 0) {
|
|
||||||
// model = wrapLanguageModel({
|
|
||||||
// model: model,
|
|
||||||
// middleware: middlewares
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
// 创建模型实例
|
||||||
|
if (typeof finalProvider === 'function') {
|
||||||
|
const model: LanguageModelV2 = finalProvider(modelId, extraModelConfig)
|
||||||
return model
|
return model
|
||||||
} else {
|
} else {
|
||||||
throw new ProviderCreationError(`Unknown model access pattern for provider "${effectiveProviderId}"`)
|
throw new ModelCreationError(`Unknown model access pattern for provider "${providerId}"`)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof ProviderCreationError) {
|
if (error instanceof ModelCreationError) {
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
throw new ProviderCreationError(
|
throw new ModelCreationError(
|
||||||
`Failed to create base model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
`Failed to create base model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
providerId,
|
providerId,
|
||||||
error instanceof Error ? error : undefined
|
error instanceof Error ? error : undefined
|
||||||
@ -131,6 +92,27 @@ export async function createBaseModel({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理特定 Provider 的逻辑
|
||||||
|
*/
|
||||||
|
function handleProviderSpecificLogic(provider: any, providerId: string, providerSettings: any, modelId: string): any {
|
||||||
|
// OpenAI 特殊处理
|
||||||
|
if (providerId === 'openai') {
|
||||||
|
if (
|
||||||
|
'mode' in providerSettings &&
|
||||||
|
providerSettings.mode === 'responses' &&
|
||||||
|
!isOpenAIChatCompletionOnlyModel(modelId)
|
||||||
|
) {
|
||||||
|
return provider.responses
|
||||||
|
} else {
|
||||||
|
return provider.chat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 其他 provider 直接返回
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 创建图像生成模型实例
|
* 创建图像生成模型实例
|
||||||
*/
|
*/
|
||||||
@ -151,40 +133,31 @@ export async function createImageModel(
|
|||||||
): Promise<ImageModelV2> {
|
): Promise<ImageModelV2> {
|
||||||
try {
|
try {
|
||||||
if (!aiProviderRegistry.isSupported(providerId)) {
|
if (!aiProviderRegistry.isSupported(providerId)) {
|
||||||
throw new ProviderCreationError(`Provider "${providerId}" is not supported`, providerId)
|
throw new ModelCreationError(`Provider "${providerId}" is not supported`, providerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
const providerConfig = aiProviderRegistry.getProvider(providerId)
|
const providerConfig = aiProviderRegistry.getProvider(providerId)
|
||||||
if (!providerConfig) {
|
if (!providerConfig) {
|
||||||
throw new ProviderCreationError(`Provider "${providerId}" is not registered`, providerId)
|
throw new ModelCreationError(`Provider "${providerId}" is not registered`, providerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!providerConfig.supportsImageGeneration) {
|
if (!providerConfig.supportsImageGeneration) {
|
||||||
throw new ProviderCreationError(`Provider "${providerId}" does not support image generation`, providerId)
|
throw new ModelCreationError(`Provider "${providerId}" does not support image generation`, providerId)
|
||||||
}
|
}
|
||||||
|
|
||||||
const module = await providerConfig.import()
|
// 创建图像 provider 实例
|
||||||
|
const provider = await createImageProvider(providerConfig, options)
|
||||||
const creatorFunction = module[providerConfig.creatorFunctionName]
|
|
||||||
|
|
||||||
if (typeof creatorFunction !== 'function') {
|
|
||||||
throw new ProviderCreationError(
|
|
||||||
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const provider = creatorFunction(options)
|
|
||||||
|
|
||||||
if (provider && typeof provider.image === 'function') {
|
if (provider && typeof provider.image === 'function') {
|
||||||
return provider.image(modelId)
|
return provider.image(modelId)
|
||||||
} else {
|
} else {
|
||||||
throw new ProviderCreationError(`Image model function not found for provider "${providerId}"`)
|
throw new ModelCreationError(`Image model function not found for provider "${providerId}"`)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof ProviderCreationError) {
|
if (error instanceof ModelCreationError) {
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
throw new ProviderCreationError(
|
throw new ModelCreationError(
|
||||||
`Failed to create image model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
`Failed to create image model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
providerId,
|
providerId,
|
||||||
error instanceof Error ? error : undefined
|
error instanceof Error ? error : undefined
|
||||||
@ -199,7 +172,7 @@ export function getSupportedProviders(): Array<{
|
|||||||
id: string
|
id: string
|
||||||
name: string
|
name: string
|
||||||
}> {
|
}> {
|
||||||
return aiProviderRegistry.getAllProviders().map((provider: ProviderConfig) => ({
|
return aiProviderRegistry.getAllProviders().map((provider) => ({
|
||||||
id: provider.id,
|
id: provider.id,
|
||||||
name: provider.name
|
name: provider.name
|
||||||
}))
|
}))
|
||||||
@ -6,7 +6,7 @@ import { LanguageModelV2 } from '@ai-sdk/provider'
|
|||||||
import { LanguageModel } from 'ai'
|
import { LanguageModel } from 'ai'
|
||||||
|
|
||||||
import { wrapModelWithMiddlewares } from '../middleware'
|
import { wrapModelWithMiddlewares } from '../middleware'
|
||||||
import { createBaseModel } from './ProviderCreator'
|
import { createBaseModel } from './ModelCreator'
|
||||||
import { ModelConfig } from './types'
|
import { ModelConfig } from './types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,19 +1,18 @@
|
|||||||
/**
|
/**
|
||||||
* Models 模块导出
|
* Models 模块统一导出
|
||||||
* 提供统一的模型创建和配置管理能力
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// 主要的模型创建API
|
// Model 创建相关
|
||||||
export { createModel, createModels } from './factory'
|
|
||||||
|
|
||||||
// 底层Provider创建功能(供高级用户使用)
|
|
||||||
export {
|
export {
|
||||||
createBaseModel,
|
createBaseModel,
|
||||||
createImageModel,
|
createImageModel,
|
||||||
getProviderInfo,
|
getProviderInfo,
|
||||||
getSupportedProviders,
|
getSupportedProviders,
|
||||||
ProviderCreationError
|
ModelCreationError
|
||||||
} from './ProviderCreator'
|
} from './ModelCreator'
|
||||||
|
|
||||||
// 保留原有类型
|
// Model 配置和工厂
|
||||||
export type { ModelConfig } from './types'
|
export { createModel } from './factory'
|
||||||
|
|
||||||
|
// 类型定义
|
||||||
|
export type { ModelConfig as ModelConfigType } from './types'
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
|
||||||
import type { ProviderId, ProviderSettingsMap } from '../../types'
|
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
|
||||||
|
|
||||||
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||||
providerId: T
|
providerId: T
|
||||||
|
|||||||
@ -41,8 +41,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
case 'google':
|
case 'google': {
|
||||||
case 'google-vertex': {
|
// case 'google-vertex':
|
||||||
if (!params.tools) params.tools = {}
|
if (!params.tools) params.tools = {}
|
||||||
params.tools.web_search = google.tools.googleSearch(config.google || {})
|
params.tools.web_search = google.tools.googleSearch(config.google || {})
|
||||||
break
|
break
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||||
|
|
||||||
import { ProviderId } from '../providers/registry'
|
import { type ProviderId } from '../providers/types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 递归调用函数类型
|
* 递归调用函数类型
|
||||||
|
|||||||
108
packages/aiCore/src/core/providers/creator.ts
Normal file
108
packages/aiCore/src/core/providers/creator.ts
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
/**
|
||||||
|
* Provider Creator
|
||||||
|
* 负责根据 ProviderConfig 创建 provider 实例
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { type ProviderConfig } from './types'
|
||||||
|
|
||||||
|
// 错误类型
|
||||||
|
export class ProviderCreationError extends Error {
|
||||||
|
constructor(
|
||||||
|
message: string,
|
||||||
|
public providerId?: string,
|
||||||
|
public cause?: Error
|
||||||
|
) {
|
||||||
|
super(message)
|
||||||
|
this.name = 'ProviderCreationError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Provider 实例
|
||||||
|
* 支持两种模式:直接提供 creator 函数,或动态导入 + 函数名
|
||||||
|
*/
|
||||||
|
export async function createProvider(config: ProviderConfig, options: any): Promise<any> {
|
||||||
|
try {
|
||||||
|
// 验证配置
|
||||||
|
if (!config.creator && !(config.import && config.creatorFunctionName)) {
|
||||||
|
throw new ProviderCreationError(
|
||||||
|
'Invalid provider configuration: must provide either creator function or import configuration',
|
||||||
|
config.id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式一:直接使用 creator 函数
|
||||||
|
if (config.creator) {
|
||||||
|
return config.creator(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式二:动态导入 + 函数名
|
||||||
|
if (config.import && config.creatorFunctionName) {
|
||||||
|
const module = await config.import()
|
||||||
|
const creatorFunction = module[config.creatorFunctionName]
|
||||||
|
|
||||||
|
if (typeof creatorFunction !== 'function') {
|
||||||
|
throw new ProviderCreationError(
|
||||||
|
`Creator function "${config.creatorFunctionName}" not found in the imported module`,
|
||||||
|
config.id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return creatorFunction(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new ProviderCreationError('Unexpected provider configuration state', config.id)
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof ProviderCreationError) {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
throw new ProviderCreationError(
|
||||||
|
`Failed to create provider "${config.id}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
|
config.id,
|
||||||
|
error instanceof Error ? error : undefined
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建图像生成 Provider 实例
|
||||||
|
*/
|
||||||
|
export async function createImageProvider(config: ProviderConfig, options: any): Promise<any> {
|
||||||
|
try {
|
||||||
|
if (!config.supportsImageGeneration) {
|
||||||
|
throw new ProviderCreationError(`Provider "${config.id}" does not support image generation`, config.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有专门的图像 creator
|
||||||
|
if (config.imageCreator) {
|
||||||
|
return config.imageCreator(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则使用普通的 provider 创建流程
|
||||||
|
return await createProvider(config, options)
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof ProviderCreationError) {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
throw new ProviderCreationError(
|
||||||
|
`Failed to create image provider "${config.id}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
|
config.id,
|
||||||
|
error instanceof Error ? error : undefined
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 Provider 配置
|
||||||
|
*/
|
||||||
|
export function validateProviderConfig(config: ProviderConfig): boolean {
|
||||||
|
if (!config.id || !config.name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.creator && !(config.import && config.creatorFunctionName)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
@ -3,8 +3,7 @@
|
|||||||
* 提供类型安全的 Provider 配置构建器
|
* 提供类型安全的 Provider 配置构建器
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { ProviderId, ProviderSettingsMap } from './registry'
|
import type { ProviderId, ProviderSettingsMap } from './types'
|
||||||
import { formatPrivateKey } from './utils'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 通用配置基础类型,包含所有 Provider 共有的属性
|
* 通用配置基础类型,包含所有 Provider 共有的属性
|
||||||
@ -37,20 +36,20 @@ const configHandlers: {
|
|||||||
apiVersion: azureProvider.apiVersion,
|
apiVersion: azureProvider.apiVersion,
|
||||||
resourceName: azureProvider.resourceName
|
resourceName: azureProvider.resourceName
|
||||||
})
|
})
|
||||||
},
|
|
||||||
'google-vertex': (builder, provider) => {
|
|
||||||
const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
|
|
||||||
const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
|
|
||||||
vertexBuilder
|
|
||||||
.withGoogleVertexConfig({
|
|
||||||
project: vertexProvider.project,
|
|
||||||
location: vertexProvider.location
|
|
||||||
})
|
|
||||||
.withGoogleCredentials({
|
|
||||||
clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
|
|
||||||
privateKey: vertexProvider.googleCredentials?.privateKey || ''
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
// 'google-vertex': (builder, provider) => {
|
||||||
|
// const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
|
||||||
|
// const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
|
||||||
|
// vertexBuilder
|
||||||
|
// .withGoogleVertexConfig({
|
||||||
|
// project: vertexProvider.project,
|
||||||
|
// location: vertexProvider.location
|
||||||
|
// })
|
||||||
|
// .withGoogleCredentials({
|
||||||
|
// clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
|
||||||
|
// privateKey: vertexProvider.googleCredentials?.privateKey || ''
|
||||||
|
// })
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||||
@ -121,35 +120,36 @@ export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
|||||||
/**
|
/**
|
||||||
* Google 特定配置
|
* Google 特定配置
|
||||||
*/
|
*/
|
||||||
withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
|
// withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
|
||||||
withGoogleVertexConfig(options: any): any {
|
// withGoogleVertexConfig(options: any): any {
|
||||||
if (this.providerId === 'google-vertex') {
|
// if (this.providerId === 'google-vertex') {
|
||||||
const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
// const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
||||||
if (options.project) {
|
// if (options.project) {
|
||||||
googleConfig.project = options.project
|
// googleConfig.project = options.project
|
||||||
}
|
// }
|
||||||
if (options.location) {
|
// if (options.location) {
|
||||||
googleConfig.location = options.location
|
// googleConfig.location = options.location
|
||||||
if (options.location === 'global') {
|
// if (options.location === 'global') {
|
||||||
googleConfig.baseURL = 'https://aiplatform.googleapis.com'
|
// googleConfig.baseURL = 'https://aiplatform.googleapis.com'
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
return this
|
// return this
|
||||||
}
|
// }
|
||||||
|
|
||||||
withGoogleCredentials(credentials: {
|
withGoogleCredentials(credentials: {
|
||||||
clientEmail: string
|
clientEmail: string
|
||||||
privateKey: string
|
privateKey: string
|
||||||
}): T extends 'google-vertex' ? this : never
|
}): T extends 'google-vertex' ? this : never
|
||||||
withGoogleCredentials(credentials: any): any {
|
withGoogleCredentials(): any {
|
||||||
if (this.providerId === 'google-vertex') {
|
// withGoogleCredentials(credentials: any): any {
|
||||||
const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
// if (this.providerId === 'google-vertex') {
|
||||||
vertexConfig.googleCredentials = {
|
// const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
|
||||||
clientEmail: credentials.clientEmail,
|
// vertexConfig.googleCredentials = {
|
||||||
privateKey: formatPrivateKey(credentials.privateKey)
|
// clientEmail: credentials.clientEmail,
|
||||||
}
|
// privateKey: formatPrivateKey(credentials.privateKey)
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -310,23 +310,22 @@ export class ProviderConfigFactory {
|
|||||||
/**
|
/**
|
||||||
* 快速创建 Vertex AI 配置
|
* 快速创建 Vertex AI 配置
|
||||||
*/
|
*/
|
||||||
static createVertexAI(
|
static createVertexAI() {
|
||||||
credentials: {
|
// credentials: {
|
||||||
clientEmail: string
|
// clientEmail: string
|
||||||
privateKey: string
|
// privateKey: string
|
||||||
},
|
// },
|
||||||
options?: {
|
// options?: {
|
||||||
project?: string
|
// project?: string
|
||||||
location?: string
|
// location?: string
|
||||||
}
|
// }
|
||||||
) {
|
// return this.builder('google-vertex')
|
||||||
return this.builder('google-vertex')
|
// .withGoogleCredentials(credentials)
|
||||||
.withGoogleCredentials(credentials)
|
// .withGoogleVertexConfig({
|
||||||
.withGoogleVertexConfig({
|
// project: options?.project,
|
||||||
project: options?.project,
|
// location: options?.location
|
||||||
location: options?.location
|
// })
|
||||||
})
|
// .build()
|
||||||
.build()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
||||||
|
|||||||
15
packages/aiCore/src/core/providers/index.ts
Normal file
15
packages/aiCore/src/core/providers/index.ts
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
/**
|
||||||
|
* Providers 模块统一导出
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Provider 注册表
|
||||||
|
export { aiProviderRegistry, getAllProviders, getProvider, isProviderSupported, registerProvider } from './registry'
|
||||||
|
|
||||||
|
// Provider 创建
|
||||||
|
export { createImageProvider, createProvider, ProviderCreationError, validateProviderConfig } from './creator'
|
||||||
|
|
||||||
|
// 类型定义
|
||||||
|
export type { ProviderConfig, ProviderError, ProviderId, ProviderSettingsMap } from './types'
|
||||||
|
|
||||||
|
// 工厂和配置
|
||||||
|
export * from './factory'
|
||||||
@ -3,76 +3,8 @@
|
|||||||
* 静态类型 + 动态导入模式:所有类型静态导入,所有实现动态导入
|
* 静态类型 + 动态导入模式:所有类型静态导入,所有实现动态导入
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// 静态导入所有 AI SDK 类型
|
import { type ProviderConfig } from './types'
|
||||||
import { type AmazonBedrockProviderSettings } from '@ai-sdk/amazon-bedrock'
|
|
||||||
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
|
||||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
|
||||||
import { type CerebrasProviderSettings } from '@ai-sdk/cerebras'
|
|
||||||
import { type CohereProviderSettings } from '@ai-sdk/cohere'
|
|
||||||
import { type DeepInfraProviderSettings } from '@ai-sdk/deepinfra'
|
|
||||||
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
|
||||||
import { type FalProviderSettings } from '@ai-sdk/fal'
|
|
||||||
import { type FireworksProviderSettings } from '@ai-sdk/fireworks'
|
|
||||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
|
||||||
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
|
|
||||||
import { type GroqProviderSettings } from '@ai-sdk/groq'
|
|
||||||
import { type MistralProviderSettings } from '@ai-sdk/mistral'
|
|
||||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
|
||||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
|
||||||
import { type PerplexityProviderSettings } from '@ai-sdk/perplexity'
|
|
||||||
import { type ReplicateProviderSettings } from '@ai-sdk/replicate'
|
|
||||||
import { type TogetherAIProviderSettings } from '@ai-sdk/togetherai'
|
|
||||||
import { type VercelProviderSettings } from '@ai-sdk/vercel'
|
|
||||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
|
||||||
import { type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
|
|
||||||
import { type AnthropicVertexProviderSettings } from 'anthropic-vertex-ai'
|
|
||||||
import { type OllamaProviderSettings } from 'ollama-ai-provider'
|
|
||||||
|
|
||||||
// 类型安全的 Provider Settings 映射
|
|
||||||
export type ProviderSettingsMap = {
|
|
||||||
openai: OpenAIProviderSettings
|
|
||||||
'openai-compatible': OpenAICompatibleProviderSettings
|
|
||||||
openrouter: OpenRouterProviderSettings
|
|
||||||
anthropic: AnthropicProviderSettings
|
|
||||||
google: GoogleGenerativeAIProviderSettings
|
|
||||||
'google-vertex': GoogleVertexProviderSettings
|
|
||||||
mistral: MistralProviderSettings
|
|
||||||
xai: XaiProviderSettings
|
|
||||||
azure: AzureOpenAIProviderSettings
|
|
||||||
bedrock: AmazonBedrockProviderSettings
|
|
||||||
cohere: CohereProviderSettings
|
|
||||||
groq: GroqProviderSettings
|
|
||||||
together: TogetherAIProviderSettings
|
|
||||||
fireworks: FireworksProviderSettings
|
|
||||||
deepseek: DeepSeekProviderSettings
|
|
||||||
cerebras: CerebrasProviderSettings
|
|
||||||
deepinfra: DeepInfraProviderSettings
|
|
||||||
replicate: ReplicateProviderSettings
|
|
||||||
perplexity: PerplexityProviderSettings
|
|
||||||
fal: FalProviderSettings
|
|
||||||
vercel: VercelProviderSettings
|
|
||||||
ollama: OllamaProviderSettings
|
|
||||||
'anthropic-vertex': AnthropicVertexProviderSettings
|
|
||||||
}
|
|
||||||
|
|
||||||
export type ProviderId = keyof ProviderSettingsMap & string
|
|
||||||
|
|
||||||
// 统一的 Provider 配置接口(所有都使用动态导入)
|
|
||||||
export interface ProviderConfig {
|
|
||||||
id: string
|
|
||||||
name: string
|
|
||||||
// 动态导入函数
|
|
||||||
import: () => Promise<any>
|
|
||||||
// 创建函数名称
|
|
||||||
creatorFunctionName: string
|
|
||||||
// 是否支持图片生成
|
|
||||||
supportsImageGeneration?: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* AI SDK Provider 注册表
|
|
||||||
* 管理所有支持的 AI Providers 及其动态导入
|
|
||||||
*/
|
|
||||||
export class AiProviderRegistry {
|
export class AiProviderRegistry {
|
||||||
private static instance: AiProviderRegistry
|
private static instance: AiProviderRegistry
|
||||||
private registry = new Map<string, ProviderConfig>()
|
private registry = new Map<string, ProviderConfig>()
|
||||||
@ -123,20 +55,20 @@ export class AiProviderRegistry {
|
|||||||
creatorFunctionName: 'createGoogleGenerativeAI',
|
creatorFunctionName: 'createGoogleGenerativeAI',
|
||||||
supportsImageGeneration: true
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
// {
|
||||||
id: 'google-vertex',
|
// id: 'google-vertex',
|
||||||
name: 'Google Vertex AI',
|
// name: 'Google Vertex AI',
|
||||||
import: () => import('@ai-sdk/google-vertex/edge'),
|
// import: () => import('@ai-sdk/google-vertex/edge'),
|
||||||
creatorFunctionName: 'createVertex',
|
// creatorFunctionName: 'createVertex',
|
||||||
supportsImageGeneration: true
|
// supportsImageGeneration: true
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'mistral',
|
// id: 'mistral',
|
||||||
name: 'Mistral AI',
|
// name: 'Mistral AI',
|
||||||
import: () => import('@ai-sdk/mistral'),
|
// import: () => import('@ai-sdk/mistral'),
|
||||||
creatorFunctionName: 'createMistral',
|
// creatorFunctionName: 'createMistral',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
{
|
||||||
id: 'xai',
|
id: 'xai',
|
||||||
name: 'xAI (Grok)',
|
name: 'xAI (Grok)',
|
||||||
@ -151,112 +83,112 @@ export class AiProviderRegistry {
|
|||||||
creatorFunctionName: 'createAzure',
|
creatorFunctionName: 'createAzure',
|
||||||
supportsImageGeneration: true
|
supportsImageGeneration: true
|
||||||
},
|
},
|
||||||
{
|
// {
|
||||||
id: 'bedrock',
|
// id: 'bedrock',
|
||||||
name: 'Amazon Bedrock',
|
// name: 'Amazon Bedrock',
|
||||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
// import: () => import('@ai-sdk/amazon-bedrock'),
|
||||||
creatorFunctionName: 'createAmazonBedrock',
|
// creatorFunctionName: 'createAmazonBedrock',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'cohere',
|
// id: 'cohere',
|
||||||
name: 'Cohere',
|
// name: 'Cohere',
|
||||||
import: () => import('@ai-sdk/cohere'),
|
// import: () => import('@ai-sdk/cohere'),
|
||||||
creatorFunctionName: 'createCohere',
|
// creatorFunctionName: 'createCohere',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'groq',
|
// id: 'groq',
|
||||||
name: 'Groq',
|
// name: 'Groq',
|
||||||
import: () => import('@ai-sdk/groq'),
|
// import: () => import('@ai-sdk/groq'),
|
||||||
creatorFunctionName: 'createGroq',
|
// creatorFunctionName: 'createGroq',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'together',
|
// id: 'together',
|
||||||
name: 'Together.ai',
|
// name: 'Together.ai',
|
||||||
import: () => import('@ai-sdk/togetherai'),
|
// import: () => import('@ai-sdk/togetherai'),
|
||||||
creatorFunctionName: 'createTogetherAI',
|
// creatorFunctionName: 'createTogetherAI',
|
||||||
supportsImageGeneration: true
|
// supportsImageGeneration: true
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'fireworks',
|
// id: 'fireworks',
|
||||||
name: 'Fireworks',
|
// name: 'Fireworks',
|
||||||
import: () => import('@ai-sdk/fireworks'),
|
// import: () => import('@ai-sdk/fireworks'),
|
||||||
creatorFunctionName: 'createFireworks',
|
// creatorFunctionName: 'createFireworks',
|
||||||
supportsImageGeneration: true
|
// supportsImageGeneration: true
|
||||||
},
|
// },
|
||||||
{
|
{
|
||||||
id: 'deepseek',
|
id: 'deepseek',
|
||||||
name: 'DeepSeek',
|
name: 'DeepSeek',
|
||||||
import: () => import('@ai-sdk/deepseek'),
|
import: () => import('@ai-sdk/deepseek'),
|
||||||
creatorFunctionName: 'createDeepSeek',
|
creatorFunctionName: 'createDeepSeek',
|
||||||
supportsImageGeneration: false
|
supportsImageGeneration: false
|
||||||
},
|
}
|
||||||
{
|
// {
|
||||||
id: 'cerebras',
|
// id: 'cerebras',
|
||||||
name: 'Cerebras',
|
// name: 'Cerebras',
|
||||||
import: () => import('@ai-sdk/cerebras'),
|
// import: () => import('@ai-sdk/cerebras'),
|
||||||
creatorFunctionName: 'createCerebras',
|
// creatorFunctionName: 'createCerebras',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'deepinfra',
|
// id: 'deepinfra',
|
||||||
name: 'DeepInfra',
|
// name: 'DeepInfra',
|
||||||
import: () => import('@ai-sdk/deepinfra'),
|
// import: () => import('@ai-sdk/deepinfra'),
|
||||||
creatorFunctionName: 'createDeepInfra',
|
// creatorFunctionName: 'createDeepInfra',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'replicate',
|
// id: 'replicate',
|
||||||
name: 'Replicate',
|
// name: 'Replicate',
|
||||||
import: () => import('@ai-sdk/replicate'),
|
// import: () => import('@ai-sdk/replicate'),
|
||||||
creatorFunctionName: 'createReplicate',
|
// creatorFunctionName: 'createReplicate',
|
||||||
supportsImageGeneration: true
|
// supportsImageGeneration: true
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'perplexity',
|
// id: 'perplexity',
|
||||||
name: 'Perplexity',
|
// name: 'Perplexity',
|
||||||
import: () => import('@ai-sdk/perplexity'),
|
// import: () => import('@ai-sdk/perplexity'),
|
||||||
creatorFunctionName: 'createPerplexity',
|
// creatorFunctionName: 'createPerplexity',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'fal',
|
// id: 'fal',
|
||||||
name: 'Fal AI',
|
// name: 'Fal AI',
|
||||||
import: () => import('@ai-sdk/fal'),
|
// import: () => import('@ai-sdk/fal'),
|
||||||
creatorFunctionName: 'createFal',
|
// creatorFunctionName: 'createFal',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'vercel',
|
// id: 'vercel',
|
||||||
name: 'Vercel',
|
// name: 'Vercel',
|
||||||
import: () => import('@ai-sdk/vercel'),
|
// import: () => import('@ai-sdk/vercel'),
|
||||||
creatorFunctionName: 'createVercel'
|
// creatorFunctionName: 'createVercel'
|
||||||
},
|
// },
|
||||||
|
|
||||||
// 社区 Providers (5个)
|
// 社区 Providers (5个)
|
||||||
{
|
// {
|
||||||
id: 'ollama',
|
// id: 'ollama',
|
||||||
name: 'Ollama',
|
// name: 'Ollama',
|
||||||
import: () => import('ollama-ai-provider'),
|
// import: () => import('ollama-ai-provider'),
|
||||||
creatorFunctionName: 'createOllama',
|
// creatorFunctionName: 'createOllama',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'anthropic-vertex',
|
// id: 'anthropic-vertex',
|
||||||
name: 'Anthropic Vertex AI',
|
// name: 'Anthropic Vertex AI',
|
||||||
import: () => import('anthropic-vertex-ai'),
|
// import: () => import('anthropic-vertex-ai'),
|
||||||
creatorFunctionName: 'createAnthropicVertex',
|
// creatorFunctionName: 'createAnthropicVertex',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
id: 'openrouter',
|
// id: 'openrouter',
|
||||||
name: 'OpenRouter',
|
// name: 'OpenRouter',
|
||||||
import: () => import('@openrouter/ai-sdk-provider'),
|
// import: () => import('@openrouter/ai-sdk-provider'),
|
||||||
creatorFunctionName: 'createOpenRouter',
|
// creatorFunctionName: 'createOpenRouter',
|
||||||
supportsImageGeneration: false
|
// supportsImageGeneration: false
|
||||||
}
|
// }
|
||||||
]
|
]
|
||||||
|
|
||||||
providers.forEach((config) => {
|
providers.forEach((config) => {
|
||||||
@ -289,6 +221,16 @@ export class AiProviderRegistry {
|
|||||||
* 注册新的 Provider(用于扩展)
|
* 注册新的 Provider(用于扩展)
|
||||||
*/
|
*/
|
||||||
public registerProvider(config: ProviderConfig): void {
|
public registerProvider(config: ProviderConfig): void {
|
||||||
|
// 验证:必须提供 creator 或 (import + creatorFunctionName)
|
||||||
|
if (!config.creator && !(config.import && config.creatorFunctionName)) {
|
||||||
|
throw new Error('Must provide either creator function or import configuration')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证:不能同时提供两种方式
|
||||||
|
if (config.creator && config.import) {
|
||||||
|
console.warn('Both creator and import provided, creator will take precedence')
|
||||||
|
}
|
||||||
|
|
||||||
this.registry.set(config.id, config)
|
this.registry.set(config.id, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,21 +241,21 @@ export class AiProviderRegistry {
|
|||||||
this.registry.clear()
|
this.registry.clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// /**
|
||||||
* 获取兼容现有实现的注册表格式
|
// * 获取兼容现有实现的注册表格式
|
||||||
*/
|
// */
|
||||||
public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
|
// public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
|
||||||
const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
|
// const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
|
||||||
|
|
||||||
this.getAllProviders().forEach((provider) => {
|
// this.getAllProviders().forEach((provider) => {
|
||||||
compatibleRegistry[provider.id] = {
|
// compatibleRegistry[provider.id] = {
|
||||||
import: provider.import,
|
// import: provider.import,
|
||||||
creatorFunctionName: provider.creatorFunctionName
|
// creatorFunctionName: provider.creatorFunctionName
|
||||||
}
|
// }
|
||||||
})
|
// })
|
||||||
|
|
||||||
return compatibleRegistry
|
// return compatibleRegistry
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
// 导出单例实例
|
// 导出单例实例
|
||||||
@ -326,31 +268,4 @@ export const isProviderSupported = (id: string) => aiProviderRegistry.isSupporte
|
|||||||
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
|
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
|
||||||
|
|
||||||
// 兼容现有实现的导出
|
// 兼容现有实现的导出
|
||||||
export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
|
// export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
|
||||||
|
|
||||||
// 重新导出所有类型供外部使用
|
|
||||||
export type {
|
|
||||||
AmazonBedrockProviderSettings,
|
|
||||||
AnthropicProviderSettings,
|
|
||||||
AnthropicVertexProviderSettings,
|
|
||||||
AzureOpenAIProviderSettings,
|
|
||||||
CerebrasProviderSettings,
|
|
||||||
CohereProviderSettings,
|
|
||||||
DeepInfraProviderSettings,
|
|
||||||
DeepSeekProviderSettings,
|
|
||||||
FalProviderSettings,
|
|
||||||
FireworksProviderSettings,
|
|
||||||
GoogleGenerativeAIProviderSettings,
|
|
||||||
GoogleVertexProviderSettings,
|
|
||||||
GroqProviderSettings,
|
|
||||||
MistralProviderSettings,
|
|
||||||
OllamaProviderSettings,
|
|
||||||
OpenAICompatibleProviderSettings,
|
|
||||||
OpenAIProviderSettings,
|
|
||||||
OpenRouterProviderSettings,
|
|
||||||
PerplexityProviderSettings,
|
|
||||||
ReplicateProviderSettings,
|
|
||||||
TogetherAIProviderSettings,
|
|
||||||
VercelProviderSettings,
|
|
||||||
XaiProviderSettings
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,14 +1,35 @@
|
|||||||
|
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
||||||
|
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||||
|
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||||
|
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||||
|
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||||
|
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||||
|
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provider 相关核心类型定义
|
* Provider 相关核心类型定义
|
||||||
* 只定义必要的接口,其他类型直接使用 AI SDK
|
* 只定义必要的接口,其他类型直接使用 AI SDK
|
||||||
*/
|
*/
|
||||||
|
export type ProviderId = keyof ProviderSettingsMap & string
|
||||||
|
|
||||||
// Provider 配置接口(简化版)
|
// Provider 配置接口 - 支持灵活的创建方式
|
||||||
export interface ProviderConfig {
|
export interface ProviderConfig {
|
||||||
id: string
|
id: string
|
||||||
name: string
|
name: string
|
||||||
import: () => Promise<any>
|
|
||||||
creatorFunctionName: string
|
// 方式一:直接提供 creator 函数(推荐用于自定义)
|
||||||
|
creator?: (options: any) => any
|
||||||
|
|
||||||
|
// 方式二:动态导入 + 函数名(用于包导入)
|
||||||
|
import?: () => Promise<any>
|
||||||
|
creatorFunctionName?: string
|
||||||
|
|
||||||
|
// 图片生成支持
|
||||||
|
supportsImageGeneration?: boolean
|
||||||
|
imageCreator?: (options: any) => any
|
||||||
|
|
||||||
|
// 可选的验证函数
|
||||||
|
validateOptions?: (options: any) => boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
// API 客户端工厂接口
|
// API 客户端工厂接口
|
||||||
@ -45,3 +66,57 @@ export interface CacheStats {
|
|||||||
keys: string[]
|
keys: string[]
|
||||||
lastCleanup?: Date
|
lastCleanup?: Date
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 类型安全的 Provider Settings 映射
|
||||||
|
export type ProviderSettingsMap = {
|
||||||
|
openai: OpenAIProviderSettings
|
||||||
|
'openai-compatible': OpenAICompatibleProviderSettings
|
||||||
|
// openrouter: OpenRouterProviderSettings
|
||||||
|
anthropic: AnthropicProviderSettings
|
||||||
|
google: GoogleGenerativeAIProviderSettings
|
||||||
|
// 'google-vertex': GoogleVertexProviderSettings
|
||||||
|
// mistral: MistralProviderSettings
|
||||||
|
xai: XaiProviderSettings
|
||||||
|
azure: AzureOpenAIProviderSettings
|
||||||
|
// bedrock: AmazonBedrockProviderSettings
|
||||||
|
// cohere: CohereProviderSettings
|
||||||
|
// groq: GroqProviderSettings
|
||||||
|
// together: TogetherAIProviderSettings
|
||||||
|
// fireworks: FireworksProviderSettings
|
||||||
|
deepseek: DeepSeekProviderSettings
|
||||||
|
// cerebras: CerebrasProviderSettings
|
||||||
|
// deepinfra: DeepInfraProviderSettings
|
||||||
|
// replicate: ReplicateProviderSettings
|
||||||
|
// perplexity: PerplexityProviderSettings
|
||||||
|
// fal: FalProviderSettings
|
||||||
|
// vercel: VercelProviderSettings
|
||||||
|
// ollama: OllamaProviderSettings
|
||||||
|
// 'anthropic-vertex': AnthropicVertexProviderSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新导出所有类型供外部使用
|
||||||
|
export type {
|
||||||
|
// AmazonBedrockProviderSettings,
|
||||||
|
AnthropicProviderSettings,
|
||||||
|
// AnthropicVertexProviderSettings,
|
||||||
|
AzureOpenAIProviderSettings,
|
||||||
|
// CerebrasProviderSettings,
|
||||||
|
// CohereProviderSettings,
|
||||||
|
// DeepInfraProviderSettings,
|
||||||
|
DeepSeekProviderSettings,
|
||||||
|
// FalProviderSettings,
|
||||||
|
// FireworksProviderSettings,
|
||||||
|
GoogleGenerativeAIProviderSettings,
|
||||||
|
// GoogleVertexProviderSettings,
|
||||||
|
// GroqProviderSettings,
|
||||||
|
// MistralProviderSettings,
|
||||||
|
// OllamaProviderSettings,
|
||||||
|
OpenAICompatibleProviderSettings,
|
||||||
|
OpenAIProviderSettings,
|
||||||
|
// OpenRouterProviderSettings,
|
||||||
|
// PerplexityProviderSettings,
|
||||||
|
// ReplicateProviderSettings,
|
||||||
|
// TogetherAIProviderSettings,
|
||||||
|
// VercelProviderSettings,
|
||||||
|
XaiProviderSettings
|
||||||
|
}
|
||||||
|
|||||||
@ -7,19 +7,14 @@
|
|||||||
export { RuntimeExecutor } from './executor'
|
export { RuntimeExecutor } from './executor'
|
||||||
|
|
||||||
// 导出类型
|
// 导出类型
|
||||||
export type {
|
export type { RuntimeConfig } from './types'
|
||||||
ExecutionOptions,
|
|
||||||
// 向后兼容的类型别名
|
|
||||||
ExecutorConfig,
|
|
||||||
RuntimeConfig
|
|
||||||
} from './types'
|
|
||||||
|
|
||||||
// === 便捷工厂函数 ===
|
// === 便捷工厂函数 ===
|
||||||
|
|
||||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
|
||||||
import { type AiPlugin } from '../plugins'
|
import { type AiPlugin } from '../plugins'
|
||||||
|
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||||
import { RuntimeExecutor } from './executor'
|
import { RuntimeExecutor } from './executor'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import { LanguageModel } from 'ai'
|
import { LanguageModel } from 'ai'
|
||||||
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
|
||||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||||
import { isProviderSupported } from '../providers/registry'
|
import { isProviderSupported } from '../providers/registry'
|
||||||
|
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 插件增强的 AI 客户端
|
* 插件增强的 AI 客户端
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
/**
|
/**
|
||||||
* Runtime 层类型定义
|
* Runtime 层类型定义
|
||||||
*/
|
*/
|
||||||
import { type ProviderId } from '../../types'
|
import { type ModelConfig } from '../models/types'
|
||||||
import { type ModelConfig } from '../models'
|
|
||||||
import { type AiPlugin } from '../plugins'
|
import { type AiPlugin } from '../plugins'
|
||||||
|
import { type ProviderId } from '../providers/types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 运行时执行器配置
|
* 运行时执行器配置
|
||||||
@ -13,14 +13,3 @@ export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
|||||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||||
plugins?: AiPlugin[]
|
plugins?: AiPlugin[]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 执行选项
|
|
||||||
*/
|
|
||||||
export interface ExecutionOptions {
|
|
||||||
// 未来可以添加执行级别的选项
|
|
||||||
// 比如:超时设置、重试机制等
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保留旧类型以保持向后兼容
|
|
||||||
export interface ExecutorConfig<T extends ProviderId = ProviderId> extends RuntimeConfig<T> {}
|
|
||||||
|
|||||||
@ -9,14 +9,15 @@ import {
|
|||||||
getSupportedProviders as factoryGetSupportedProviders
|
getSupportedProviders as factoryGetSupportedProviders
|
||||||
} from './core/models'
|
} from './core/models'
|
||||||
import { aiProviderRegistry, isProviderSupported } from './core/providers/registry'
|
import { aiProviderRegistry, isProviderSupported } from './core/providers/registry'
|
||||||
|
import type { ProviderId } from './core/providers/types'
|
||||||
|
import type { ProviderSettingsMap } from './core/providers/types'
|
||||||
import { createExecutor } from './core/runtime'
|
import { createExecutor } from './core/runtime'
|
||||||
import type { ProviderId, ProviderSettingsMap } from './types'
|
|
||||||
|
|
||||||
// ==================== 主要用户接口 ====================
|
// ==================== 主要用户接口 ====================
|
||||||
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
|
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
|
||||||
|
|
||||||
// ==================== 高级API ====================
|
// ==================== 高级API ====================
|
||||||
export { createModel, type ModelConfig } from './core/models'
|
export { createModel } from './core/models'
|
||||||
|
|
||||||
// ==================== 插件系统 ====================
|
// ==================== 插件系统 ====================
|
||||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
|
||||||
@ -30,43 +31,27 @@ export {
|
|||||||
createImageModel,
|
createImageModel,
|
||||||
getProviderInfo as getClientInfo,
|
getProviderInfo as getClientInfo,
|
||||||
getSupportedProviders,
|
getSupportedProviders,
|
||||||
ProviderCreationError
|
ModelCreationError
|
||||||
} from './core/models'
|
} from './core/models'
|
||||||
export { aiProviderRegistry } from './core/providers/registry'
|
export { aiProviderRegistry } from './core/providers/registry'
|
||||||
|
|
||||||
// ==================== 类型定义 ====================
|
// ==================== 类型定义 ====================
|
||||||
export type { ProviderConfig } from './core/providers/registry'
|
export type { ProviderConfig } from './core/providers/types'
|
||||||
export type { ProviderError } from './core/providers/types'
|
export type { ProviderError } from './core/providers/types'
|
||||||
export type {
|
export type {
|
||||||
AmazonBedrockProviderSettings,
|
|
||||||
AnthropicProviderSettings,
|
AnthropicProviderSettings,
|
||||||
AnthropicVertexProviderSettings,
|
|
||||||
AzureOpenAIProviderSettings,
|
AzureOpenAIProviderSettings,
|
||||||
CerebrasProviderSettings,
|
|
||||||
CohereProviderSettings,
|
|
||||||
DeepInfraProviderSettings,
|
|
||||||
DeepSeekProviderSettings,
|
DeepSeekProviderSettings,
|
||||||
FalProviderSettings,
|
|
||||||
FireworksProviderSettings,
|
|
||||||
GenerateObjectParams,
|
GenerateObjectParams,
|
||||||
GenerateTextParams,
|
GenerateTextParams,
|
||||||
GoogleGenerativeAIProviderSettings,
|
GoogleGenerativeAIProviderSettings,
|
||||||
GoogleVertexProviderSettings,
|
|
||||||
GroqProviderSettings,
|
|
||||||
MistralProviderSettings,
|
|
||||||
OllamaProviderSettings,
|
|
||||||
OpenAICompatibleProviderSettings,
|
OpenAICompatibleProviderSettings,
|
||||||
OpenAIProviderSettings,
|
OpenAIProviderSettings,
|
||||||
OpenRouterProviderSettings,
|
|
||||||
PerplexityProviderSettings,
|
|
||||||
ProviderId,
|
ProviderId,
|
||||||
ProviderSettings,
|
ProviderSettings,
|
||||||
ProviderSettingsMap,
|
ProviderSettingsMap,
|
||||||
ReplicateProviderSettings,
|
|
||||||
StreamObjectParams,
|
StreamObjectParams,
|
||||||
StreamTextParams,
|
StreamTextParams,
|
||||||
TogetherAIProviderSettings,
|
|
||||||
VercelProviderSettings,
|
|
||||||
XaiProviderSettings
|
XaiProviderSettings
|
||||||
} from './types'
|
} from './types'
|
||||||
export * as aiSdk from 'ai'
|
export * as aiSdk from 'ai'
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
import { generateObject, generateText, streamObject, streamText } from 'ai'
|
||||||
|
|
||||||
import type { ProviderSettingsMap } from './core/providers/registry'
|
import type { ProviderSettingsMap } from './core/providers/types'
|
||||||
|
|
||||||
// ProviderSettings 是所有 Provider Settings 的联合类型
|
// ProviderSettings 是所有 Provider Settings 的联合类型
|
||||||
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
|
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||||
@ -12,32 +12,16 @@ export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'm
|
|||||||
|
|
||||||
// 重新导出 ProviderSettingsMap 中的所有类型
|
// 重新导出 ProviderSettingsMap 中的所有类型
|
||||||
export type {
|
export type {
|
||||||
AmazonBedrockProviderSettings,
|
|
||||||
AnthropicProviderSettings,
|
AnthropicProviderSettings,
|
||||||
AnthropicVertexProviderSettings,
|
|
||||||
AzureOpenAIProviderSettings,
|
AzureOpenAIProviderSettings,
|
||||||
CerebrasProviderSettings,
|
|
||||||
CohereProviderSettings,
|
|
||||||
DeepInfraProviderSettings,
|
|
||||||
DeepSeekProviderSettings,
|
DeepSeekProviderSettings,
|
||||||
FalProviderSettings,
|
|
||||||
FireworksProviderSettings,
|
|
||||||
GoogleGenerativeAIProviderSettings,
|
GoogleGenerativeAIProviderSettings,
|
||||||
GoogleVertexProviderSettings,
|
|
||||||
GroqProviderSettings,
|
|
||||||
MistralProviderSettings,
|
|
||||||
OllamaProviderSettings,
|
|
||||||
OpenAICompatibleProviderSettings,
|
OpenAICompatibleProviderSettings,
|
||||||
OpenAIProviderSettings,
|
OpenAIProviderSettings,
|
||||||
OpenRouterProviderSettings,
|
|
||||||
PerplexityProviderSettings,
|
|
||||||
ProviderId,
|
ProviderId,
|
||||||
ProviderSettingsMap,
|
ProviderSettingsMap,
|
||||||
ReplicateProviderSettings,
|
|
||||||
TogetherAIProviderSettings,
|
|
||||||
VercelProviderSettings,
|
|
||||||
XaiProviderSettings
|
XaiProviderSettings
|
||||||
} from './core/providers/registry'
|
} from './core/providers/types'
|
||||||
|
|
||||||
// 重新导出插件类型
|
// 重新导出插件类型
|
||||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import {
|
|||||||
type ProviderSettingsMap,
|
type ProviderSettingsMap,
|
||||||
StreamTextParams
|
StreamTextParams
|
||||||
} from '@cherrystudio/ai-core'
|
} from '@cherrystudio/ai-core'
|
||||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core'
|
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user