refactor: reorganize provider and model exports for improved structure

- Updated exports in index.ts and related files to streamline provider and model management.
- Introduced a new ModelCreator module for better encapsulation of model creation logic.
- Refactored type imports to enhance clarity and maintainability across the codebase.
- Removed deprecated provider configurations and cleaned up unused code for better performance.
This commit is contained in:
MyPrototypeWhat 2025-07-18 15:35:44 +08:00
parent c3ad18b77e
commit 1248e3c49a
18 changed files with 472 additions and 435 deletions

View File

@ -8,17 +8,17 @@ export type { NamedMiddleware } from './middleware'
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
// 创建管理
export type { ModelConfig } from './models'
export {
createBaseModel,
createImageModel,
createModel,
getProviderInfo,
getSupportedProviders,
ProviderCreationError
ModelCreationError
} from './models'
export type { ModelConfig } from './models/types'
// 执行管理
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
export type { ExecutionOptions, ExecutorConfig } from './runtime'
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
export type { RuntimeConfig } from './runtime/types'

View File

@ -1,22 +1,23 @@
/**
* Provider
* AI SDK providers
* Model Creator
* Provider AI SDK Language Model Image Model
*/
import { ImageModelV2, type LanguageModelV2 } from '@ai-sdk/provider'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
import { aiProviderRegistry, type ProviderConfig } from '../providers/registry'
import { createImageProvider, createProvider } from '../providers/creator'
import { aiProviderRegistry } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
// 错误类型
export class ProviderCreationError extends Error {
export class ModelCreationError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderCreationError'
this.name = 'ModelCreationError'
}
}
@ -29,13 +30,11 @@ export async function createBaseModel<T extends ProviderId>({
modelId,
providerSettings,
extraModelConfig
// middlewares
}: {
providerId: T
modelId: string
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2>
export async function createBaseModel({
@ -43,87 +42,49 @@ export async function createBaseModel({
modelId,
providerSettings,
extraModelConfig
// middlewares
}: {
providerId: string
modelId: string
providerSettings: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' }
extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2>
export async function createBaseModel({
providerId,
modelId,
providerSettings,
// middlewares,
extraModelConfig
}: {
providerId: string
modelId: string
providerSettings: ProviderSettingsMap[ProviderId] & { mode?: 'chat' | 'responses' }
// middlewares?: LanguageModelV1Middleware[]
extraModelConfig?: any
}): Promise<LanguageModelV2> {
try {
// 对于不在注册表中的 provider默认使用 openai-compatible
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
// 获取Provider配置
const providerConfig = aiProviderRegistry.getProvider(effectiveProviderId)
const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) {
throw new ProviderCreationError(`Provider "${effectiveProviderId}" is not registered`, providerId)
throw new ModelCreationError(`Provider "${providerId}" is not registered`, providerId)
}
// 动态导入模块
const module = await providerConfig.import()
// 创建 provider 实例
const provider = await createProvider(providerConfig, providerSettings)
// 获取创建函数
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ProviderCreationError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${effectiveProviderId}"`
)
}
// TODO: 对openai 的 providerSettings.mode参数是否要删除,目前看没毛病
// 创建provider实例
let provider = creatorFunction(providerSettings)
// 加一个特判
if (providerConfig.id === 'openai') {
if (
'mode' in providerSettings &&
providerSettings.mode === 'responses' &&
!isOpenAIChatCompletionOnlyModel(modelId)
) {
provider = provider.responses
} else {
provider = provider.chat
}
}
// 返回模型实例
if (typeof provider === 'function') {
// extraModelConfig:例如google的useSearchGrounding
const model: LanguageModelV2 = provider(modelId, extraModelConfig)
// // 应用 AI SDK 中间件
// if (middlewares && middlewares.length > 0) {
// model = wrapLanguageModel({
// model: model,
// middleware: middlewares
// })
// }
// 根据 provider 类型处理特殊逻辑
const finalProvider = handleProviderSpecificLogic(provider, providerConfig.id, providerSettings, modelId)
// 创建模型实例
if (typeof finalProvider === 'function') {
const model: LanguageModelV2 = finalProvider(modelId, extraModelConfig)
return model
} else {
throw new ProviderCreationError(`Unknown model access pattern for provider "${effectiveProviderId}"`)
throw new ModelCreationError(`Unknown model access pattern for provider "${providerId}"`)
}
} catch (error) {
if (error instanceof ProviderCreationError) {
if (error instanceof ModelCreationError) {
throw error
}
throw new ProviderCreationError(
throw new ModelCreationError(
`Failed to create base model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
@ -131,6 +92,27 @@ export async function createBaseModel({
}
}
/**
* Provider
*/
function handleProviderSpecificLogic(provider: any, providerId: string, providerSettings: any, modelId: string): any {
// OpenAI 特殊处理
if (providerId === 'openai') {
if (
'mode' in providerSettings &&
providerSettings.mode === 'responses' &&
!isOpenAIChatCompletionOnlyModel(modelId)
) {
return provider.responses
} else {
return provider.chat
}
}
// 其他 provider 直接返回
return provider
}
/**
*
*/
@ -151,40 +133,31 @@ export async function createImageModel(
): Promise<ImageModelV2> {
try {
if (!aiProviderRegistry.isSupported(providerId)) {
throw new ProviderCreationError(`Provider "${providerId}" is not supported`, providerId)
throw new ModelCreationError(`Provider "${providerId}" is not supported`, providerId)
}
const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) {
throw new ProviderCreationError(`Provider "${providerId}" is not registered`, providerId)
throw new ModelCreationError(`Provider "${providerId}" is not registered`, providerId)
}
if (!providerConfig.supportsImageGeneration) {
throw new ProviderCreationError(`Provider "${providerId}" does not support image generation`, providerId)
throw new ModelCreationError(`Provider "${providerId}" does not support image generation`, providerId)
}
const module = await providerConfig.import()
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ProviderCreationError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
)
}
const provider = creatorFunction(options)
// 创建图像 provider 实例
const provider = await createImageProvider(providerConfig, options)
if (provider && typeof provider.image === 'function') {
return provider.image(modelId)
} else {
throw new ProviderCreationError(`Image model function not found for provider "${providerId}"`)
throw new ModelCreationError(`Image model function not found for provider "${providerId}"`)
}
} catch (error) {
if (error instanceof ProviderCreationError) {
if (error instanceof ModelCreationError) {
throw error
}
throw new ProviderCreationError(
throw new ModelCreationError(
`Failed to create image model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
@ -199,7 +172,7 @@ export function getSupportedProviders(): Array<{
id: string
name: string
}> {
return aiProviderRegistry.getAllProviders().map((provider: ProviderConfig) => ({
return aiProviderRegistry.getAllProviders().map((provider) => ({
id: provider.id,
name: provider.name
}))

View File

@ -6,7 +6,7 @@ import { LanguageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai'
import { wrapModelWithMiddlewares } from '../middleware'
import { createBaseModel } from './ProviderCreator'
import { createBaseModel } from './ModelCreator'
import { ModelConfig } from './types'
/**

View File

@ -1,19 +1,18 @@
/**
* Models
*
* Models
*/
// 主要的模型创建API
export { createModel, createModels } from './factory'
// 底层Provider创建功能供高级用户使用
// Model 创建相关
export {
createBaseModel,
createImageModel,
getProviderInfo,
getSupportedProviders,
ProviderCreationError
} from './ProviderCreator'
ModelCreationError
} from './ModelCreator'
// 保留原有类型
export type { ModelConfig } from './types'
// Model 配置和工厂
export { createModel } from './factory'
// 类型定义
export type { ModelConfig as ModelConfigType } from './types'

View File

@ -3,7 +3,7 @@
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import type { ProviderId, ProviderSettingsMap } from '../../types'
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
export interface ModelConfig<T extends ProviderId = ProviderId> {
providerId: T

View File

@ -41,8 +41,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
break
}
case 'google':
case 'google-vertex': {
case 'google': {
// case 'google-vertex':
if (!params.tools) params.tools = {}
params.tools.web_search = google.tools.googleSearch(config.google || {})
break

View File

@ -1,6 +1,6 @@
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
import { ProviderId } from '../providers/registry'
import { type ProviderId } from '../providers/types'
/**
*

View File

@ -0,0 +1,108 @@
/**
* Provider Creator
* ProviderConfig provider
*/
import { type ProviderConfig } from './types'
// 错误类型
export class ProviderCreationError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderCreationError'
}
}
/**
* Provider
* creator +
*/
export async function createProvider(config: ProviderConfig, options: any): Promise<any> {
try {
// 验证配置
if (!config.creator && !(config.import && config.creatorFunctionName)) {
throw new ProviderCreationError(
'Invalid provider configuration: must provide either creator function or import configuration',
config.id
)
}
// 方式一:直接使用 creator 函数
if (config.creator) {
return config.creator(options)
}
// 方式二:动态导入 + 函数名
if (config.import && config.creatorFunctionName) {
const module = await config.import()
const creatorFunction = module[config.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ProviderCreationError(
`Creator function "${config.creatorFunctionName}" not found in the imported module`,
config.id
)
}
return creatorFunction(options)
}
throw new ProviderCreationError('Unexpected provider configuration state', config.id)
} catch (error) {
if (error instanceof ProviderCreationError) {
throw error
}
throw new ProviderCreationError(
`Failed to create provider "${config.id}": ${error instanceof Error ? error.message : 'Unknown error'}`,
config.id,
error instanceof Error ? error : undefined
)
}
}
/**
* Provider
*/
export async function createImageProvider(config: ProviderConfig, options: any): Promise<any> {
try {
if (!config.supportsImageGeneration) {
throw new ProviderCreationError(`Provider "${config.id}" does not support image generation`, config.id)
}
// 如果有专门的图像 creator
if (config.imageCreator) {
return config.imageCreator(options)
}
// 否则使用普通的 provider 创建流程
return await createProvider(config, options)
} catch (error) {
if (error instanceof ProviderCreationError) {
throw error
}
throw new ProviderCreationError(
`Failed to create image provider "${config.id}": ${error instanceof Error ? error.message : 'Unknown error'}`,
config.id,
error instanceof Error ? error : undefined
)
}
}
/**
* Provider
*/
export function validateProviderConfig(config: ProviderConfig): boolean {
if (!config.id || !config.name) {
return false
}
if (!config.creator && !(config.import && config.creatorFunctionName)) {
return false
}
return true
}

View File

@ -3,8 +3,7 @@
* Provider
*/
import type { ProviderId, ProviderSettingsMap } from './registry'
import { formatPrivateKey } from './utils'
import type { ProviderId, ProviderSettingsMap } from './types'
/**
* Provider
@ -37,20 +36,20 @@ const configHandlers: {
apiVersion: azureProvider.apiVersion,
resourceName: azureProvider.resourceName
})
},
'google-vertex': (builder, provider) => {
const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
vertexBuilder
.withGoogleVertexConfig({
project: vertexProvider.project,
location: vertexProvider.location
})
.withGoogleCredentials({
clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
privateKey: vertexProvider.googleCredentials?.privateKey || ''
})
}
// 'google-vertex': (builder, provider) => {
// const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
// const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
// vertexBuilder
// .withGoogleVertexConfig({
// project: vertexProvider.project,
// location: vertexProvider.location
// })
// .withGoogleCredentials({
// clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
// privateKey: vertexProvider.googleCredentials?.privateKey || ''
// })
// }
}
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
@ -121,35 +120,36 @@ export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
/**
* Google
*/
withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
withGoogleVertexConfig(options: any): any {
if (this.providerId === 'google-vertex') {
const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
if (options.project) {
googleConfig.project = options.project
}
if (options.location) {
googleConfig.location = options.location
if (options.location === 'global') {
googleConfig.baseURL = 'https://aiplatform.googleapis.com'
}
}
}
return this
}
// withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
// withGoogleVertexConfig(options: any): any {
// if (this.providerId === 'google-vertex') {
// const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
// if (options.project) {
// googleConfig.project = options.project
// }
// if (options.location) {
// googleConfig.location = options.location
// if (options.location === 'global') {
// googleConfig.baseURL = 'https://aiplatform.googleapis.com'
// }
// }
// }
// return this
// }
withGoogleCredentials(credentials: {
clientEmail: string
privateKey: string
}): T extends 'google-vertex' ? this : never
withGoogleCredentials(credentials: any): any {
if (this.providerId === 'google-vertex') {
const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
vertexConfig.googleCredentials = {
clientEmail: credentials.clientEmail,
privateKey: formatPrivateKey(credentials.privateKey)
}
}
withGoogleCredentials(): any {
// withGoogleCredentials(credentials: any): any {
// if (this.providerId === 'google-vertex') {
// const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
// vertexConfig.googleCredentials = {
// clientEmail: credentials.clientEmail,
// privateKey: formatPrivateKey(credentials.privateKey)
// }
// }
return this
}
@ -310,23 +310,22 @@ export class ProviderConfigFactory {
/**
* Vertex AI
*/
static createVertexAI(
credentials: {
clientEmail: string
privateKey: string
},
options?: {
project?: string
location?: string
}
) {
return this.builder('google-vertex')
.withGoogleCredentials(credentials)
.withGoogleVertexConfig({
project: options?.project,
location: options?.location
})
.build()
static createVertexAI() {
// credentials: {
// clientEmail: string
// privateKey: string
// },
// options?: {
// project?: string
// location?: string
// }
// return this.builder('google-vertex')
// .withGoogleCredentials(credentials)
// .withGoogleVertexConfig({
// project: options?.project,
// location: options?.location
// })
// .build()
}
static createOpenAICompatible(baseURL: string, apiKey: string) {

View File

@ -0,0 +1,15 @@
/**
* Providers
*/
// Provider 注册表
export { aiProviderRegistry, getAllProviders, getProvider, isProviderSupported, registerProvider } from './registry'
// Provider 创建
export { createImageProvider, createProvider, ProviderCreationError, validateProviderConfig } from './creator'
// 类型定义
export type { ProviderConfig, ProviderError, ProviderId, ProviderSettingsMap } from './types'
// 工厂和配置
export * from './factory'

View File

@ -3,76 +3,8 @@
* +
*/
// 静态导入所有 AI SDK 类型
import { type AmazonBedrockProviderSettings } from '@ai-sdk/amazon-bedrock'
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type CerebrasProviderSettings } from '@ai-sdk/cerebras'
import { type CohereProviderSettings } from '@ai-sdk/cohere'
import { type DeepInfraProviderSettings } from '@ai-sdk/deepinfra'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type FalProviderSettings } from '@ai-sdk/fal'
import { type FireworksProviderSettings } from '@ai-sdk/fireworks'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
import { type GroqProviderSettings } from '@ai-sdk/groq'
import { type MistralProviderSettings } from '@ai-sdk/mistral'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type PerplexityProviderSettings } from '@ai-sdk/perplexity'
import { type ReplicateProviderSettings } from '@ai-sdk/replicate'
import { type TogetherAIProviderSettings } from '@ai-sdk/togetherai'
import { type VercelProviderSettings } from '@ai-sdk/vercel'
import { type XaiProviderSettings } from '@ai-sdk/xai'
import { type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
import { type AnthropicVertexProviderSettings } from 'anthropic-vertex-ai'
import { type OllamaProviderSettings } from 'ollama-ai-provider'
import { type ProviderConfig } from './types'
// 类型安全的 Provider Settings 映射
export type ProviderSettingsMap = {
openai: OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
openrouter: OpenRouterProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
'google-vertex': GoogleVertexProviderSettings
mistral: MistralProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
bedrock: AmazonBedrockProviderSettings
cohere: CohereProviderSettings
groq: GroqProviderSettings
together: TogetherAIProviderSettings
fireworks: FireworksProviderSettings
deepseek: DeepSeekProviderSettings
cerebras: CerebrasProviderSettings
deepinfra: DeepInfraProviderSettings
replicate: ReplicateProviderSettings
perplexity: PerplexityProviderSettings
fal: FalProviderSettings
vercel: VercelProviderSettings
ollama: OllamaProviderSettings
'anthropic-vertex': AnthropicVertexProviderSettings
}
export type ProviderId = keyof ProviderSettingsMap & string
// 统一的 Provider 配置接口(所有都使用动态导入)
export interface ProviderConfig {
id: string
name: string
// 动态导入函数
import: () => Promise<any>
// 创建函数名称
creatorFunctionName: string
// 是否支持图片生成
supportsImageGeneration?: boolean
}
/**
* AI SDK Provider
* AI Providers
*/
export class AiProviderRegistry {
private static instance: AiProviderRegistry
private registry = new Map<string, ProviderConfig>()
@ -123,20 +55,20 @@ export class AiProviderRegistry {
creatorFunctionName: 'createGoogleGenerativeAI',
supportsImageGeneration: true
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true
},
{
id: 'mistral',
name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false
},
// {
// id: 'google-vertex',
// name: 'Google Vertex AI',
// import: () => import('@ai-sdk/google-vertex/edge'),
// creatorFunctionName: 'createVertex',
// supportsImageGeneration: true
// },
// {
// id: 'mistral',
// name: 'Mistral AI',
// import: () => import('@ai-sdk/mistral'),
// creatorFunctionName: 'createMistral',
// supportsImageGeneration: false
// },
{
id: 'xai',
name: 'xAI (Grok)',
@ -151,112 +83,112 @@ export class AiProviderRegistry {
creatorFunctionName: 'createAzure',
supportsImageGeneration: true
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: false
},
{
id: 'cohere',
name: 'Cohere',
import: () => import('@ai-sdk/cohere'),
creatorFunctionName: 'createCohere',
supportsImageGeneration: false
},
{
id: 'groq',
name: 'Groq',
import: () => import('@ai-sdk/groq'),
creatorFunctionName: 'createGroq',
supportsImageGeneration: false
},
{
id: 'together',
name: 'Together.ai',
import: () => import('@ai-sdk/togetherai'),
creatorFunctionName: 'createTogetherAI',
supportsImageGeneration: true
},
{
id: 'fireworks',
name: 'Fireworks',
import: () => import('@ai-sdk/fireworks'),
creatorFunctionName: 'createFireworks',
supportsImageGeneration: true
},
// {
// id: 'bedrock',
// name: 'Amazon Bedrock',
// import: () => import('@ai-sdk/amazon-bedrock'),
// creatorFunctionName: 'createAmazonBedrock',
// supportsImageGeneration: false
// },
// {
// id: 'cohere',
// name: 'Cohere',
// import: () => import('@ai-sdk/cohere'),
// creatorFunctionName: 'createCohere',
// supportsImageGeneration: false
// },
// {
// id: 'groq',
// name: 'Groq',
// import: () => import('@ai-sdk/groq'),
// creatorFunctionName: 'createGroq',
// supportsImageGeneration: false
// },
// {
// id: 'together',
// name: 'Together.ai',
// import: () => import('@ai-sdk/togetherai'),
// creatorFunctionName: 'createTogetherAI',
// supportsImageGeneration: true
// },
// {
// id: 'fireworks',
// name: 'Fireworks',
// import: () => import('@ai-sdk/fireworks'),
// creatorFunctionName: 'createFireworks',
// supportsImageGeneration: true
// },
{
id: 'deepseek',
name: 'DeepSeek',
import: () => import('@ai-sdk/deepseek'),
creatorFunctionName: 'createDeepSeek',
supportsImageGeneration: false
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
},
{
id: 'deepinfra',
name: 'DeepInfra',
import: () => import('@ai-sdk/deepinfra'),
creatorFunctionName: 'createDeepInfra',
supportsImageGeneration: false
},
{
id: 'replicate',
name: 'Replicate',
import: () => import('@ai-sdk/replicate'),
creatorFunctionName: 'createReplicate',
supportsImageGeneration: true
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false
},
{
id: 'fal',
name: 'Fal AI',
import: () => import('@ai-sdk/fal'),
creatorFunctionName: 'createFal',
supportsImageGeneration: false
},
{
id: 'vercel',
name: 'Vercel',
import: () => import('@ai-sdk/vercel'),
creatorFunctionName: 'createVercel'
},
}
// {
// id: 'cerebras',
// name: 'Cerebras',
// import: () => import('@ai-sdk/cerebras'),
// creatorFunctionName: 'createCerebras',
// supportsImageGeneration: false
// },
// {
// id: 'deepinfra',
// name: 'DeepInfra',
// import: () => import('@ai-sdk/deepinfra'),
// creatorFunctionName: 'createDeepInfra',
// supportsImageGeneration: false
// },
// {
// id: 'replicate',
// name: 'Replicate',
// import: () => import('@ai-sdk/replicate'),
// creatorFunctionName: 'createReplicate',
// supportsImageGeneration: true
// },
// {
// id: 'perplexity',
// name: 'Perplexity',
// import: () => import('@ai-sdk/perplexity'),
// creatorFunctionName: 'createPerplexity',
// supportsImageGeneration: false
// },
// {
// id: 'fal',
// name: 'Fal AI',
// import: () => import('@ai-sdk/fal'),
// creatorFunctionName: 'createFal',
// supportsImageGeneration: false
// },
// {
// id: 'vercel',
// name: 'Vercel',
// import: () => import('@ai-sdk/vercel'),
// creatorFunctionName: 'createVercel'
// },
// 社区 Providers (5个)
{
id: 'ollama',
name: 'Ollama',
import: () => import('ollama-ai-provider'),
creatorFunctionName: 'createOllama',
supportsImageGeneration: false
},
{
id: 'anthropic-vertex',
name: 'Anthropic Vertex AI',
import: () => import('anthropic-vertex-ai'),
creatorFunctionName: 'createAnthropicVertex',
supportsImageGeneration: false
},
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: false
}
// {
// id: 'ollama',
// name: 'Ollama',
// import: () => import('ollama-ai-provider'),
// creatorFunctionName: 'createOllama',
// supportsImageGeneration: false
// },
// {
// id: 'anthropic-vertex',
// name: 'Anthropic Vertex AI',
// import: () => import('anthropic-vertex-ai'),
// creatorFunctionName: 'createAnthropicVertex',
// supportsImageGeneration: false
// },
// {
// id: 'openrouter',
// name: 'OpenRouter',
// import: () => import('@openrouter/ai-sdk-provider'),
// creatorFunctionName: 'createOpenRouter',
// supportsImageGeneration: false
// }
]
providers.forEach((config) => {
@ -289,6 +221,16 @@ export class AiProviderRegistry {
* Provider
*/
public registerProvider(config: ProviderConfig): void {
// 验证:必须提供 creator 或 (import + creatorFunctionName)
if (!config.creator && !(config.import && config.creatorFunctionName)) {
throw new Error('Must provide either creator function or import configuration')
}
// 验证:不能同时提供两种方式
if (config.creator && config.import) {
console.warn('Both creator and import provided, creator will take precedence')
}
this.registry.set(config.id, config)
}
@ -299,21 +241,21 @@ export class AiProviderRegistry {
this.registry.clear()
}
/**
*
*/
public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
// /**
// * 获取兼容现有实现的注册表格式
// */
// public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
// const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
this.getAllProviders().forEach((provider) => {
compatibleRegistry[provider.id] = {
import: provider.import,
creatorFunctionName: provider.creatorFunctionName
}
})
// this.getAllProviders().forEach((provider) => {
// compatibleRegistry[provider.id] = {
// import: provider.import,
// creatorFunctionName: provider.creatorFunctionName
// }
// })
return compatibleRegistry
}
// return compatibleRegistry
// }
}
// 导出单例实例
@ -326,31 +268,4 @@ export const isProviderSupported = (id: string) => aiProviderRegistry.isSupporte
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
// 兼容现有实现的导出
export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
// 重新导出所有类型供外部使用
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings
}
// export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()

View File

@ -1,14 +1,35 @@
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type XaiProviderSettings } from '@ai-sdk/xai'
/**
* Provider
* 使 AI SDK
*/
export type ProviderId = keyof ProviderSettingsMap & string
// Provider 配置接口(简化版)
// Provider 配置接口 - 支持灵活的创建方式
export interface ProviderConfig {
id: string
name: string
import: () => Promise<any>
creatorFunctionName: string
// 方式一:直接提供 creator 函数(推荐用于自定义)
creator?: (options: any) => any
// 方式二:动态导入 + 函数名(用于包导入)
import?: () => Promise<any>
creatorFunctionName?: string
// 图片生成支持
supportsImageGeneration?: boolean
imageCreator?: (options: any) => any
// 可选的验证函数
validateOptions?: (options: any) => boolean
}
// API 客户端工厂接口
@ -45,3 +66,57 @@ export interface CacheStats {
keys: string[]
lastCleanup?: Date
}
// 类型安全的 Provider Settings 映射
export type ProviderSettingsMap = {
openai: OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
// openrouter: OpenRouterProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
// 'google-vertex': GoogleVertexProviderSettings
// mistral: MistralProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
// bedrock: AmazonBedrockProviderSettings
// cohere: CohereProviderSettings
// groq: GroqProviderSettings
// together: TogetherAIProviderSettings
// fireworks: FireworksProviderSettings
deepseek: DeepSeekProviderSettings
// cerebras: CerebrasProviderSettings
// deepinfra: DeepInfraProviderSettings
// replicate: ReplicateProviderSettings
// perplexity: PerplexityProviderSettings
// fal: FalProviderSettings
// vercel: VercelProviderSettings
// ollama: OllamaProviderSettings
// 'anthropic-vertex': AnthropicVertexProviderSettings
}
// 重新导出所有类型供外部使用
export type {
// AmazonBedrockProviderSettings,
AnthropicProviderSettings,
// AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
// CerebrasProviderSettings,
// CohereProviderSettings,
// DeepInfraProviderSettings,
DeepSeekProviderSettings,
// FalProviderSettings,
// FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
// GoogleVertexProviderSettings,
// GroqProviderSettings,
// MistralProviderSettings,
// OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
// OpenRouterProviderSettings,
// PerplexityProviderSettings,
// ReplicateProviderSettings,
// TogetherAIProviderSettings,
// VercelProviderSettings,
XaiProviderSettings
}

View File

@ -7,19 +7,14 @@
export { RuntimeExecutor } from './executor'
// 导出类型
export type {
ExecutionOptions,
// 向后兼容的类型别名
ExecutorConfig,
RuntimeConfig
} from './types'
export type { RuntimeConfig } from './types'
// === 便捷工厂函数 ===
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin } from '../plugins'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
import { RuntimeExecutor } from './executor'
/**

View File

@ -1,8 +1,8 @@
import { LanguageModel } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
/**
* AI

View File

@ -1,9 +1,9 @@
/**
* Runtime
*/
import { type ProviderId } from '../../types'
import { type ModelConfig } from '../models'
import { type ModelConfig } from '../models/types'
import { type AiPlugin } from '../plugins'
import { type ProviderId } from '../providers/types'
/**
*
@ -13,14 +13,3 @@ export interface RuntimeConfig<T extends ProviderId = ProviderId> {
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
plugins?: AiPlugin[]
}
/**
*
*/
export interface ExecutionOptions {
// 未来可以添加执行级别的选项
// 比如:超时设置、重试机制等
}
// 保留旧类型以保持向后兼容
export interface ExecutorConfig<T extends ProviderId = ProviderId> extends RuntimeConfig<T> {}

View File

@ -9,14 +9,15 @@ import {
getSupportedProviders as factoryGetSupportedProviders
} from './core/models'
import { aiProviderRegistry, isProviderSupported } from './core/providers/registry'
import type { ProviderId } from './core/providers/types'
import type { ProviderSettingsMap } from './core/providers/types'
import { createExecutor } from './core/runtime'
import type { ProviderId, ProviderSettingsMap } from './types'
// ==================== 主要用户接口 ====================
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
// ==================== 高级API ====================
export { createModel, type ModelConfig } from './core/models'
export { createModel } from './core/models'
// ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
@ -30,43 +31,27 @@ export {
createImageModel,
getProviderInfo as getClientInfo,
getSupportedProviders,
ProviderCreationError
ModelCreationError
} from './core/models'
export { aiProviderRegistry } from './core/providers/registry'
// ==================== 类型定义 ====================
export type { ProviderConfig } from './core/providers/registry'
export type { ProviderConfig } from './core/providers/types'
export type { ProviderError } from './core/providers/types'
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GenerateObjectParams,
GenerateTextParams,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId,
ProviderSettings,
ProviderSettingsMap,
ReplicateProviderSettings,
StreamObjectParams,
StreamTextParams,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings
} from './types'
export * as aiSdk from 'ai'

View File

@ -1,6 +1,6 @@
import { generateObject, generateText, streamObject, streamText } from 'ai'
import type { ProviderSettingsMap } from './core/providers/registry'
import type { ProviderSettingsMap } from './core/providers/types'
// ProviderSettings 是所有 Provider Settings 的联合类型
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
@ -12,32 +12,16 @@ export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'm
// 重新导出 ProviderSettingsMap 中的所有类型
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId,
ProviderSettingsMap,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings
} from './core/providers/registry'
} from './core/providers/types'
// 重新导出插件类型
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'

View File

@ -17,7 +17,7 @@ import {
type ProviderSettingsMap,
StreamTextParams
} from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'