From c92475b6bf66b992a9aff856517cfe6415eb0866 Mon Sep 17 00:00:00 2001 From: lizhixuan Date: Mon, 7 Jul 2025 00:34:32 +0800 Subject: [PATCH] refactor: streamline model configuration and factory functions - Updated the `createModel` function to accept a simplified `ModelConfig` interface, enhancing clarity and usability. - Refactored `createBaseModel` to destructure parameters for better readability and maintainability. - Removed the `ModelCreator.ts` file as its functionality has been integrated into the factory functions. - Adjusted type definitions in `types.ts` to reflect changes in model configuration structure, ensuring consistency across the codebase. --- .../aiCore/src/core/models/ModelCreator.ts | 31 -------- .../aiCore/src/core/models/ProviderCreator.ts | 74 +++++++++++------- packages/aiCore/src/core/models/factory.ts | 17 ++--- packages/aiCore/src/core/models/index.ts | 4 +- packages/aiCore/src/core/models/types.ts | 62 ++++++++------- .../built-in/webSearchPlugin/helper.ts | 50 ++++++------- .../plugins/built-in/webSearchPlugin/index.ts | 75 +++++++++++-------- packages/aiCore/src/core/plugins/index.ts | 7 +- packages/aiCore/src/core/plugins/manager.ts | 8 +- packages/aiCore/src/core/runtime/executor.ts | 6 +- packages/aiCore/src/core/runtime/types.ts | 2 +- 11 files changed, 169 insertions(+), 167 deletions(-) delete mode 100644 packages/aiCore/src/core/models/ModelCreator.ts diff --git a/packages/aiCore/src/core/models/ModelCreator.ts b/packages/aiCore/src/core/models/ModelCreator.ts deleted file mode 100644 index ef63893d3c..0000000000 --- a/packages/aiCore/src/core/models/ModelCreator.ts +++ /dev/null @@ -1,31 +0,0 @@ -/** - * 模型创建器 - * 负责创建模型并应用中间件,不暴露原始model给用户 - */ -import { LanguageModel } from 'ai' - -import { wrapModelWithMiddlewares } from '../middleware' -import { createBaseModel } from './ProviderCreator' -import { ModelCreationRequest, ResolvedConfig } from './types' - -/** - * 根据解析后的配置创建包装好的模型 - */ -export async function createModelFromConfig(config: ResolvedConfig): Promise { - // 使用ProviderCreator创建基础模型(不应用中间件) - const baseModel = await createBaseModel(config.provider.id, config.model.id, config.provider.options) - - // 在creation层应用中间件,用户不直接接触原始model - return wrapModelWithMiddlewares(baseModel, config.middlewares) -} - -/** - * 直接根据请求参数创建模型 - */ -export async function createModel(request: ModelCreationRequest): Promise { - // 使用ProviderCreator创建基础模型(不应用中间件) - const baseModel = await createBaseModel(request.providerId, request.modelId, request.options) - - const middlewares = request.middlewares || [] - return wrapModelWithMiddlewares(baseModel, middlewares) -} diff --git a/packages/aiCore/src/core/models/ProviderCreator.ts b/packages/aiCore/src/core/models/ProviderCreator.ts index d17ce491c2..81d4774210 100644 --- a/packages/aiCore/src/core/models/ProviderCreator.ts +++ b/packages/aiCore/src/core/models/ProviderCreator.ts @@ -3,7 +3,7 @@ * 负责动态导入 AI SDK providers 并创建基础模型实例 */ import type { ImageModelV1 } from '@ai-sdk/provider' -import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai' +import { type LanguageModelV1 } from 'ai' import { type ProviderId, type ProviderSettingsMap } from '../../types' import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' @@ -25,26 +25,43 @@ export class ProviderCreationError extends Error { * 创建基础 AI SDK 模型实例 * 对于已知的 Provider 使用严格类型检查,未知的 Provider 默认使用 openai-compatible */ -export async function createBaseModel( - providerId: T, - modelId: string, - options: ProviderSettingsMap[T], - middlewares?: LanguageModelV1Middleware[] -): Promise +export async function createBaseModel({ + providerId, + modelId, + providerSettings + // middlewares +}: { + providerId: T + modelId: string + providerSettings: ProviderSettingsMap[T] + // middlewares?: LanguageModelV1Middleware[] +}): Promise -export async function createBaseModel( - providerId: string, - modelId: string, - options: ProviderSettingsMap['openai-compatible'], - middlewares?: LanguageModelV1Middleware[] -): Promise +export async function createBaseModel({ + providerId, + modelId, + providerSettings + // middlewares +}: { + providerId: string + modelId: string + providerSettings: ProviderSettingsMap['openai-compatible'] + // middlewares?: LanguageModelV1Middleware[] +}): Promise -export async function createBaseModel( - providerId: string, - modelId: string = 'default', - options: any, - middlewares?: LanguageModelV1Middleware[] -): Promise { +export async function createBaseModel({ + providerId, + modelId, + providerSettings, + // middlewares, + extraModelConfig +}: { + providerId: string + modelId: string + providerSettings: ProviderSettingsMap[ProviderId] + // middlewares?: LanguageModelV1Middleware[] + extraModelConfig?: any +}): Promise { try { // 对于不在注册表中的 provider,默认使用 openai-compatible const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible' @@ -67,7 +84,7 @@ export async function createBaseModel( ) } // 创建provider实例 - let provider = creatorFunction(options) + let provider = creatorFunction(providerSettings) // 加一个特判 if (providerConfig.id === 'openai' && !isOpenAIChatCompletionOnlyModel(modelId)) { @@ -75,15 +92,16 @@ export async function createBaseModel( } // 返回模型实例 if (typeof provider === 'function') { - let model: LanguageModelV1 = provider(modelId) + // extraModelConfig:例如google的useSearchGrounding + const model: LanguageModelV1 = provider(modelId, extraModelConfig) - // 应用 AI SDK 中间件 - if (middlewares && middlewares.length > 0) { - model = wrapLanguageModel({ - model: model, - middleware: middlewares - }) - } + // // 应用 AI SDK 中间件 + // if (middlewares && middlewares.length > 0) { + // model = wrapLanguageModel({ + // model: model, + // middleware: middlewares + // }) + // } return model } else { diff --git a/packages/aiCore/src/core/models/factory.ts b/packages/aiCore/src/core/models/factory.ts index acc539d436..67b6fa1edc 100644 --- a/packages/aiCore/src/core/models/factory.ts +++ b/packages/aiCore/src/core/models/factory.ts @@ -2,18 +2,11 @@ * 模型工厂函数 * 统一的模型创建和配置管理 */ -import { LanguageModel, LanguageModelV1Middleware } from 'ai' +import { LanguageModel } from 'ai' -import { type ProviderId, type ProviderSettingsMap } from '../../types' import { wrapModelWithMiddlewares } from '../middleware' import { createBaseModel } from './ProviderCreator' - -export interface ModelConfig { - providerId: ProviderId - modelId: string - options: ProviderSettingsMap[ProviderId] - middlewares?: LanguageModelV1Middleware[] -} +import { ModelConfig } from './types' /** * 创建模型 - 核心函数 @@ -22,7 +15,7 @@ export async function createModel(config: ModelConfig): Promise { validateModelConfig(config) // 1. 创建基础模型 - const baseModel = await createBaseModel(config.providerId, config.modelId, config.options) + const baseModel = await createBaseModel(config) // 2. 应用中间件(如果有) return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel @@ -45,7 +38,7 @@ function validateModelConfig(config: ModelConfig): void { if (!config.modelId) { throw new Error('ModelConfig: modelId is required') } - if (!config.options) { - throw new Error('ModelConfig: options is required') + if (!config.providerSettings) { + throw new Error('ModelConfig: providerSettings is required') } } diff --git a/packages/aiCore/src/core/models/index.ts b/packages/aiCore/src/core/models/index.ts index cee1eab374..89f6a31514 100644 --- a/packages/aiCore/src/core/models/index.ts +++ b/packages/aiCore/src/core/models/index.ts @@ -4,7 +4,7 @@ */ // 主要的模型创建API -export { createModel, createModels, type ModelConfig } from './factory' +export { createModel, createModels } from './factory' // 底层Provider创建功能(供高级用户使用) export { @@ -16,4 +16,4 @@ export { } from './ProviderCreator' // 保留原有类型 -export type { ModelCreationRequest, ResolvedConfig } from './types' +export type { ModelConfig } from './types' diff --git a/packages/aiCore/src/core/models/types.ts b/packages/aiCore/src/core/models/types.ts index e088d5d9f1..92b5f3139a 100644 --- a/packages/aiCore/src/core/models/types.ts +++ b/packages/aiCore/src/core/models/types.ts @@ -1,32 +1,44 @@ -/** - * Creation 模块类型定义 - */ +// /** +// * Creation 模块类型定义 +// */ +// import { LanguageModelV1Middleware } from 'ai' + +// import { ProviderId, ProviderSettingsMap } from '../../types' +// import { AiPlugin } from '../plugins' + +// /** +// * 模型创建请求 +// */ +// export interface ModelCreationRequest { +// providerId: ProviderId +// modelId: string +// options: ProviderSettingsMap[ProviderId] +// middlewares?: LanguageModelV1Middleware[] +// } + +// /** +// * 配置解析结果 +// */ +// export interface ResolvedConfig { +// provider: { +// id: ProviderId +// options: ProviderSettingsMap[ProviderId] +// } +// model: { +// id: string +// } +// plugins: AiPlugin[] +// middlewares: LanguageModelV1Middleware[] +// } + import { LanguageModelV1Middleware } from 'ai' -import { ProviderId, ProviderSettingsMap } from '../../types' -import { AiPlugin } from '../plugins' +import type { ProviderId, ProviderSettingsMap } from '../../types' -/** - * 模型创建请求 - */ -export interface ModelCreationRequest { +export interface ModelConfig { providerId: ProviderId modelId: string - options: ProviderSettingsMap[ProviderId] + providerSettings: ProviderSettingsMap[ProviderId] middlewares?: LanguageModelV1Middleware[] -} - -/** - * 配置解析结果 - */ -export interface ResolvedConfig { - provider: { - id: ProviderId - options: ProviderSettingsMap[ProviderId] - } - model: { - id: string - } - plugins: AiPlugin[] - middlewares: LanguageModelV1Middleware[] + extraModelConfig?: any } diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts index 7b4cee56fc..f3e2e7cf1e 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/helper.ts @@ -49,27 +49,29 @@ export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConf } /** + * * 适配 Gemini 网络搜索 * 将 googleSearch 工具放入 providerOptions.google.tools */ -export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { - const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig - const googleSearchTool = { googleSearch: {} } +// export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any { +// const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig +// const googleSearchTool = { googleSearch: {} } - const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : [] +// const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : [] - return { - ...params, - providerOptions: { - ...params.providerOptions, - google: { - ...params.providerOptions?.google, - tools: [...existingTools, googleSearchTool], - ...(config.extra || {}) - } - } - } -} +// return { +// ...params, +// providerOptions: { +// ...params.providerOptions, +// google: { +// ...params.providerOptions?.google, +// useSearchGrounding: true, +// // tools: [...existingTools, googleSearchTool], +// ...(config.extra || {}) +// } +// } +// } +// } /** * 适配 Anthropic 网络搜索 @@ -113,9 +115,10 @@ export function adaptWebSearchForProvider( case 'openai': return adaptOpenAIWebSearch(params, webSearchConfig) - case 'google': - case 'google-vertex': - return adaptGeminiWebSearch(params, webSearchConfig) + // google的需要通过插件,在创建model的时候传入参数 + // case 'google': + // case 'google-vertex': + // return adaptGeminiWebSearch(params, webSearchConfig) case 'anthropic': return adaptAnthropicWebSearch(params, webSearchConfig) @@ -125,12 +128,3 @@ export function adaptWebSearchForProvider( return params } } - -/** - * 检查 provider 是否支持网络搜索 - */ -export function isWebSearchSupported(providerId: string): boolean { - const supportedProviders = ['openai', 'google', 'google-vertex', 'anthropic'] - - return supportedProviders.includes(providerId) -} diff --git a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts index 4833f45f2f..86d4b4da0e 100644 --- a/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts +++ b/packages/aiCore/src/core/plugins/built-in/webSearchPlugin/index.ts @@ -5,7 +5,7 @@ import { definePlugin } from '../../' import type { AiRequestContext } from '../../types' -import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig } from './helper' +import { adaptWebSearchForProvider, type WebSearchConfig } from './helper' /** * 网络搜索插件 @@ -14,42 +14,51 @@ import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig } * options.ts 文件负责将高层级的设置(如 assistant.enableWebSearch) * 转换为 providerOptions 中的 webSearch: { enabled: true } 配置。 */ -export const webSearchPlugin = definePlugin({ - name: 'webSearch', +export const webSearchPlugin = (config) => + definePlugin({ + name: 'webSearch', + enforce: 'pre', - transformParams: async (params: any, context: AiRequestContext) => { - const { providerId } = context + // configureModel: async (modelConfig: any, context: AiRequestContext) => { + // if (context.providerId === 'google') { + // return { + // ...modelConfig + // } + // } + // return null + // }, - // 从 providerOptions 中提取 webSearch 配置 - const webSearchConfig = params.providerOptions?.[providerId]?.webSearch + transformParams: async (params: any, context: AiRequestContext) => { + const { providerId } = context - // 检查是否启用了网络搜索 (enabled: false 可用于显式禁用) - if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) { - return params + // 从 providerOptions 中提取 webSearch 配置 + const webSearchConfig = params.providerOptions?.[providerId]?.webSearch + + // 检查是否启用了网络搜索 (enabled: false 可用于显式禁用) + if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) { + return params + } + console.log('webSearchConfig', webSearchConfig) + // // 检查当前 provider 是否支持网络搜索 + // if (!isWebSearchSupported(providerId)) { + // // 对于不支持的 provider,只记录警告,不修改参数 + // console.warn( + // `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.` + // ) + // return params + // } + + // 使用适配器函数处理网络搜索 + const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean) + // 清理原始的 webSearch 配置 + if (adaptedParams.providerOptions?.[providerId]) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { webSearch, ...rest } = adaptedParams.providerOptions[providerId] + adaptedParams.providerOptions[providerId] = rest + } + return adaptedParams } - - // 检查当前 provider 是否支持网络搜索 - if (!isWebSearchSupported(providerId)) { - // 对于不支持的 provider,只记录警告,不修改参数 - console.warn( - `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.` - ) - return params - } - - // 使用适配器函数处理网络搜索 - const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean) - - // 清理原始的 webSearch 配置 - if (adaptedParams.providerOptions?.[providerId]) { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { webSearch, ...rest } = adaptedParams.providerOptions[providerId] - adaptedParams.providerOptions[providerId] = rest - } - - return adaptedParams - } -}) + }) // 导出类型定义供开发者使用 export type { WebSearchConfig } from './helper' diff --git a/packages/aiCore/src/core/plugins/index.ts b/packages/aiCore/src/core/plugins/index.ts index 06833fbc99..5883c0a41d 100644 --- a/packages/aiCore/src/core/plugins/index.ts +++ b/packages/aiCore/src/core/plugins/index.ts @@ -1,12 +1,17 @@ // 核心类型和接口 export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types' +import type { ProviderId } from '../../types' import type { AiPlugin, AiRequestContext } from './types' // 插件管理器 export { PluginManager } from './manager' // 工具函数 -export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext { +export function createContext( + providerId: T, + modelId: string, + originalParams: any +): AiRequestContext { return { providerId, modelId, diff --git a/packages/aiCore/src/core/plugins/manager.ts b/packages/aiCore/src/core/plugins/manager.ts index d37976b288..5c381aa964 100644 --- a/packages/aiCore/src/core/plugins/manager.ts +++ b/packages/aiCore/src/core/plugins/manager.ts @@ -52,7 +52,7 @@ export class PluginManager { */ async executeFirst( hookName: 'resolveModel' | 'loadTemplate', - arg: string, + arg: any, context: AiRequestContext ): Promise { for (const plugin of this.plugins) { @@ -71,7 +71,7 @@ export class PluginManager { * 执行 Sequential 钩子 - 链式数据转换 */ async executeSequential( - hookName: 'transformParams' | 'transformResult', + hookName: 'transformParams' | 'transformResult' | 'configureModel', initialValue: T, context: AiRequestContext ): Promise { @@ -120,7 +120,9 @@ export class PluginManager { * 收集所有流转换器(返回数组,AI SDK 原生支持) */ collectStreamTransforms(params: any, context: AiRequestContext) { - return this.plugins.map((plugin) => plugin.transformStream?.(params, context)) + return this.plugins + .filter((plugin) => plugin.transformStream) + .map((plugin) => plugin.transformStream?.(params, context)) } /** diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index 2b30155938..7a815825a3 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -217,7 +217,7 @@ export class RuntimeExecutor { return await createModel({ providerId: this.config.providerId, modelId: modelOrId, - options: this.config.options, + providerSettings: this.config.providerSettings, middlewares }) } else { @@ -245,7 +245,7 @@ export class RuntimeExecutor { ): RuntimeExecutor { return new RuntimeExecutor({ providerId, - options, + providerSettings: options, plugins }) } @@ -259,7 +259,7 @@ export class RuntimeExecutor { ): RuntimeExecutor<'openai-compatible'> { return new RuntimeExecutor({ providerId: 'openai-compatible', - options, + providerSettings: options, plugins }) } diff --git a/packages/aiCore/src/core/runtime/types.ts b/packages/aiCore/src/core/runtime/types.ts index e4df242e3a..c5cb6cc7ef 100644 --- a/packages/aiCore/src/core/runtime/types.ts +++ b/packages/aiCore/src/core/runtime/types.ts @@ -9,7 +9,7 @@ import { type AiPlugin } from '../plugins' */ export interface RuntimeConfig { providerId: T - options: ProviderSettingsMap[T] + providerSettings: ProviderSettingsMap[T] plugins?: AiPlugin[] }