mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 13:59:28 +08:00
refactor: enhance model configuration and plugin execution
- Simplified the `createModel` function to directly accept the `ModelConfig` object, improving clarity. - Updated `createBaseModel` to include `extraModelConfig` for extended configuration options. - Introduced `executeConfigureContext` method in `PluginManager` to handle context configuration for plugins. - Adjusted type definitions in `types.ts` to ensure consistency with the new configuration structure. - Refactored plugin execution methods in `PluginEngine` to utilize the resolved model directly, enhancing the flow of data through the plugin system.
This commit is contained in:
parent
c72156b2da
commit
0a908a334b
@ -27,24 +27,28 @@ export class ProviderCreationError extends Error {
|
|||||||
export async function createBaseModel<T extends ProviderId>({
|
export async function createBaseModel<T extends ProviderId>({
|
||||||
providerId,
|
providerId,
|
||||||
modelId,
|
modelId,
|
||||||
providerSettings
|
providerSettings,
|
||||||
|
extraModelConfig
|
||||||
// middlewares
|
// middlewares
|
||||||
}: {
|
}: {
|
||||||
providerId: T
|
providerId: T
|
||||||
modelId: string
|
modelId: string
|
||||||
providerSettings: ProviderSettingsMap[T]
|
providerSettings: ProviderSettingsMap[T]
|
||||||
|
extraModelConfig?: any
|
||||||
// middlewares?: LanguageModelV1Middleware[]
|
// middlewares?: LanguageModelV1Middleware[]
|
||||||
}): Promise<LanguageModelV2>
|
}): Promise<LanguageModelV2>
|
||||||
|
|
||||||
export async function createBaseModel({
|
export async function createBaseModel({
|
||||||
providerId,
|
providerId,
|
||||||
modelId,
|
modelId,
|
||||||
providerSettings
|
providerSettings,
|
||||||
|
extraModelConfig
|
||||||
// middlewares
|
// middlewares
|
||||||
}: {
|
}: {
|
||||||
providerId: string
|
providerId: string
|
||||||
modelId: string
|
modelId: string
|
||||||
providerSettings: ProviderSettingsMap['openai-compatible']
|
providerSettings: ProviderSettingsMap['openai-compatible']
|
||||||
|
extraModelConfig?: any
|
||||||
// middlewares?: LanguageModelV1Middleware[]
|
// middlewares?: LanguageModelV1Middleware[]
|
||||||
}): Promise<LanguageModelV2>
|
}): Promise<LanguageModelV2>
|
||||||
|
|
||||||
|
|||||||
@ -16,11 +16,7 @@ export async function createModel(config: ModelConfig): Promise<LanguageModelV2>
|
|||||||
validateModelConfig(config)
|
validateModelConfig(config)
|
||||||
|
|
||||||
// 1. 创建基础模型
|
// 1. 创建基础模型
|
||||||
const baseModel = await createBaseModel({
|
const baseModel = await createBaseModel(config)
|
||||||
providerId: config.providerId,
|
|
||||||
modelId: config.modelId,
|
|
||||||
providerSettings: config.options
|
|
||||||
})
|
|
||||||
|
|
||||||
// 2. 应用中间件(如果有)
|
// 2. 应用中间件(如果有)
|
||||||
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
|
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
|
||||||
@ -43,7 +39,7 @@ function validateModelConfig(config: ModelConfig): void {
|
|||||||
if (!config.modelId) {
|
if (!config.modelId) {
|
||||||
throw new Error('ModelConfig: modelId is required')
|
throw new Error('ModelConfig: modelId is required')
|
||||||
}
|
}
|
||||||
if (!config.options) {
|
if (!config.providerSettings) {
|
||||||
throw new Error('ModelConfig: providerSettings is required')
|
throw new Error('ModelConfig: providerSettings is required')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,9 +5,10 @@ import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
|||||||
|
|
||||||
import type { ProviderId, ProviderSettingsMap } from '../../types'
|
import type { ProviderId, ProviderSettingsMap } from '../../types'
|
||||||
|
|
||||||
export interface ModelConfig {
|
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||||
providerId: ProviderId
|
providerId: T
|
||||||
modelId: string
|
modelId: string
|
||||||
options: ProviderSettingsMap[ProviderId]
|
providerSettings: ProviderSettingsMap[T]
|
||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
extraModelConfig?: Record<string, any>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -71,7 +71,7 @@ export class PluginManager {
|
|||||||
* 执行 Sequential 钩子 - 链式数据转换
|
* 执行 Sequential 钩子 - 链式数据转换
|
||||||
*/
|
*/
|
||||||
async executeSequential<T>(
|
async executeSequential<T>(
|
||||||
hookName: 'transformParams' | 'transformResult' | 'configureModel',
|
hookName: 'transformParams' | 'transformResult',
|
||||||
initialValue: T,
|
initialValue: T,
|
||||||
context: AiRequestContext
|
context: AiRequestContext
|
||||||
): Promise<T> {
|
): Promise<T> {
|
||||||
@ -87,6 +87,18 @@ export class PluginManager {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 ConfigureContext 钩子 - 串行配置上下文
|
||||||
|
*/
|
||||||
|
async executeConfigureContext(context: AiRequestContext): Promise<void> {
|
||||||
|
for (const plugin of this.plugins) {
|
||||||
|
const hook = plugin.configureContext
|
||||||
|
if (hook) {
|
||||||
|
await hook(context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 执行 Parallel 钩子 - 并行副作用
|
* 执行 Parallel 钩子 - 并行副作用
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import type { TextStreamPart, ToolSet } from 'ai'
|
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||||
|
|
||||||
import { ProviderId } from '../providers/registry'
|
import { ProviderId } from '../providers/registry'
|
||||||
|
|
||||||
@ -32,13 +32,13 @@ export interface AiPlugin {
|
|||||||
enforce?: 'pre' | 'post'
|
enforce?: 'pre' | 'post'
|
||||||
|
|
||||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
|
resolveModel?: (modelId: string, context: AiRequestContext) => Promise<LanguageModel | null> | LanguageModel | null
|
||||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||||
|
|
||||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||||
|
configureContext?: (context: AiRequestContext) => void | Promise<void>
|
||||||
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||||
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||||
configureModel?: (model: any, context: AiRequestContext) => any | Promise<any>
|
|
||||||
|
|
||||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import { generateObject, generateText, LanguageModel, streamObject, streamText }
|
|||||||
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||||
import { createModel, getProviderInfo } from '../models'
|
import { createModel, getProviderInfo } from '../models'
|
||||||
import { type AiPlugin } from '../plugins'
|
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||||
import { PluginEngine } from './pluginEngine'
|
import { PluginEngine } from './pluginEngine'
|
||||||
import { type RuntimeConfig } from './types'
|
import { type RuntimeConfig } from './types'
|
||||||
|
|
||||||
@ -28,6 +28,19 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||||
|
return definePlugin({
|
||||||
|
name: '_internal_resolveModel',
|
||||||
|
enforce: 'post',
|
||||||
|
|
||||||
|
resolveModel: async (modelId: string, context: AiRequestContext) => {
|
||||||
|
// 从 context 中读取由用户插件注入的 extraModelConfig
|
||||||
|
const extraModelConfig = context.extraModelConfig || {}
|
||||||
|
return await this.resolveModel(modelId, middlewares, extraModelConfig)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// === 高阶重载:直接使用模型 ===
|
// === 高阶重载:直接使用模型 ===
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -59,14 +72,14 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof streamText>> {
|
): Promise<ReturnType<typeof streamText>> {
|
||||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||||
|
|
||||||
// 2. 执行插件处理
|
// 2. 执行插件处理
|
||||||
return this.pluginEngine.executeStreamWithPlugins(
|
return this.pluginEngine.executeStreamWithPlugins(
|
||||||
'streamText',
|
'streamText',
|
||||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams, streamTransforms) => {
|
async (model, transformedParams, streamTransforms) => {
|
||||||
const experimental_transform =
|
const experimental_transform =
|
||||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||||
|
|
||||||
@ -110,13 +123,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof generateText>> {
|
): Promise<ReturnType<typeof generateText>> {
|
||||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||||
|
|
||||||
return this.pluginEngine.executeWithPlugins(
|
return this.pluginEngine.executeWithPlugins(
|
||||||
'generateText',
|
'generateText',
|
||||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams) => {
|
async (model, transformedParams) => {
|
||||||
return await generateText({ model, ...transformedParams })
|
return await generateText({ model, ...transformedParams })
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -151,13 +164,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof generateObject>> {
|
): Promise<ReturnType<typeof generateObject>> {
|
||||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||||
|
|
||||||
return this.pluginEngine.executeWithPlugins(
|
return this.pluginEngine.executeWithPlugins(
|
||||||
'generateObject',
|
'generateObject',
|
||||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams) => {
|
async (model, transformedParams) => {
|
||||||
return await generateObject({ model, ...transformedParams })
|
return await generateObject({ model, ...transformedParams })
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -192,13 +205,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
}
|
}
|
||||||
): Promise<ReturnType<typeof streamObject>> {
|
): Promise<ReturnType<typeof streamObject>> {
|
||||||
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
|
||||||
|
|
||||||
return this.pluginEngine.executeWithPlugins(
|
return this.pluginEngine.executeWithPlugins(
|
||||||
'streamObject',
|
'streamObject',
|
||||||
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams) => {
|
async (model, transformedParams) => {
|
||||||
return await streamObject({ model, ...transformedParams })
|
return await streamObject({ model, ...transformedParams })
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -211,15 +224,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
*/
|
*/
|
||||||
private async resolveModel(
|
private async resolveModel(
|
||||||
modelOrId: LanguageModel,
|
modelOrId: LanguageModel,
|
||||||
middlewares?: LanguageModelV2Middleware[]
|
middlewares?: LanguageModelV2Middleware[],
|
||||||
|
extraModelConfig?: Record<string, any>
|
||||||
): Promise<LanguageModelV2> {
|
): Promise<LanguageModelV2> {
|
||||||
if (typeof modelOrId === 'string') {
|
if (typeof modelOrId === 'string') {
|
||||||
// 字符串modelId,需要创建模型
|
// 字符串modelId,需要创建模型
|
||||||
return await createModel({
|
return await createModel({
|
||||||
providerId: this.config.providerId,
|
providerId: this.config.providerId,
|
||||||
modelId: modelOrId,
|
modelId: modelOrId,
|
||||||
options: this.config.providerSettings,
|
providerSettings: this.config.providerSettings,
|
||||||
middlewares
|
middlewares,
|
||||||
|
extraModelConfig
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// 已经是模型,直接返回
|
// 已经是模型,直接返回
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import { LanguageModel } from 'ai'
|
||||||
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||||
import { isProviderSupported } from '../providers/registry'
|
import { isProviderSupported } from '../providers/registry'
|
||||||
@ -63,7 +65,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
methodName: string,
|
methodName: string,
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: TParams,
|
params: TParams,
|
||||||
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>,
|
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||||
_context?: ReturnType<typeof createContext>
|
_context?: ReturnType<typeof createContext>
|
||||||
): Promise<TResult> {
|
): Promise<TResult> {
|
||||||
// 使用正确的createContext创建请求上下文
|
// 使用正确的createContext创建请求上下文
|
||||||
@ -79,18 +81,23 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// 0. 配置上下文
|
||||||
|
await this.pluginManager.executeConfigureContext(context)
|
||||||
|
|
||||||
// 1. 触发请求开始事件
|
// 1. 触发请求开始事件
|
||||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||||
|
|
||||||
// 2. 解析模型别名
|
// 2. 解析模型
|
||||||
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
|
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||||
const finalModelId = resolvedModelId || modelId
|
if (!model) {
|
||||||
|
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 转换请求参数
|
// 3. 转换请求参数
|
||||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||||
|
|
||||||
// 4. 执行具体的 API 调用
|
// 4. 执行具体的 API 调用
|
||||||
const result = await executor(finalModelId, transformedParams)
|
const result = await executor(model, transformedParams)
|
||||||
|
|
||||||
// 5. 转换结果(对于非流式调用)
|
// 5. 转换结果(对于非流式调用)
|
||||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||||
@ -114,7 +121,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
methodName: string,
|
methodName: string,
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: TParams,
|
params: TParams,
|
||||||
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||||
_context?: ReturnType<typeof createContext>
|
_context?: ReturnType<typeof createContext>
|
||||||
): Promise<TResult> {
|
): Promise<TResult> {
|
||||||
// 创建请求上下文
|
// 创建请求上下文
|
||||||
@ -130,12 +137,18 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// 0. 配置上下文
|
||||||
|
await this.pluginManager.executeConfigureContext(context)
|
||||||
|
|
||||||
// 1. 触发请求开始事件
|
// 1. 触发请求开始事件
|
||||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||||
|
|
||||||
// 2. 解析模型别名
|
// 2. 解析模型
|
||||||
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
|
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||||
const finalModelId = resolvedModelId || modelId
|
|
||||||
|
if (!model) {
|
||||||
|
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||||
|
}
|
||||||
|
|
||||||
// 3. 转换请求参数
|
// 3. 转换请求参数
|
||||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||||
@ -144,12 +157,14 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||||
|
|
||||||
// 5. 执行流式 API 调用
|
// 5. 执行流式 API 调用
|
||||||
const result = await executor(finalModelId, transformedParams, streamTransforms)
|
const result = await executor(model, transformedParams, streamTransforms)
|
||||||
|
|
||||||
|
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||||
|
|
||||||
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
||||||
await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true })
|
await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true })
|
||||||
|
|
||||||
return result
|
return transformedResult
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// 7. 触发错误事件
|
// 7. 触发错误事件
|
||||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user