refactor: reorganize provider and model exports for improved structure

- Updated exports in index.ts and related files to streamline provider and model management.
- Introduced a new ModelCreator module for better encapsulation of model creation logic.
- Refactored type imports to enhance clarity and maintainability across the codebase.
- Removed deprecated provider configurations and cleaned up unused code for better performance.
This commit is contained in:
MyPrototypeWhat 2025-07-18 15:35:44 +08:00
parent c3ad18b77e
commit 1248e3c49a
18 changed files with 472 additions and 435 deletions

View File

@ -8,17 +8,17 @@ export type { NamedMiddleware } from './middleware'
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware' export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
// 创建管理 // 创建管理
export type { ModelConfig } from './models'
export { export {
createBaseModel, createBaseModel,
createImageModel, createImageModel,
createModel, createModel,
getProviderInfo, getProviderInfo,
getSupportedProviders, getSupportedProviders,
ProviderCreationError ModelCreationError
} from './models' } from './models'
export type { ModelConfig } from './models/types'
// 执行管理 // 执行管理
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type' export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
export type { ExecutionOptions, ExecutorConfig } from './runtime'
export { createExecutor, createOpenAICompatibleExecutor } from './runtime' export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
export type { RuntimeConfig } from './runtime/types'

View File

@ -1,22 +1,23 @@
/** /**
* Provider * Model Creator
* AI SDK providers * Provider AI SDK Language Model Image Model
*/ */
import { ImageModelV2, type LanguageModelV2 } from '@ai-sdk/provider' import { ImageModelV2, type LanguageModelV2 } from '@ai-sdk/provider'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
import { aiProviderRegistry, type ProviderConfig } from '../providers/registry' import { createImageProvider, createProvider } from '../providers/creator'
import { aiProviderRegistry } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
// 错误类型 // 错误类型
export class ProviderCreationError extends Error { export class ModelCreationError extends Error {
constructor( constructor(
message: string, message: string,
public providerId?: string, public providerId?: string,
public cause?: Error public cause?: Error
) { ) {
super(message) super(message)
this.name = 'ProviderCreationError' this.name = 'ModelCreationError'
} }
} }
@ -29,13 +30,11 @@ export async function createBaseModel<T extends ProviderId>({
modelId, modelId,
providerSettings, providerSettings,
extraModelConfig extraModelConfig
// middlewares
}: { }: {
providerId: T providerId: T
modelId: string modelId: string
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' } providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
extraModelConfig?: any extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2> }): Promise<LanguageModelV2>
export async function createBaseModel({ export async function createBaseModel({
@ -43,87 +42,49 @@ export async function createBaseModel({
modelId, modelId,
providerSettings, providerSettings,
extraModelConfig extraModelConfig
// middlewares
}: { }: {
providerId: string providerId: string
modelId: string modelId: string
providerSettings: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' } providerSettings: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' }
extraModelConfig?: any extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2> }): Promise<LanguageModelV2>
export async function createBaseModel({ export async function createBaseModel({
providerId, providerId,
modelId, modelId,
providerSettings, providerSettings,
// middlewares,
extraModelConfig extraModelConfig
}: { }: {
providerId: string providerId: string
modelId: string modelId: string
providerSettings: ProviderSettingsMap[ProviderId] & { mode?: 'chat' | 'responses' } providerSettings: ProviderSettingsMap[ProviderId] & { mode?: 'chat' | 'responses' }
// middlewares?: LanguageModelV1Middleware[]
extraModelConfig?: any extraModelConfig?: any
}): Promise<LanguageModelV2> { }): Promise<LanguageModelV2> {
try { try {
// 对于不在注册表中的 provider默认使用 openai-compatible
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
// 获取Provider配置 // 获取Provider配置
const providerConfig = aiProviderRegistry.getProvider(effectiveProviderId) const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) { if (!providerConfig) {
throw new ProviderCreationError(`Provider "${effectiveProviderId}" is not registered`, providerId) throw new ModelCreationError(`Provider "${providerId}" is not registered`, providerId)
} }
// 动态导入模块 // 创建 provider 实例
const module = await providerConfig.import() const provider = await createProvider(providerConfig, providerSettings)
// 获取创建函数 // 根据 provider 类型处理特殊逻辑
const creatorFunction = module[providerConfig.creatorFunctionName] const finalProvider = handleProviderSpecificLogic(provider, providerConfig.id, providerSettings, modelId)
if (typeof creatorFunction !== 'function') {
throw new ProviderCreationError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${effectiveProviderId}"`
)
}
// TODO: 对openai 的 providerSettings.mode参数是否要删除,目前看没毛病
// 创建provider实例
let provider = creatorFunction(providerSettings)
// 加一个特判
if (providerConfig.id === 'openai') {
if (
'mode' in providerSettings &&
providerSettings.mode === 'responses' &&
!isOpenAIChatCompletionOnlyModel(modelId)
) {
provider = provider.responses
} else {
provider = provider.chat
}
}
// 返回模型实例
if (typeof provider === 'function') {
// extraModelConfig:例如google的useSearchGrounding
const model: LanguageModelV2 = provider(modelId, extraModelConfig)
// // 应用 AI SDK 中间件
// if (middlewares && middlewares.length > 0) {
// model = wrapLanguageModel({
// model: model,
// middleware: middlewares
// })
// }
// 创建模型实例
if (typeof finalProvider === 'function') {
const model: LanguageModelV2 = finalProvider(modelId, extraModelConfig)
return model return model
} else { } else {
throw new ProviderCreationError(`Unknown model access pattern for provider "${effectiveProviderId}"`) throw new ModelCreationError(`Unknown model access pattern for provider "${providerId}"`)
} }
} catch (error) { } catch (error) {
if (error instanceof ProviderCreationError) { if (error instanceof ModelCreationError) {
throw error throw error
} }
throw new ProviderCreationError( throw new ModelCreationError(
`Failed to create base model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`, `Failed to create base model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId, providerId,
error instanceof Error ? error : undefined error instanceof Error ? error : undefined
@ -131,6 +92,27 @@ export async function createBaseModel({
} }
} }
/**
* Provider
*/
function handleProviderSpecificLogic(provider: any, providerId: string, providerSettings: any, modelId: string): any {
// OpenAI 特殊处理
if (providerId === 'openai') {
if (
'mode' in providerSettings &&
providerSettings.mode === 'responses' &&
!isOpenAIChatCompletionOnlyModel(modelId)
) {
return provider.responses
} else {
return provider.chat
}
}
// 其他 provider 直接返回
return provider
}
/** /**
* *
*/ */
@ -151,40 +133,31 @@ export async function createImageModel(
): Promise<ImageModelV2> { ): Promise<ImageModelV2> {
try { try {
if (!aiProviderRegistry.isSupported(providerId)) { if (!aiProviderRegistry.isSupported(providerId)) {
throw new ProviderCreationError(`Provider "${providerId}" is not supported`, providerId) throw new ModelCreationError(`Provider "${providerId}" is not supported`, providerId)
} }
const providerConfig = aiProviderRegistry.getProvider(providerId) const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) { if (!providerConfig) {
throw new ProviderCreationError(`Provider "${providerId}" is not registered`, providerId) throw new ModelCreationError(`Provider "${providerId}" is not registered`, providerId)
} }
if (!providerConfig.supportsImageGeneration) { if (!providerConfig.supportsImageGeneration) {
throw new ProviderCreationError(`Provider "${providerId}" does not support image generation`, providerId) throw new ModelCreationError(`Provider "${providerId}" does not support image generation`, providerId)
} }
const module = await providerConfig.import() // 创建图像 provider 实例
const provider = await createImageProvider(providerConfig, options)
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ProviderCreationError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
)
}
const provider = creatorFunction(options)
if (provider && typeof provider.image === 'function') { if (provider && typeof provider.image === 'function') {
return provider.image(modelId) return provider.image(modelId)
} else { } else {
throw new ProviderCreationError(`Image model function not found for provider "${providerId}"`) throw new ModelCreationError(`Image model function not found for provider "${providerId}"`)
} }
} catch (error) { } catch (error) {
if (error instanceof ProviderCreationError) { if (error instanceof ModelCreationError) {
throw error throw error
} }
throw new ProviderCreationError( throw new ModelCreationError(
`Failed to create image model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`, `Failed to create image model for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId, providerId,
error instanceof Error ? error : undefined error instanceof Error ? error : undefined
@ -199,7 +172,7 @@ export function getSupportedProviders(): Array<{
id: string id: string
name: string name: string
}> { }> {
return aiProviderRegistry.getAllProviders().map((provider: ProviderConfig) => ({ return aiProviderRegistry.getAllProviders().map((provider) => ({
id: provider.id, id: provider.id,
name: provider.name name: provider.name
})) }))

View File

@ -6,7 +6,7 @@ import { LanguageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai' import { LanguageModel } from 'ai'
import { wrapModelWithMiddlewares } from '../middleware' import { wrapModelWithMiddlewares } from '../middleware'
import { createBaseModel } from './ProviderCreator' import { createBaseModel } from './ModelCreator'
import { ModelConfig } from './types' import { ModelConfig } from './types'
/** /**

View File

@ -1,19 +1,18 @@
/** /**
* Models * Models
*
*/ */
// 主要的模型创建API // Model 创建相关
export { createModel, createModels } from './factory'
// 底层Provider创建功能供高级用户使用
export { export {
createBaseModel, createBaseModel,
createImageModel, createImageModel,
getProviderInfo, getProviderInfo,
getSupportedProviders, getSupportedProviders,
ProviderCreationError ModelCreationError
} from './ProviderCreator' } from './ModelCreator'
// 保留原有类型 // Model 配置和工厂
export type { ModelConfig } from './types' export { createModel } from './factory'
// 类型定义
export type { ModelConfig as ModelConfigType } from './types'

View File

@ -3,7 +3,7 @@
*/ */
import { LanguageModelV2Middleware } from '@ai-sdk/provider' import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import type { ProviderId, ProviderSettingsMap } from '../../types' import type { ProviderId, ProviderSettingsMap } from '../providers/types'
export interface ModelConfig<T extends ProviderId = ProviderId> { export interface ModelConfig<T extends ProviderId = ProviderId> {
providerId: T providerId: T

View File

@ -41,8 +41,8 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
break break
} }
case 'google': case 'google': {
case 'google-vertex': { // case 'google-vertex':
if (!params.tools) params.tools = {} if (!params.tools) params.tools = {}
params.tools.web_search = google.tools.googleSearch(config.google || {}) params.tools.web_search = google.tools.googleSearch(config.google || {})
break break

View File

@ -1,6 +1,6 @@
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai' import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
import { ProviderId } from '../providers/registry' import { type ProviderId } from '../providers/types'
/** /**
* *

View File

@ -0,0 +1,108 @@
/**
* Provider Creator
* ProviderConfig provider
*/
import { type ProviderConfig } from './types'
// 错误类型
export class ProviderCreationError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderCreationError'
}
}
/**
* Provider
* creator +
*/
export async function createProvider(config: ProviderConfig, options: any): Promise<any> {
try {
// 验证配置
if (!config.creator && !(config.import && config.creatorFunctionName)) {
throw new ProviderCreationError(
'Invalid provider configuration: must provide either creator function or import configuration',
config.id
)
}
// 方式一:直接使用 creator 函数
if (config.creator) {
return config.creator(options)
}
// 方式二:动态导入 + 函数名
if (config.import && config.creatorFunctionName) {
const module = await config.import()
const creatorFunction = module[config.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ProviderCreationError(
`Creator function "${config.creatorFunctionName}" not found in the imported module`,
config.id
)
}
return creatorFunction(options)
}
throw new ProviderCreationError('Unexpected provider configuration state', config.id)
} catch (error) {
if (error instanceof ProviderCreationError) {
throw error
}
throw new ProviderCreationError(
`Failed to create provider "${config.id}": ${error instanceof Error ? error.message : 'Unknown error'}`,
config.id,
error instanceof Error ? error : undefined
)
}
}
/**
* Provider
*/
export async function createImageProvider(config: ProviderConfig, options: any): Promise<any> {
try {
if (!config.supportsImageGeneration) {
throw new ProviderCreationError(`Provider "${config.id}" does not support image generation`, config.id)
}
// 如果有专门的图像 creator
if (config.imageCreator) {
return config.imageCreator(options)
}
// 否则使用普通的 provider 创建流程
return await createProvider(config, options)
} catch (error) {
if (error instanceof ProviderCreationError) {
throw error
}
throw new ProviderCreationError(
`Failed to create image provider "${config.id}": ${error instanceof Error ? error.message : 'Unknown error'}`,
config.id,
error instanceof Error ? error : undefined
)
}
}
/**
* Provider
*/
export function validateProviderConfig(config: ProviderConfig): boolean {
if (!config.id || !config.name) {
return false
}
if (!config.creator && !(config.import && config.creatorFunctionName)) {
return false
}
return true
}

View File

@ -3,8 +3,7 @@
* Provider * Provider
*/ */
import type { ProviderId, ProviderSettingsMap } from './registry' import type { ProviderId, ProviderSettingsMap } from './types'
import { formatPrivateKey } from './utils'
/** /**
* Provider * Provider
@ -37,20 +36,20 @@ const configHandlers: {
apiVersion: azureProvider.apiVersion, apiVersion: azureProvider.apiVersion,
resourceName: azureProvider.resourceName resourceName: azureProvider.resourceName
}) })
},
'google-vertex': (builder, provider) => {
const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
vertexBuilder
.withGoogleVertexConfig({
project: vertexProvider.project,
location: vertexProvider.location
})
.withGoogleCredentials({
clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
privateKey: vertexProvider.googleCredentials?.privateKey || ''
})
} }
// 'google-vertex': (builder, provider) => {
// const vertexBuilder = builder as ProviderConfigBuilder<'google-vertex'>
// const vertexProvider = provider as CompleteProviderConfig<'google-vertex'>
// vertexBuilder
// .withGoogleVertexConfig({
// project: vertexProvider.project,
// location: vertexProvider.location
// })
// .withGoogleCredentials({
// clientEmail: vertexProvider.googleCredentials?.clientEmail || '',
// privateKey: vertexProvider.googleCredentials?.privateKey || ''
// })
// }
} }
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> { export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
@ -121,35 +120,36 @@ export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
/** /**
* Google * Google
*/ */
withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never // withGoogleVertexConfig(options: { project?: string; location?: string }): T extends 'google-vertex' ? this : never
withGoogleVertexConfig(options: any): any { // withGoogleVertexConfig(options: any): any {
if (this.providerId === 'google-vertex') { // if (this.providerId === 'google-vertex') {
const googleConfig = this.config as CompleteProviderConfig<'google-vertex'> // const googleConfig = this.config as CompleteProviderConfig<'google-vertex'>
if (options.project) { // if (options.project) {
googleConfig.project = options.project // googleConfig.project = options.project
} // }
if (options.location) { // if (options.location) {
googleConfig.location = options.location // googleConfig.location = options.location
if (options.location === 'global') { // if (options.location === 'global') {
googleConfig.baseURL = 'https://aiplatform.googleapis.com' // googleConfig.baseURL = 'https://aiplatform.googleapis.com'
} // }
} // }
} // }
return this // return this
} // }
withGoogleCredentials(credentials: { withGoogleCredentials(credentials: {
clientEmail: string clientEmail: string
privateKey: string privateKey: string
}): T extends 'google-vertex' ? this : never }): T extends 'google-vertex' ? this : never
withGoogleCredentials(credentials: any): any { withGoogleCredentials(): any {
if (this.providerId === 'google-vertex') { // withGoogleCredentials(credentials: any): any {
const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'> // if (this.providerId === 'google-vertex') {
vertexConfig.googleCredentials = { // const vertexConfig = this.config as CompleteProviderConfig<'google-vertex'>
clientEmail: credentials.clientEmail, // vertexConfig.googleCredentials = {
privateKey: formatPrivateKey(credentials.privateKey) // clientEmail: credentials.clientEmail,
} // privateKey: formatPrivateKey(credentials.privateKey)
} // }
// }
return this return this
} }
@ -310,23 +310,22 @@ export class ProviderConfigFactory {
/** /**
* Vertex AI * Vertex AI
*/ */
static createVertexAI( static createVertexAI() {
credentials: { // credentials: {
clientEmail: string // clientEmail: string
privateKey: string // privateKey: string
}, // },
options?: { // options?: {
project?: string // project?: string
location?: string // location?: string
} // }
) { // return this.builder('google-vertex')
return this.builder('google-vertex') // .withGoogleCredentials(credentials)
.withGoogleCredentials(credentials) // .withGoogleVertexConfig({
.withGoogleVertexConfig({ // project: options?.project,
project: options?.project, // location: options?.location
location: options?.location // })
}) // .build()
.build()
} }
static createOpenAICompatible(baseURL: string, apiKey: string) { static createOpenAICompatible(baseURL: string, apiKey: string) {

View File

@ -0,0 +1,15 @@
/**
* Providers
*/
// Provider 注册表
export { aiProviderRegistry, getAllProviders, getProvider, isProviderSupported, registerProvider } from './registry'
// Provider 创建
export { createImageProvider, createProvider, ProviderCreationError, validateProviderConfig } from './creator'
// 类型定义
export type { ProviderConfig, ProviderError, ProviderId, ProviderSettingsMap } from './types'
// 工厂和配置
export * from './factory'

View File

@ -3,76 +3,8 @@
* + * +
*/ */
// 静态导入所有 AI SDK 类型 import { type ProviderConfig } from './types'
import { type AmazonBedrockProviderSettings } from '@ai-sdk/amazon-bedrock'
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type CerebrasProviderSettings } from '@ai-sdk/cerebras'
import { type CohereProviderSettings } from '@ai-sdk/cohere'
import { type DeepInfraProviderSettings } from '@ai-sdk/deepinfra'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type FalProviderSettings } from '@ai-sdk/fal'
import { type FireworksProviderSettings } from '@ai-sdk/fireworks'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex/edge'
import { type GroqProviderSettings } from '@ai-sdk/groq'
import { type MistralProviderSettings } from '@ai-sdk/mistral'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type PerplexityProviderSettings } from '@ai-sdk/perplexity'
import { type ReplicateProviderSettings } from '@ai-sdk/replicate'
import { type TogetherAIProviderSettings } from '@ai-sdk/togetherai'
import { type VercelProviderSettings } from '@ai-sdk/vercel'
import { type XaiProviderSettings } from '@ai-sdk/xai'
import { type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
import { type AnthropicVertexProviderSettings } from 'anthropic-vertex-ai'
import { type OllamaProviderSettings } from 'ollama-ai-provider'
// 类型安全的 Provider Settings 映射
export type ProviderSettingsMap = {
openai: OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
openrouter: OpenRouterProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
'google-vertex': GoogleVertexProviderSettings
mistral: MistralProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
bedrock: AmazonBedrockProviderSettings
cohere: CohereProviderSettings
groq: GroqProviderSettings
together: TogetherAIProviderSettings
fireworks: FireworksProviderSettings
deepseek: DeepSeekProviderSettings
cerebras: CerebrasProviderSettings
deepinfra: DeepInfraProviderSettings
replicate: ReplicateProviderSettings
perplexity: PerplexityProviderSettings
fal: FalProviderSettings
vercel: VercelProviderSettings
ollama: OllamaProviderSettings
'anthropic-vertex': AnthropicVertexProviderSettings
}
export type ProviderId = keyof ProviderSettingsMap & string
// 统一的 Provider 配置接口(所有都使用动态导入)
export interface ProviderConfig {
id: string
name: string
// 动态导入函数
import: () => Promise<any>
// 创建函数名称
creatorFunctionName: string
// 是否支持图片生成
supportsImageGeneration?: boolean
}
/**
* AI SDK Provider
* AI Providers
*/
export class AiProviderRegistry { export class AiProviderRegistry {
private static instance: AiProviderRegistry private static instance: AiProviderRegistry
private registry = new Map<string, ProviderConfig>() private registry = new Map<string, ProviderConfig>()
@ -123,20 +55,20 @@ export class AiProviderRegistry {
creatorFunctionName: 'createGoogleGenerativeAI', creatorFunctionName: 'createGoogleGenerativeAI',
supportsImageGeneration: true supportsImageGeneration: true
}, },
{ // {
id: 'google-vertex', // id: 'google-vertex',
name: 'Google Vertex AI', // name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'), // import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex', // creatorFunctionName: 'createVertex',
supportsImageGeneration: true // supportsImageGeneration: true
}, // },
{ // {
id: 'mistral', // id: 'mistral',
name: 'Mistral AI', // name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'), // import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral', // creatorFunctionName: 'createMistral',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ {
id: 'xai', id: 'xai',
name: 'xAI (Grok)', name: 'xAI (Grok)',
@ -151,112 +83,112 @@ export class AiProviderRegistry {
creatorFunctionName: 'createAzure', creatorFunctionName: 'createAzure',
supportsImageGeneration: true supportsImageGeneration: true
}, },
{ // {
id: 'bedrock', // id: 'bedrock',
name: 'Amazon Bedrock', // name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'), // import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock', // creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'cohere', // id: 'cohere',
name: 'Cohere', // name: 'Cohere',
import: () => import('@ai-sdk/cohere'), // import: () => import('@ai-sdk/cohere'),
creatorFunctionName: 'createCohere', // creatorFunctionName: 'createCohere',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'groq', // id: 'groq',
name: 'Groq', // name: 'Groq',
import: () => import('@ai-sdk/groq'), // import: () => import('@ai-sdk/groq'),
creatorFunctionName: 'createGroq', // creatorFunctionName: 'createGroq',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'together', // id: 'together',
name: 'Together.ai', // name: 'Together.ai',
import: () => import('@ai-sdk/togetherai'), // import: () => import('@ai-sdk/togetherai'),
creatorFunctionName: 'createTogetherAI', // creatorFunctionName: 'createTogetherAI',
supportsImageGeneration: true // supportsImageGeneration: true
}, // },
{ // {
id: 'fireworks', // id: 'fireworks',
name: 'Fireworks', // name: 'Fireworks',
import: () => import('@ai-sdk/fireworks'), // import: () => import('@ai-sdk/fireworks'),
creatorFunctionName: 'createFireworks', // creatorFunctionName: 'createFireworks',
supportsImageGeneration: true // supportsImageGeneration: true
}, // },
{ {
id: 'deepseek', id: 'deepseek',
name: 'DeepSeek', name: 'DeepSeek',
import: () => import('@ai-sdk/deepseek'), import: () => import('@ai-sdk/deepseek'),
creatorFunctionName: 'createDeepSeek', creatorFunctionName: 'createDeepSeek',
supportsImageGeneration: false supportsImageGeneration: false
}, }
{ // {
id: 'cerebras', // id: 'cerebras',
name: 'Cerebras', // name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'), // import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras', // creatorFunctionName: 'createCerebras',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'deepinfra', // id: 'deepinfra',
name: 'DeepInfra', // name: 'DeepInfra',
import: () => import('@ai-sdk/deepinfra'), // import: () => import('@ai-sdk/deepinfra'),
creatorFunctionName: 'createDeepInfra', // creatorFunctionName: 'createDeepInfra',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'replicate', // id: 'replicate',
name: 'Replicate', // name: 'Replicate',
import: () => import('@ai-sdk/replicate'), // import: () => import('@ai-sdk/replicate'),
creatorFunctionName: 'createReplicate', // creatorFunctionName: 'createReplicate',
supportsImageGeneration: true // supportsImageGeneration: true
}, // },
{ // {
id: 'perplexity', // id: 'perplexity',
name: 'Perplexity', // name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'), // import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity', // creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'fal', // id: 'fal',
name: 'Fal AI', // name: 'Fal AI',
import: () => import('@ai-sdk/fal'), // import: () => import('@ai-sdk/fal'),
creatorFunctionName: 'createFal', // creatorFunctionName: 'createFal',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'vercel', // id: 'vercel',
name: 'Vercel', // name: 'Vercel',
import: () => import('@ai-sdk/vercel'), // import: () => import('@ai-sdk/vercel'),
creatorFunctionName: 'createVercel' // creatorFunctionName: 'createVercel'
}, // },
// 社区 Providers (5个) // 社区 Providers (5个)
{ // {
id: 'ollama', // id: 'ollama',
name: 'Ollama', // name: 'Ollama',
import: () => import('ollama-ai-provider'), // import: () => import('ollama-ai-provider'),
creatorFunctionName: 'createOllama', // creatorFunctionName: 'createOllama',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'anthropic-vertex', // id: 'anthropic-vertex',
name: 'Anthropic Vertex AI', // name: 'Anthropic Vertex AI',
import: () => import('anthropic-vertex-ai'), // import: () => import('anthropic-vertex-ai'),
creatorFunctionName: 'createAnthropicVertex', // creatorFunctionName: 'createAnthropicVertex',
supportsImageGeneration: false // supportsImageGeneration: false
}, // },
{ // {
id: 'openrouter', // id: 'openrouter',
name: 'OpenRouter', // name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'), // import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter', // creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: false // supportsImageGeneration: false
} // }
] ]
providers.forEach((config) => { providers.forEach((config) => {
@ -289,6 +221,16 @@ export class AiProviderRegistry {
* Provider * Provider
*/ */
public registerProvider(config: ProviderConfig): void { public registerProvider(config: ProviderConfig): void {
// 验证:必须提供 creator 或 (import + creatorFunctionName)
if (!config.creator && !(config.import && config.creatorFunctionName)) {
throw new Error('Must provide either creator function or import configuration')
}
// 验证:不能同时提供两种方式
if (config.creator && config.import) {
console.warn('Both creator and import provided, creator will take precedence')
}
this.registry.set(config.id, config) this.registry.set(config.id, config)
} }
@ -299,21 +241,21 @@ export class AiProviderRegistry {
this.registry.clear() this.registry.clear()
} }
/** // /**
* // * 获取兼容现有实现的注册表格式
*/ // */
public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> { // public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {} // const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
this.getAllProviders().forEach((provider) => { // this.getAllProviders().forEach((provider) => {
compatibleRegistry[provider.id] = { // compatibleRegistry[provider.id] = {
import: provider.import, // import: provider.import,
creatorFunctionName: provider.creatorFunctionName // creatorFunctionName: provider.creatorFunctionName
} // }
}) // })
return compatibleRegistry // return compatibleRegistry
} // }
} }
// 导出单例实例 // 导出单例实例
@ -326,31 +268,4 @@ export const isProviderSupported = (id: string) => aiProviderRegistry.isSupporte
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config) export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
// 兼容现有实现的导出 // 兼容现有实现的导出
export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry() // export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
// 重新导出所有类型供外部使用
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings
}

View File

@ -1,14 +1,35 @@
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type XaiProviderSettings } from '@ai-sdk/xai'
/** /**
* Provider * Provider
* 使 AI SDK * 使 AI SDK
*/ */
export type ProviderId = keyof ProviderSettingsMap & string
// Provider 配置接口(简化版) // Provider 配置接口 - 支持灵活的创建方式
export interface ProviderConfig { export interface ProviderConfig {
id: string id: string
name: string name: string
import: () => Promise<any>
creatorFunctionName: string // 方式一:直接提供 creator 函数(推荐用于自定义)
creator?: (options: any) => any
// 方式二:动态导入 + 函数名(用于包导入)
import?: () => Promise<any>
creatorFunctionName?: string
// 图片生成支持
supportsImageGeneration?: boolean
imageCreator?: (options: any) => any
// 可选的验证函数
validateOptions?: (options: any) => boolean
} }
// API 客户端工厂接口 // API 客户端工厂接口
@ -45,3 +66,57 @@ export interface CacheStats {
keys: string[] keys: string[]
lastCleanup?: Date lastCleanup?: Date
} }
// 类型安全的 Provider Settings 映射
export type ProviderSettingsMap = {
openai: OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
// openrouter: OpenRouterProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
// 'google-vertex': GoogleVertexProviderSettings
// mistral: MistralProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
// bedrock: AmazonBedrockProviderSettings
// cohere: CohereProviderSettings
// groq: GroqProviderSettings
// together: TogetherAIProviderSettings
// fireworks: FireworksProviderSettings
deepseek: DeepSeekProviderSettings
// cerebras: CerebrasProviderSettings
// deepinfra: DeepInfraProviderSettings
// replicate: ReplicateProviderSettings
// perplexity: PerplexityProviderSettings
// fal: FalProviderSettings
// vercel: VercelProviderSettings
// ollama: OllamaProviderSettings
// 'anthropic-vertex': AnthropicVertexProviderSettings
}
// 重新导出所有类型供外部使用
export type {
// AmazonBedrockProviderSettings,
AnthropicProviderSettings,
// AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
// CerebrasProviderSettings,
// CohereProviderSettings,
// DeepInfraProviderSettings,
DeepSeekProviderSettings,
// FalProviderSettings,
// FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
// GoogleVertexProviderSettings,
// GroqProviderSettings,
// MistralProviderSettings,
// OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
// OpenRouterProviderSettings,
// PerplexityProviderSettings,
// ReplicateProviderSettings,
// TogetherAIProviderSettings,
// VercelProviderSettings,
XaiProviderSettings
}

View File

@ -7,19 +7,14 @@
export { RuntimeExecutor } from './executor' export { RuntimeExecutor } from './executor'
// 导出类型 // 导出类型
export type { export type { RuntimeConfig } from './types'
ExecutionOptions,
// 向后兼容的类型别名
ExecutorConfig,
RuntimeConfig
} from './types'
// === 便捷工厂函数 === // === 便捷工厂函数 ===
import { LanguageModelV2Middleware } from '@ai-sdk/provider' import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin } from '../plugins' import { type AiPlugin } from '../plugins'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
import { RuntimeExecutor } from './executor' import { RuntimeExecutor } from './executor'
/** /**

View File

@ -1,8 +1,8 @@
import { LanguageModel } from 'ai' import { LanguageModel } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin, createContext, PluginManager } from '../plugins' import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry' import { isProviderSupported } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
/** /**
* AI * AI

View File

@ -1,9 +1,9 @@
/** /**
* Runtime * Runtime
*/ */
import { type ProviderId } from '../../types' import { type ModelConfig } from '../models/types'
import { type ModelConfig } from '../models'
import { type AiPlugin } from '../plugins' import { type AiPlugin } from '../plugins'
import { type ProviderId } from '../providers/types'
/** /**
* *
@ -13,14 +13,3 @@ export interface RuntimeConfig<T extends ProviderId = ProviderId> {
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' } providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
plugins?: AiPlugin[] plugins?: AiPlugin[]
} }
/**
*
*/
export interface ExecutionOptions {
// 未来可以添加执行级别的选项
// 比如:超时设置、重试机制等
}
// 保留旧类型以保持向后兼容
export interface ExecutorConfig<T extends ProviderId = ProviderId> extends RuntimeConfig<T> {}

View File

@ -9,14 +9,15 @@ import {
getSupportedProviders as factoryGetSupportedProviders getSupportedProviders as factoryGetSupportedProviders
} from './core/models' } from './core/models'
import { aiProviderRegistry, isProviderSupported } from './core/providers/registry' import { aiProviderRegistry, isProviderSupported } from './core/providers/registry'
import type { ProviderId } from './core/providers/types'
import type { ProviderSettingsMap } from './core/providers/types'
import { createExecutor } from './core/runtime' import { createExecutor } from './core/runtime'
import type { ProviderId, ProviderSettingsMap } from './types'
// ==================== 主要用户接口 ==================== // ==================== 主要用户接口 ====================
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime' export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
// ==================== 高级API ==================== // ==================== 高级API ====================
export { createModel, type ModelConfig } from './core/models' export { createModel } from './core/models'
// ==================== 插件系统 ==================== // ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins' export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
@ -30,43 +31,27 @@ export {
createImageModel, createImageModel,
getProviderInfo as getClientInfo, getProviderInfo as getClientInfo,
getSupportedProviders, getSupportedProviders,
ProviderCreationError ModelCreationError
} from './core/models' } from './core/models'
export { aiProviderRegistry } from './core/providers/registry' export { aiProviderRegistry } from './core/providers/registry'
// ==================== 类型定义 ==================== // ==================== 类型定义 ====================
export type { ProviderConfig } from './core/providers/registry' export type { ProviderConfig } from './core/providers/types'
export type { ProviderError } from './core/providers/types' export type { ProviderError } from './core/providers/types'
export type { export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings, AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings, AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings, DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GenerateObjectParams, GenerateObjectParams,
GenerateTextParams, GenerateTextParams,
GoogleGenerativeAIProviderSettings, GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings, OpenAICompatibleProviderSettings,
OpenAIProviderSettings, OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId, ProviderId,
ProviderSettings, ProviderSettings,
ProviderSettingsMap, ProviderSettingsMap,
ReplicateProviderSettings,
StreamObjectParams, StreamObjectParams,
StreamTextParams, StreamTextParams,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings XaiProviderSettings
} from './types' } from './types'
export * as aiSdk from 'ai' export * as aiSdk from 'ai'

View File

@ -1,6 +1,6 @@
import { generateObject, generateText, streamObject, streamText } from 'ai' import { generateObject, generateText, streamObject, streamText } from 'ai'
import type { ProviderSettingsMap } from './core/providers/registry' import type { ProviderSettingsMap } from './core/providers/types'
// ProviderSettings 是所有 Provider Settings 的联合类型 // ProviderSettings 是所有 Provider Settings 的联合类型
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap] export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
@ -12,32 +12,16 @@ export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'm
// 重新导出 ProviderSettingsMap 中的所有类型 // 重新导出 ProviderSettingsMap 中的所有类型
export type { export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings, AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings, AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings, DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings, GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings, OpenAICompatibleProviderSettings,
OpenAIProviderSettings, OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId, ProviderId,
ProviderSettingsMap, ProviderSettingsMap,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings XaiProviderSettings
} from './core/providers/registry' } from './core/providers/types'
// 重新导出插件类型 // 重新导出插件类型
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types' export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'

View File

@ -17,7 +17,7 @@ import {
type ProviderSettingsMap, type ProviderSettingsMap,
StreamTextParams StreamTextParams
} from '@cherrystudio/ai-core' } from '@cherrystudio/ai-core'
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core' import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'