refactor: streamline model configuration and factory functions

- Updated the `createModel` function to accept a simplified `ModelConfig` interface, enhancing clarity and usability.
- Refactored `createBaseModel` to destructure parameters for better readability and maintainability.
- Removed the `ModelCreator.ts` file as its functionality has been integrated into the factory functions.
- Adjusted type definitions in `types.ts` to reflect changes in model configuration structure, ensuring consistency across the codebase.
This commit is contained in:
lizhixuan 2025-07-07 00:34:32 +08:00
parent 547e5785c0
commit c92475b6bf
11 changed files with 169 additions and 167 deletions

View File

@ -1,31 +0,0 @@
/**
*
* model给用户
*/
import { LanguageModel } from 'ai'
import { wrapModelWithMiddlewares } from '../middleware'
import { createBaseModel } from './ProviderCreator'
import { ModelCreationRequest, ResolvedConfig } from './types'
/**
*
*/
export async function createModelFromConfig(config: ResolvedConfig): Promise<LanguageModel> {
// 使用ProviderCreator创建基础模型不应用中间件
const baseModel = await createBaseModel(config.provider.id, config.model.id, config.provider.options)
// 在creation层应用中间件用户不直接接触原始model
return wrapModelWithMiddlewares(baseModel, config.middlewares)
}
/**
*
*/
export async function createModel(request: ModelCreationRequest): Promise<LanguageModel> {
// 使用ProviderCreator创建基础模型不应用中间件
const baseModel = await createBaseModel(request.providerId, request.modelId, request.options)
const middlewares = request.middlewares || []
return wrapModelWithMiddlewares(baseModel, middlewares)
}

View File

