refactor: enhance model configuration and plugin execution

- Simplified the `createModel` function to directly accept the `ModelConfig` object, improving clarity.
- Updated `createBaseModel` to include `extraModelConfig` for extended configuration options.
- Introduced `executeConfigureContext` method in `PluginManager` to handle context configuration for plugins.
- Adjusted type definitions in `types.ts` to ensure consistency with the new configuration structure.
- Refactored plugin execution methods in `PluginEngine` to utilize the resolved model directly, enhancing the flow of data through the plugin system.
This commit is contained in:
MyPrototypeWhat 2025-07-07 18:33:47 +08:00
parent c72156b2da
commit 0a908a334b
7 changed files with 85 additions and 42 deletions

View File

@ -27,24 +27,28 @@ export class ProviderCreationError extends Error {
export async function createBaseModel<T extends ProviderId>({ export async function createBaseModel<T extends ProviderId>({
providerId, providerId,
modelId, modelId,
providerSettings providerSettings,
extraModelConfig
// middlewares // middlewares
}: { }: {
providerId: T providerId: T
modelId: string modelId: string
providerSettings: ProviderSettingsMap[T] providerSettings: ProviderSettingsMap[T]
extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[] // middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2> }): Promise<LanguageModelV2>
export async function createBaseModel({ export async function createBaseModel({
providerId, providerId,
modelId, modelId,
providerSettings providerSettings,
extraModelConfig
// middlewares // middlewares
}: { }: {
providerId: string providerId: string
modelId: string modelId: string
providerSettings: ProviderSettingsMap['openai-compatible'] providerSettings: ProviderSettingsMap['openai-compatible']
extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[] // middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2> }): Promise<LanguageModelV2>

View File

@ -16,11 +16,7 @@ export async function createModel(config: ModelConfig): Promise<LanguageModelV2>
validateModelConfig(config) validateModelConfig(config)
// 1. 创建基础模型 // 1. 创建基础模型
const baseModel = await createBaseModel({ const baseModel = await createBaseModel(config)
providerId: config.providerId,
modelId: config.modelId,
providerSettings: config.options
})
// 2. 应用中间件(如果有) // 2. 应用中间件(如果有)
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
@ -43,7 +39,7 @@ function validateModelConfig(config: ModelConfig): void {
if (!config.modelId) { if (!config.modelId) {
throw new Error('ModelConfig: modelId is required') throw new Error('ModelConfig: modelId is required')
} }
if (!config.options) { if (!config.providerSettings) {
throw new Error('ModelConfig: providerSettings is required') throw new Error('ModelConfig: providerSettings is required')
} }
} }

View File

@ -5,9 +5,10 @@ import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import type { ProviderId, ProviderSettingsMap } from '../../types' import type { ProviderId, ProviderSettingsMap } from '../../types'
export interface ModelConfig { export interface ModelConfig<T extends ProviderId = ProviderId> {
providerId: ProviderId providerId: T
modelId: string modelId: string
options: ProviderSettingsMap[ProviderId] providerSettings: ProviderSettingsMap[T]
middlewares?: LanguageModelV2Middleware[] middlewares?: LanguageModelV2Middleware[]
extraModelConfig?: Record<string, any>
} }

View File

@ -71,7 +71,7 @@ export class PluginManager {
* Sequential - * Sequential -
*/ */
async executeSequential<T>( async executeSequential<T>(
hookName: 'transformParams' | 'transformResult' | 'configureModel', hookName: 'transformParams' | 'transformResult',
initialValue: T, initialValue: T,
context: AiRequestContext context: AiRequestContext
): Promise<T> { ): Promise<T> {
@ -87,6 +87,18 @@ export class PluginManager {
return result return result
} }
/**
* ConfigureContext -
*/
async executeConfigureContext(context: AiRequestContext): Promise<void> {
for (const plugin of this.plugins) {
const hook = plugin.configureContext
if (hook) {
await hook(context)
}
}
}
/** /**
* Parallel - * Parallel -
*/ */

View File

@ -1,4 +1,4 @@
import type { TextStreamPart, ToolSet } from 'ai' import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
import { ProviderId } from '../providers/registry' import { ProviderId } from '../providers/registry'
@ -32,13 +32,13 @@ export interface AiPlugin {
enforce?: 'pre' | 'post' enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件 // 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null> resolveModel?: (modelId: string, context: AiRequestContext) => Promise<LanguageModel | null> | LanguageModel | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null> loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换 // 【Sequential】串行钩子 - 链式执行,支持数据转换
configureContext?: (context: AiRequestContext) => void | Promise<void>
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any> transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any> transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
configureModel?: (model: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用 // 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void> onRequestStart?: (context: AiRequestContext) => void | Promise<void>

View File

@ -7,7 +7,7 @@ import { generateObject, generateText, LanguageModel, streamObject, streamText }
import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { createModel, getProviderInfo } from '../models' import { createModel, getProviderInfo } from '../models'
import { type AiPlugin } from '../plugins' import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { PluginEngine } from './pluginEngine' import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types' import { type RuntimeConfig } from './types'
@ -28,6 +28,19 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || []) this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
} }
createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
return definePlugin({
name: '_internal_resolveModel',
enforce: 'post',
resolveModel: async (modelId: string, context: AiRequestContext) => {
// 从 context 中读取由用户插件注入的 extraModelConfig
const extraModelConfig = context.extraModelConfig || {}
return await this.resolveModel(modelId, middlewares, extraModelConfig)
}
})
}
// === 高阶重载:直接使用模型 === // === 高阶重载:直接使用模型 ===
/** /**
@ -59,14 +72,14 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[] middlewares?: LanguageModelV2Middleware[]
} }
): Promise<ReturnType<typeof streamText>> { ): Promise<ReturnType<typeof streamText>> {
const model = await this.resolveModel(modelOrId, options?.middlewares) this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
// 2. 执行插件处理 // 2. 执行插件处理
return this.pluginEngine.executeStreamWithPlugins( return this.pluginEngine.executeStreamWithPlugins(
'streamText', 'streamText',
typeof modelOrId === 'string' ? modelOrId : model.modelId, typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params, params,
async (finalModelId, transformedParams, streamTransforms) => { async (model, transformedParams, streamTransforms) => {
const experimental_transform = const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined) params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
@ -110,13 +123,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[] middlewares?: LanguageModelV2Middleware[]
} }
): Promise<ReturnType<typeof generateText>> { ): Promise<ReturnType<typeof generateText>> {
const model = await this.resolveModel(modelOrId, options?.middlewares) this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
return this.pluginEngine.executeWithPlugins( return this.pluginEngine.executeWithPlugins(
'generateText', 'generateText',
typeof modelOrId === 'string' ? modelOrId : model.modelId, typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params, params,
async (finalModelId, transformedParams) => { async (model, transformedParams) => {
return await generateText({ model, ...transformedParams }) return await generateText({ model, ...transformedParams })
} }
) )
@ -151,13 +164,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[] middlewares?: LanguageModelV2Middleware[]
} }
): Promise<ReturnType<typeof generateObject>> { ): Promise<ReturnType<typeof generateObject>> {
const model = await this.resolveModel(modelOrId, options?.middlewares) this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
return this.pluginEngine.executeWithPlugins( return this.pluginEngine.executeWithPlugins(
'generateObject', 'generateObject',
typeof modelOrId === 'string' ? modelOrId : model.modelId, typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params, params,
async (finalModelId, transformedParams) => { async (model, transformedParams) => {
return await generateObject({ model, ...transformedParams }) return await generateObject({ model, ...transformedParams })
} }
) )
@ -192,13 +205,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[] middlewares?: LanguageModelV2Middleware[]
} }
): Promise<ReturnType<typeof streamObject>> { ): Promise<ReturnType<typeof streamObject>> {
const model = await this.resolveModel(modelOrId, options?.middlewares) this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
return this.pluginEngine.executeWithPlugins( return this.pluginEngine.executeWithPlugins(
'streamObject', 'streamObject',
typeof modelOrId === 'string' ? modelOrId : model.modelId, typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params, params,
async (finalModelId, transformedParams) => { async (model, transformedParams) => {
return await streamObject({ model, ...transformedParams }) return await streamObject({ model, ...transformedParams })
} }
) )
@ -211,15 +224,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
*/ */
private async resolveModel( private async resolveModel(
modelOrId: LanguageModel, modelOrId: LanguageModel,
middlewares?: LanguageModelV2Middleware[] middlewares?: LanguageModelV2Middleware[],
extraModelConfig?: Record<string, any>
): Promise<LanguageModelV2> { ): Promise<LanguageModelV2> {
if (typeof modelOrId === 'string') { if (typeof modelOrId === 'string') {
// 字符串modelId需要创建模型 // 字符串modelId需要创建模型
return await createModel({ return await createModel({
providerId: this.config.providerId, providerId: this.config.providerId,
modelId: modelOrId, modelId: modelOrId,
options: this.config.providerSettings, providerSettings: this.config.providerSettings,
middlewares middlewares,
extraModelConfig
}) })
} else { } else {
// 已经是模型,直接返回 // 已经是模型,直接返回

View File

@ -1,3 +1,5 @@
import { LanguageModel } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin, createContext, PluginManager } from '../plugins' import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry' import { isProviderSupported } from '../providers/registry'
@ -63,7 +65,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
methodName: string, methodName: string,
modelId: string, modelId: string,
params: TParams, params: TParams,
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>, executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext> _context?: ReturnType<typeof createContext>
): Promise<TResult> { ): Promise<TResult> {
// 使用正确的createContext创建请求上下文 // 使用正确的createContext创建请求上下文
@ -79,18 +81,23 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
} }
try { try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件 // 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context) await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名 // 2. 解析模型
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context) const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
// 3. 转换请求参数 // 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用 // 4. 执行具体的 API 调用
const result = await executor(finalModelId, transformedParams) const result = await executor(model, transformedParams)
// 5. 转换结果(对于非流式调用) // 5. 转换结果(对于非流式调用)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context) const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
@ -114,7 +121,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
methodName: string, methodName: string,
modelId: string, modelId: string,
params: TParams, params: TParams,
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>, executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
_context?: ReturnType<typeof createContext> _context?: ReturnType<typeof createContext>
): Promise<TResult> { ): Promise<TResult> {
// 创建请求上下文 // 创建请求上下文
@ -130,12 +137,18 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
} }
try { try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件 // 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context) await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名 // 2. 解析模型
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context) const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId
if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
// 3. 转换请求参数 // 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context) const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
@ -144,12 +157,14 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context) const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
// 5. 执行流式 API 调用 // 5. 执行流式 API 调用
const result = await executor(finalModelId, transformedParams, streamTransforms) const result = await executor(model, transformedParams, streamTransforms)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件) // 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true }) await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true })
return result return transformedResult
} catch (error) { } catch (error) {
// 7. 触发错误事件 // 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error) await this.pluginManager.executeParallel('onError', context, undefined, error as Error)