mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: enhance model configuration and plugin execution
- Simplified the `createModel` function to directly accept the `ModelConfig` object, improving clarity. - Updated `createBaseModel` to include `extraModelConfig` for extended configuration options. - Introduced `executeConfigureContext` method in `PluginManager` to handle context configuration for plugins. - Adjusted type definitions in `types.ts` to ensure consistency with the new configuration structure. - Refactored plugin execution methods in `PluginEngine` to utilize the resolved model directly, enhancing the flow of data through the plugin system.
This commit is contained in:
parent
c72156b2da
commit
0a908a334b
@ -27,24 +27,28 @@ export class ProviderCreationError extends Error {
|
||||
export async function createBaseModel<T extends ProviderId>({
|
||||
providerId,
|
||||
modelId,
|
||||
providerSettings
|
||||
providerSettings,
|
||||
extraModelConfig
|
||||
// middlewares
|
||||
}: {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T]
|
||||
extraModelConfig?: any
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
}): Promise<LanguageModelV2>
|
||||
|
||||
export async function createBaseModel({
|
||||
providerId,
|
||||
modelId,
|
||||
providerSettings
|
||||
providerSettings,
|
||||
extraModelConfig
|
||||
// middlewares
|
||||
}: {
|
||||
providerId: string
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap['openai-compatible']
|
||||
extraModelConfig?: any
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
}): Promise<LanguageModelV2>
|
||||
|
||||
|
||||
@ -16,11 +16,7 @@ export async function createModel(config: ModelConfig): Promise<LanguageModelV2>
|
||||
validateModelConfig(config)
|
||||
|
||||
// 1. 创建基础模型
|
||||
const baseModel = await createBaseModel({
|
||||
providerId: config.providerId,
|
||||
modelId: config.modelId,
|
||||
providerSettings: config.options
|
||||
})
|
||||
const baseModel = await createBaseModel(config)
|
||||
|
||||
// 2. 应用中间件(如果有)
|
||||
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
|
||||
@ -43,7 +39,7 @@ function validateModelConfig(config: ModelConfig): void {
|
||||
if (!config.modelId) {
|
||||
throw new Error('ModelConfig: modelId is required')
|
||||
}
|
||||
if (!config.options) {
|
||||
if (!config.providerSettings) {
|
||||
throw new Error('ModelConfig: providerSettings is required')
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,9 +5,10 @@ import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from '../../types'
|
||||
|
||||
export interface ModelConfig {
|
||||
providerId: ProviderId
|
||||
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
modelId: string
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
providerSettings: ProviderSettingsMap[T]
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
extraModelConfig?: Record<string, any>
|
||||
}
|
||||
|
||||
@ -71,7 +71,7 @@ export class PluginManager {
|
||||
* 执行 Sequential 钩子 - 链式数据转换
|
||||
*/
|
||||
async executeSequential<T>(
|
||||
hookName: 'transformParams' | 'transformResult' | 'configureModel',
|
||||
hookName: 'transformParams' | 'transformResult',
|
||||
initialValue: T,
|
||||
context: AiRequestContext
|
||||
): Promise<T> {
|
||||
@ -87,6 +87,18 @@ export class PluginManager {
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 ConfigureContext 钩子 - 串行配置上下文
|
||||
*/
|
||||
async executeConfigureContext(context: AiRequestContext): Promise<void> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin.configureContext
|
||||
if (hook) {
|
||||
await hook(context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Parallel 钩子 - 并行副作用
|
||||
*/
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { ProviderId } from '../providers/registry'
|
||||
|
||||
@ -32,13 +32,13 @@ export interface AiPlugin {
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => Promise<LanguageModel | null> | LanguageModel | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
configureContext?: (context: AiRequestContext) => void | Promise<void>
|
||||
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||
configureModel?: (model: any, context: AiRequestContext) => any | Promise<any>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
@ -7,7 +7,7 @@ import { generateObject, generateText, LanguageModel, streamObject, streamText }
|
||||
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||
import { createModel, getProviderInfo } from '../models'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
|
||||
@ -28,6 +28,19 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||
}
|
||||
|
||||
createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string, context: AiRequestContext) => {
|
||||
// 从 context 中读取由用户插件注入的 extraModelConfig
|
||||
const extraModelConfig = context.extraModelConfig || {}
|
||||
return await this.resolveModel(modelId, middlewares, extraModelConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// === 高阶重载:直接使用模型 ===
|
||||
|
||||
/**
|
||||
@ -59,14 +72,14 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
|
||||
// 2. 执行插件处理
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (finalModelId, transformedParams, streamTransforms) => {
|
||||
async (model, transformedParams, streamTransforms) => {
|
||||
const experimental_transform =
|
||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||
|
||||
@ -110,13 +123,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateText',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (finalModelId, transformedParams) => {
|
||||
async (model, transformedParams) => {
|
||||
return await generateText({ model, ...transformedParams })
|
||||
}
|
||||
)
|
||||
@ -151,13 +164,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (finalModelId, transformedParams) => {
|
||||
async (model, transformedParams) => {
|
||||
return await generateObject({ model, ...transformedParams })
|
||||
}
|
||||
)
|
||||
@ -192,13 +205,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (finalModelId, transformedParams) => {
|
||||
async (model, transformedParams) => {
|
||||
return await streamObject({ model, ...transformedParams })
|
||||
}
|
||||
)
|
||||
@ -211,15 +224,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
*/
|
||||
private async resolveModel(
|
||||
modelOrId: LanguageModel,
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
middlewares?: LanguageModelV2Middleware[],
|
||||
extraModelConfig?: Record<string, any>
|
||||
): Promise<LanguageModelV2> {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 字符串modelId,需要创建模型
|
||||
return await createModel({
|
||||
providerId: this.config.providerId,
|
||||
modelId: modelOrId,
|
||||
options: this.config.providerSettings,
|
||||
middlewares
|
||||
providerSettings: this.config.providerSettings,
|
||||
middlewares,
|
||||
extraModelConfig
|
||||
})
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { isProviderSupported } from '../providers/registry'
|
||||
@ -63,7 +65,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 使用正确的createContext创建请求上下文
|
||||
@ -79,18 +81,23 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型别名
|
||||
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
|
||||
const finalModelId = resolvedModelId || modelId
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(finalModelId, transformedParams)
|
||||
const result = await executor(model, transformedParams)
|
||||
|
||||
// 5. 转换结果(对于非流式调用)
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
@ -114,7 +121,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 创建请求上下文
|
||||
@ -130,12 +137,18 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型别名
|
||||
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
|
||||
const finalModelId = resolvedModelId || modelId
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
@ -144,12 +157,14 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||
|
||||
// 5. 执行流式 API 调用
|
||||
const result = await executor(finalModelId, transformedParams, streamTransforms)
|
||||
const result = await executor(model, transformedParams, streamTransforms)
|
||||
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true })
|
||||
|
||||
return result
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user