diff --git a/packages/aiCore/src/core/models/ProviderCreator.ts b/packages/aiCore/src/core/models/ProviderCreator.ts index 177c6f377c..3cc9c07f09 100644 --- a/packages/aiCore/src/core/models/ProviderCreator.ts +++ b/packages/aiCore/src/core/models/ProviderCreator.ts @@ -27,24 +27,28 @@ export class ProviderCreationError extends Error { export async function createBaseModel({ providerId, modelId, - providerSettings + providerSettings, + extraModelConfig // middlewares }: { providerId: T modelId: string providerSettings: ProviderSettingsMap[T] + extraModelConfig?: any // middlewares?: LanguageModelV1Middleware[] }): Promise export async function createBaseModel({ providerId, modelId, - providerSettings + providerSettings, + extraModelConfig // middlewares }: { providerId: string modelId: string providerSettings: ProviderSettingsMap['openai-compatible'] + extraModelConfig?: any // middlewares?: LanguageModelV1Middleware[] }): Promise diff --git a/packages/aiCore/src/core/models/factory.ts b/packages/aiCore/src/core/models/factory.ts index 9010514374..7f7dc4ab81 100644 --- a/packages/aiCore/src/core/models/factory.ts +++ b/packages/aiCore/src/core/models/factory.ts @@ -16,11 +16,7 @@ export async function createModel(config: ModelConfig): Promise validateModelConfig(config) // 1. 创建基础模型 - const baseModel = await createBaseModel({ - providerId: config.providerId, - modelId: config.modelId, - providerSettings: config.options - }) + const baseModel = await createBaseModel(config) // 2. 应用中间件(如果有) return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel @@ -43,7 +39,7 @@ function validateModelConfig(config: ModelConfig): void { if (!config.modelId) { throw new Error('ModelConfig: modelId is required') } - if (!config.options) { + if (!config.providerSettings) { throw new Error('ModelConfig: providerSettings is required') } } diff --git a/packages/aiCore/src/core/models/types.ts b/packages/aiCore/src/core/models/types.ts index 2bb98a67e8..47623a6359 100644 --- a/packages/aiCore/src/core/models/types.ts +++ b/packages/aiCore/src/core/models/types.ts @@ -5,9 +5,10 @@ import { LanguageModelV2Middleware } from '@ai-sdk/provider' import type { ProviderId, ProviderSettingsMap } from '../../types' -export interface ModelConfig { - providerId: ProviderId +export interface ModelConfig { + providerId: T modelId: string - options: ProviderSettingsMap[ProviderId] + providerSettings: ProviderSettingsMap[T] middlewares?: LanguageModelV2Middleware[] + extraModelConfig?: Record } diff --git a/packages/aiCore/src/core/plugins/manager.ts b/packages/aiCore/src/core/plugins/manager.ts index 5c381aa964..ccde1ccc87 100644 --- a/packages/aiCore/src/core/plugins/manager.ts +++ b/packages/aiCore/src/core/plugins/manager.ts @@ -71,7 +71,7 @@ export class PluginManager { * 执行 Sequential 钩子 - 链式数据转换 */ async executeSequential( - hookName: 'transformParams' | 'transformResult' | 'configureModel', + hookName: 'transformParams' | 'transformResult', initialValue: T, context: AiRequestContext ): Promise { @@ -87,6 +87,18 @@ export class PluginManager { return result } + /** + * 执行 ConfigureContext 钩子 - 串行配置上下文 + */ + async executeConfigureContext(context: AiRequestContext): Promise { + for (const plugin of this.plugins) { + const hook = plugin.configureContext + if (hook) { + await hook(context) + } + } + } + /** * 执行 Parallel 钩子 - 并行副作用 */ diff --git a/packages/aiCore/src/core/plugins/types.ts b/packages/aiCore/src/core/plugins/types.ts index b1dd00cbd5..53bb6f6fe1 100644 --- a/packages/aiCore/src/core/plugins/types.ts +++ b/packages/aiCore/src/core/plugins/types.ts @@ -1,4 +1,4 @@ -import type { TextStreamPart, ToolSet } from 'ai' +import type { LanguageModel, TextStreamPart, ToolSet } from 'ai' import { ProviderId } from '../providers/registry' @@ -32,13 +32,13 @@ export interface AiPlugin { enforce?: 'pre' | 'post' // 【First】首个钩子 - 只执行第一个返回值的插件 - resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise + resolveModel?: (modelId: string, context: AiRequestContext) => Promise | LanguageModel | null loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise // 【Sequential】串行钩子 - 链式执行,支持数据转换 + configureContext?: (context: AiRequestContext) => void | Promise transformParams?: (params: any, context: AiRequestContext) => any | Promise transformResult?: (result: any, context: AiRequestContext) => any | Promise - configureModel?: (model: any, context: AiRequestContext) => any | Promise // 【Parallel】并行钩子 - 不依赖顺序,用于副作用 onRequestStart?: (context: AiRequestContext) => void | Promise diff --git a/packages/aiCore/src/core/runtime/executor.ts b/packages/aiCore/src/core/runtime/executor.ts index e1b9e9267f..46ca3af2b5 100644 --- a/packages/aiCore/src/core/runtime/executor.ts +++ b/packages/aiCore/src/core/runtime/executor.ts @@ -7,7 +7,7 @@ import { generateObject, generateText, LanguageModel, streamObject, streamText } import { type ProviderId, type ProviderSettingsMap } from '../../types' import { createModel, getProviderInfo } from '../models' -import { type AiPlugin } from '../plugins' +import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins' import { PluginEngine } from './pluginEngine' import { type RuntimeConfig } from './types' @@ -28,6 +28,19 @@ export class RuntimeExecutor { this.pluginEngine = new PluginEngine(config.providerId, config.plugins || []) } + createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) { + return definePlugin({ + name: '_internal_resolveModel', + enforce: 'post', + + resolveModel: async (modelId: string, context: AiRequestContext) => { + // 从 context 中读取由用户插件注入的 extraModelConfig + const extraModelConfig = context.extraModelConfig || {} + return await this.resolveModel(modelId, middlewares, extraModelConfig) + } + }) + } + // === 高阶重载:直接使用模型 === /** @@ -59,14 +72,14 @@ export class RuntimeExecutor { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - const model = await this.resolveModel(modelOrId, options?.middlewares) + this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares)) // 2. 执行插件处理 return this.pluginEngine.executeStreamWithPlugins( 'streamText', - typeof modelOrId === 'string' ? modelOrId : model.modelId, + typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, params, - async (finalModelId, transformedParams, streamTransforms) => { + async (model, transformedParams, streamTransforms) => { const experimental_transform = params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined) @@ -110,13 +123,13 @@ export class RuntimeExecutor { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - const model = await this.resolveModel(modelOrId, options?.middlewares) + this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares)) return this.pluginEngine.executeWithPlugins( 'generateText', - typeof modelOrId === 'string' ? modelOrId : model.modelId, + typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, params, - async (finalModelId, transformedParams) => { + async (model, transformedParams) => { return await generateText({ model, ...transformedParams }) } ) @@ -151,13 +164,13 @@ export class RuntimeExecutor { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - const model = await this.resolveModel(modelOrId, options?.middlewares) + this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares)) return this.pluginEngine.executeWithPlugins( 'generateObject', - typeof modelOrId === 'string' ? modelOrId : model.modelId, + typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, params, - async (finalModelId, transformedParams) => { + async (model, transformedParams) => { return await generateObject({ model, ...transformedParams }) } ) @@ -192,13 +205,13 @@ export class RuntimeExecutor { middlewares?: LanguageModelV2Middleware[] } ): Promise> { - const model = await this.resolveModel(modelOrId, options?.middlewares) + this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares)) return this.pluginEngine.executeWithPlugins( 'streamObject', - typeof modelOrId === 'string' ? modelOrId : model.modelId, + typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId, params, - async (finalModelId, transformedParams) => { + async (model, transformedParams) => { return await streamObject({ model, ...transformedParams }) } ) @@ -211,15 +224,17 @@ export class RuntimeExecutor { */ private async resolveModel( modelOrId: LanguageModel, - middlewares?: LanguageModelV2Middleware[] + middlewares?: LanguageModelV2Middleware[], + extraModelConfig?: Record ): Promise { if (typeof modelOrId === 'string') { // 字符串modelId,需要创建模型 return await createModel({ providerId: this.config.providerId, modelId: modelOrId, - options: this.config.providerSettings, - middlewares + providerSettings: this.config.providerSettings, + middlewares, + extraModelConfig }) } else { // 已经是模型,直接返回 diff --git a/packages/aiCore/src/core/runtime/pluginEngine.ts b/packages/aiCore/src/core/runtime/pluginEngine.ts index dd7f0b7d9b..68089c3891 100644 --- a/packages/aiCore/src/core/runtime/pluginEngine.ts +++ b/packages/aiCore/src/core/runtime/pluginEngine.ts @@ -1,3 +1,5 @@ +import { LanguageModel } from 'ai' + import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type AiPlugin, createContext, PluginManager } from '../plugins' import { isProviderSupported } from '../providers/registry' @@ -63,7 +65,7 @@ export class PluginEngine { methodName: string, modelId: string, params: TParams, - executor: (finalModelId: string, transformedParams: TParams) => Promise, + executor: (model: LanguageModel, transformedParams: TParams) => Promise, _context?: ReturnType ): Promise { // 使用正确的createContext创建请求上下文 @@ -79,18 +81,23 @@ export class PluginEngine { } try { + // 0. 配置上下文 + await this.pluginManager.executeConfigureContext(context) + // 1. 触发请求开始事件 await this.pluginManager.executeParallel('onRequestStart', context) - // 2. 解析模型别名 - const resolvedModelId = await this.pluginManager.executeFirst('resolveModel', modelId, context) - const finalModelId = resolvedModelId || modelId + // 2. 解析模型 + const model = await this.pluginManager.executeFirst('resolveModel', modelId, context) + if (!model) { + throw new Error(`Failed to resolve model: ${modelId}`) + } // 3. 转换请求参数 const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) // 4. 执行具体的 API 调用 - const result = await executor(finalModelId, transformedParams) + const result = await executor(model, transformedParams) // 5. 转换结果(对于非流式调用) const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) @@ -114,7 +121,7 @@ export class PluginEngine { methodName: string, modelId: string, params: TParams, - executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise, + executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise, _context?: ReturnType ): Promise { // 创建请求上下文 @@ -130,12 +137,18 @@ export class PluginEngine { } try { + // 0. 配置上下文 + await this.pluginManager.executeConfigureContext(context) + // 1. 触发请求开始事件 await this.pluginManager.executeParallel('onRequestStart', context) - // 2. 解析模型别名 - const resolvedModelId = await this.pluginManager.executeFirst('resolveModel', modelId, context) - const finalModelId = resolvedModelId || modelId + // 2. 解析模型 + const model = await this.pluginManager.executeFirst('resolveModel', modelId, context) + + if (!model) { + throw new Error(`Failed to resolve model: ${modelId}`) + } // 3. 转换请求参数 const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) @@ -144,12 +157,14 @@ export class PluginEngine { const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context) // 5. 执行流式 API 调用 - const result = await executor(finalModelId, transformedParams, streamTransforms) + const result = await executor(model, transformedParams, streamTransforms) + + const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) // 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件) await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true }) - return result + return transformedResult } catch (error) { // 7. 触发错误事件 await this.pluginManager.executeParallel('onError', context, undefined, error as Error)