@ -3,7 +3,7 @@
* AI SDK providers * AI SDK providers
*/ */
import type { ImageModelV1 } from '@ai-sdk/provider' import type { ImageModelV1 } from '@ai-sdk/provider'
import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai' import { type LanguageModelV1 } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
@ -25,26 +25,43 @@ export class ProviderCreationError extends Error {
* AI SDK * AI SDK
* Provider 使 Provider 使 openai-compatible * Provider 使 Provider 使 openai-compatible
*/ */
export async function createBaseModel<T extends ProviderId>( export async function createBaseModel<T extends ProviderId>({
providerId: T, providerId,
modelId: string, modelId,
options: ProviderSettingsMap[T], providerSettings
middlewares?: LanguageModelV1Middleware[] // middlewares
): Promise<LanguageModelV1> }: {
providerId: T
modelId: string
providerSettings: ProviderSettingsMap[T]
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV1>
export async function createBaseModel( export async function createBaseModel({
providerId: string, providerId,
modelId: string, modelId,
options: ProviderSettingsMap['openai-compatible'], providerSettings
middlewares?: LanguageModelV1Middleware[] // middlewares
): Promise<LanguageModelV1> }: {
providerId: string
modelId: string
providerSettings: ProviderSettingsMap['openai-compatible']
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV1>
export async function createBaseModel( export async function createBaseModel({
providerId: string, providerId,
modelId: string = 'default', modelId,
options: any, providerSettings,
middlewares?: LanguageModelV1Middleware[] // middlewares,
): Promise<LanguageModelV1> { extraModelConfig
}: {
providerId: string
modelId: string
providerSettings: ProviderSettingsMap[ProviderId]
// middlewares?: LanguageModelV1Middleware[]
extraModelConfig?: any
}): Promise<LanguageModelV1> {
try { try {
// 对于不在注册表中的 provider默认使用 openai-compatible // 对于不在注册表中的 provider默认使用 openai-compatible
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible' const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
@ -67,7 +84,7 @@ export async function createBaseModel(
) )
} }
// 创建provider实例 // 创建provider实例
let provider = creatorFunction(options) let provider = creatorFunction(providerSettings)
// 加一个特判 // 加一个特判
if (providerConfig.id === 'openai' && !isOpenAIChatCompletionOnlyModel(modelId)) { if (providerConfig.id === 'openai' && !isOpenAIChatCompletionOnlyModel(modelId)) {
@ -75,15 +92,16 @@ export async function createBaseModel(
} }
// 返回模型实例 // 返回模型实例
if (typeof provider === 'function') { if (typeof provider === 'function') {
let model: LanguageModelV1 = provider(modelId) // extraModelConfig:例如google的useSearchGrounding
const model: LanguageModelV1 = provider(modelId, extraModelConfig)
// 应用 AI SDK 中间件 // // 应用 AI SDK 中间件
if (middlewares && middlewares.length > 0) { // if (middlewares && middlewares.length > 0) {
model = wrapLanguageModel({ // model = wrapLanguageModel({
model: model, // model: model,
middleware: middlewares // middleware: middlewares
}) // })
} // }
return model return model
} else { } else {

View File

@ -2,18 +2,11 @@
* *
* *
*/ */
import { LanguageModel, LanguageModelV1Middleware } from 'ai' import { LanguageModel } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { wrapModelWithMiddlewares } from '../middleware' import { wrapModelWithMiddlewares } from '../middleware'
import { createBaseModel } from './ProviderCreator' import { createBaseModel } from './ProviderCreator'
import { ModelConfig } from './types'
export interface ModelConfig {
providerId: ProviderId
modelId: string
options: ProviderSettingsMap[ProviderId]
middlewares?: LanguageModelV1Middleware[]
}
/** /**
* - * -
@ -22,7 +15,7 @@ export async function createModel(config: ModelConfig): Promise<LanguageModel> {
validateModelConfig(config) validateModelConfig(config)
// 1. 创建基础模型 // 1. 创建基础模型
const baseModel = await createBaseModel(config.providerId, config.modelId, config.options) const baseModel = await createBaseModel(config)
// 2. 应用中间件(如果有) // 2. 应用中间件(如果有)
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
@ -45,7 +38,7 @@ function validateModelConfig(config: ModelConfig): void {
if (!config.modelId) { if (!config.modelId) {
throw new Error('ModelConfig: modelId is required') throw new Error('ModelConfig: modelId is required')
} }
if (!config.options) { if (!config.providerSettings) {
throw new Error('ModelConfig: options is required') throw new Error('ModelConfig: providerSettings is required')
} }
} }

View File

@ -4,7 +4,7 @@
*/ */
// 主要的模型创建API // 主要的模型创建API
export { createModel, createModels, type ModelConfig } from './factory' export { createModel, createModels } from './factory'
// 底层Provider创建功能供高级用户使用 // 底层Provider创建功能供高级用户使用
export { export {
@ -16,4 +16,4 @@ export {
} from './ProviderCreator' } from './ProviderCreator'
// 保留原有类型 // 保留原有类型
export type { ModelCreationRequest, ResolvedConfig } from './types' export type { ModelConfig } from './types'

View File

@ -1,32 +1,44 @@
/** // /**
* Creation // * Creation 模块类型定义
*/ // */
// import { LanguageModelV1Middleware } from 'ai'
// import { ProviderId, ProviderSettingsMap } from '../../types'
// import { AiPlugin } from '../plugins'
// /**
// * 模型创建请求
// */
// export interface ModelCreationRequest {
// providerId: ProviderId
// modelId: string
// options: ProviderSettingsMap[ProviderId]
// middlewares?: LanguageModelV1Middleware[]
// }
// /**
// * 配置解析结果
// */
// export interface ResolvedConfig {
// provider: {
// id: ProviderId
// options: ProviderSettingsMap[ProviderId]
// }
// model: {
// id: string
// }
// plugins: AiPlugin[]
// middlewares: LanguageModelV1Middleware[]
// }
import { LanguageModelV1Middleware } from 'ai' import { LanguageModelV1Middleware } from 'ai'
import { ProviderId, ProviderSettingsMap } from '../../types' import type { ProviderId, ProviderSettingsMap } from '../../types'
import { AiPlugin } from '../plugins'
/** export interface ModelConfig {
*
*/
export interface ModelCreationRequest {
providerId: ProviderId providerId: ProviderId
modelId: string modelId: string
options: ProviderSettingsMap[ProviderId] providerSettings: ProviderSettingsMap[ProviderId]
middlewares?: LanguageModelV1Middleware[] middlewares?: LanguageModelV1Middleware[]
} extraModelConfig?: any
/**
*
*/
export interface ResolvedConfig {
provider: {
id: ProviderId
options: ProviderSettingsMap[ProviderId]
}
model: {
id: string
}
plugins: AiPlugin[]
middlewares: LanguageModelV1Middleware[]
} }

View File

@ -49,27 +49,29 @@ export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConf
} }
/** /**
*
* Gemini * Gemini
* googleSearch providerOptions.google.tools * googleSearch providerOptions.google.tools
*/ */
export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { // export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig // const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
const googleSearchTool = { googleSearch: {} } // const googleSearchTool = { googleSearch: {} }
const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : [] // const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : []
return { // return {
...params, // ...params,
providerOptions: { // providerOptions: {
...params.providerOptions, // ...params.providerOptions,
google: { // google: {
...params.providerOptions?.google, // ...params.providerOptions?.google,
tools: [...existingTools, googleSearchTool], // useSearchGrounding: true,
...(config.extra || {}) // // tools: [...existingTools, googleSearchTool],
} // ...(config.extra || {})
} // }
} // }
} // }
// }
/** /**
* Anthropic * Anthropic
@ -113,9 +115,10 @@ export function adaptWebSearchForProvider(
case 'openai': case 'openai':
return adaptOpenAIWebSearch(params, webSearchConfig) return adaptOpenAIWebSearch(params, webSearchConfig)
case 'google': // google的需要通过插件在创建model的时候传入参数
case 'google-vertex': // case 'google':
return adaptGeminiWebSearch(params, webSearchConfig) // case 'google-vertex':
// return adaptGeminiWebSearch(params, webSearchConfig)
case 'anthropic': case 'anthropic':
return adaptAnthropicWebSearch(params, webSearchConfig) return adaptAnthropicWebSearch(params, webSearchConfig)
@ -125,12 +128,3 @@ export function adaptWebSearchForProvider(
return params return params
} }
} }
/**
* provider
*/
export function isWebSearchSupported(providerId: string): boolean {
const supportedProviders = ['openai', 'google', 'google-vertex', 'anthropic']
return supportedProviders.includes(providerId)
}

View File

@ -5,7 +5,7 @@
import { definePlugin } from '../../' import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types' import type { AiRequestContext } from '../../types'
import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig } from './helper' import { adaptWebSearchForProvider, type WebSearchConfig } from './helper'
/** /**
* *
@ -14,42 +14,51 @@ import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig }
* options.ts assistant.enableWebSearch * options.ts assistant.enableWebSearch
* providerOptions webSearch: { enabled: true } * providerOptions webSearch: { enabled: true }
*/ */
export const webSearchPlugin = definePlugin({ export const webSearchPlugin = (config) =>
name: 'webSearch', definePlugin({
name: 'webSearch',
enforce: 'pre',
transformParams: async (params: any, context: AiRequestContext) => { // configureModel: async (modelConfig: any, context: AiRequestContext) => {
const { providerId } = context // if (context.providerId === 'google') {
// return {
// ...modelConfig
// }
// }
// return null
// },
// 从 providerOptions 中提取 webSearch 配置 transformParams: async (params: any, context: AiRequestContext) => {
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch const { providerId } = context
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用) // 从 providerOptions 中提取 webSearch 配置
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) { const webSearchConfig = params.providerOptions?.[providerId]?.webSearch
return params
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用)
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) {
return params
}
console.log('webSearchConfig', webSearchConfig)
// // 检查当前 provider 是否支持网络搜索
// if (!isWebSearchSupported(providerId)) {
// // 对于不支持的 provider只记录警告不修改参数
// console.warn(
// `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
// )
// return params
// }
// 使用适配器函数处理网络搜索
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
// 清理原始的 webSearch 配置
if (adaptedParams.providerOptions?.[providerId]) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
adaptedParams.providerOptions[providerId] = rest
}
return adaptedParams
} }
})
// 检查当前 provider 是否支持网络搜索
if (!isWebSearchSupported(providerId)) {
// 对于不支持的 provider只记录警告不修改参数
console.warn(
`[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
)
return params
}
// 使用适配器函数处理网络搜索
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
// 清理原始的 webSearch 配置
if (adaptedParams.providerOptions?.[providerId]) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
adaptedParams.providerOptions[providerId] = rest
}
return adaptedParams
}
})
// 导出类型定义供开发者使用 // 导出类型定义供开发者使用
export type { WebSearchConfig } from './helper' export type { WebSearchConfig } from './helper'

View File

@ -1,12 +1,17 @@
// 核心类型和接口 // 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types' export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types'
import type { ProviderId } from '../../types'
import type { AiPlugin, AiRequestContext } from './types' import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器 // 插件管理器
export { PluginManager } from './manager' export { PluginManager } from './manager'
// 工具函数 // 工具函数
export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext { export function createContext<T extends ProviderId>(
providerId: T,
modelId: string,
originalParams: any
): AiRequestContext {
return { return {
providerId, providerId,
modelId, modelId,

View File

@ -52,7 +52,7 @@ export class PluginManager {
*/ */
async executeFirst<T>( async executeFirst<T>(
hookName: 'resolveModel' | 'loadTemplate', hookName: 'resolveModel' | 'loadTemplate',
arg: string, arg: any,
context: AiRequestContext context: AiRequestContext
): Promise<T | null> { ): Promise<T | null> {
for (const plugin of this.plugins) { for (const plugin of this.plugins) {
@ -71,7 +71,7 @@ export class PluginManager {
* Sequential - * Sequential -
*/ */
async executeSequential<T>( async executeSequential<T>(
hookName: 'transformParams' | 'transformResult', hookName: 'transformParams' | 'transformResult' | 'configureModel',
initialValue: T, initialValue: T,
context: AiRequestContext context: AiRequestContext
): Promise<T> { ): Promise<T> {
@ -120,7 +120,9 @@ export class PluginManager {
* AI SDK * AI SDK
*/ */
collectStreamTransforms(params: any, context: AiRequestContext) { collectStreamTransforms(params: any, context: AiRequestContext) {
return this.plugins.map((plugin) => plugin.transformStream?.(params, context)) return this.plugins
.filter((plugin) => plugin.transformStream)
.map((plugin) => plugin.transformStream?.(params, context))
} }
/** /**

View File

@ -217,7 +217,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
return await createModel({ return await createModel({
providerId: this.config.providerId, providerId: this.config.providerId,
modelId: modelOrId, modelId: modelOrId,
options: this.config.options, providerSettings: this.config.providerSettings,
middlewares middlewares
}) })
} else { } else {
@ -245,7 +245,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
): RuntimeExecutor<T> { ): RuntimeExecutor<T> {
return new RuntimeExecutor({ return new RuntimeExecutor({
providerId, providerId,
options, providerSettings: options,
plugins plugins
}) })
} }
@ -259,7 +259,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
): RuntimeExecutor<'openai-compatible'> { ): RuntimeExecutor<'openai-compatible'> {
return new RuntimeExecutor({ return new RuntimeExecutor({
providerId: 'openai-compatible', providerId: 'openai-compatible',
options, providerSettings: options,
plugins plugins
}) })
} }

View File

@ -9,7 +9,7 @@ import { type AiPlugin } from '../plugins'
*/ */
export interface RuntimeConfig<T extends ProviderId = ProviderId> { export interface RuntimeConfig<T extends ProviderId = ProviderId> {
providerId: T providerId: T
options: ProviderSettingsMap[T] providerSettings: ProviderSettingsMap[T]
plugins?: AiPlugin[] plugins?: AiPlugin[]
} }