mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: streamline model configuration and factory functions
- Updated the `createModel` function to accept a simplified `ModelConfig` interface, enhancing clarity and usability. - Refactored `createBaseModel` to destructure parameters for better readability and maintainability. - Removed the `ModelCreator.ts` file as its functionality has been integrated into the factory functions. - Adjusted type definitions in `types.ts` to reflect changes in model configuration structure, ensuring consistency across the codebase.
This commit is contained in:
parent
547e5785c0
commit
c92475b6bf
@ -1,31 +0,0 @@
|
||||
/**
|
||||
* 模型创建器
|
||||
* 负责创建模型并应用中间件,不暴露原始model给用户
|
||||
*/
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { wrapModelWithMiddlewares } from '../middleware'
|
||||
import { createBaseModel } from './ProviderCreator'
|
||||
import { ModelCreationRequest, ResolvedConfig } from './types'
|
||||
|
||||
/**
|
||||
* 根据解析后的配置创建包装好的模型
|
||||
*/
|
||||
export async function createModelFromConfig(config: ResolvedConfig): Promise<LanguageModel> {
|
||||
// 使用ProviderCreator创建基础模型(不应用中间件)
|
||||
const baseModel = await createBaseModel(config.provider.id, config.model.id, config.provider.options)
|
||||
|
||||
// 在creation层应用中间件,用户不直接接触原始model
|
||||
return wrapModelWithMiddlewares(baseModel, config.middlewares)
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接根据请求参数创建模型
|
||||
*/
|
||||
export async function createModel(request: ModelCreationRequest): Promise<LanguageModel> {
|
||||
// 使用ProviderCreator创建基础模型(不应用中间件)
|
||||
const baseModel = await createBaseModel(request.providerId, request.modelId, request.options)
|
||||
|
||||
const middlewares = request.middlewares || []
|
||||
return wrapModelWithMiddlewares(baseModel, middlewares)
|
||||
}
|
||||
@ -3,7 +3,7 @@
|
||||
* 负责动态导入 AI SDK providers 并创建基础模型实例
|
||||
*/
|
||||
import type { ImageModelV1 } from '@ai-sdk/provider'
|
||||
import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai'
|
||||
import { type LanguageModelV1 } from 'ai'
|
||||
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||
@ -25,26 +25,43 @@ export class ProviderCreationError extends Error {
|
||||
* 创建基础 AI SDK 模型实例
|
||||
* 对于已知的 Provider 使用严格类型检查,未知的 Provider 默认使用 openai-compatible
|
||||
*/
|
||||
export async function createBaseModel<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
options: ProviderSettingsMap[T],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
): Promise<LanguageModelV1>
|
||||
export async function createBaseModel<T extends ProviderId>({
|
||||
providerId,
|
||||
modelId,
|
||||
providerSettings
|
||||
// middlewares
|
||||
}: {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T]
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
}): Promise<LanguageModelV1>
|
||||
|
||||
export async function createBaseModel(
|
||||
providerId: string,
|
||||
modelId: string,
|
||||
options: ProviderSettingsMap['openai-compatible'],
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
): Promise<LanguageModelV1>
|
||||
export async function createBaseModel({
|
||||
providerId,
|
||||
modelId,
|
||||
providerSettings
|
||||
// middlewares
|
||||
}: {
|
||||
providerId: string
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap['openai-compatible']
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
}): Promise<LanguageModelV1>
|
||||
|
||||
export async function createBaseModel(
|
||||
providerId: string,
|
||||
modelId: string = 'default',
|
||||
options: any,
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
): Promise<LanguageModelV1> {
|
||||
export async function createBaseModel({
|
||||
providerId,
|
||||
modelId,
|
||||
providerSettings,
|
||||
// middlewares,
|
||||
extraModelConfig
|
||||
}: {
|
||||
providerId: string
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[ProviderId]
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
extraModelConfig?: any
|
||||
}): Promise<LanguageModelV1> {
|
||||
try {
|
||||
// 对于不在注册表中的 provider,默认使用 openai-compatible
|
||||
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
|
||||
@ -67,7 +84,7 @@ export async function createBaseModel(
|
||||
)
|
||||
}
|
||||
// 创建provider实例
|
||||
let provider = creatorFunction(options)
|
||||
let provider = creatorFunction(providerSettings)
|
||||
|
||||
// 加一个特判
|
||||
if (providerConfig.id === 'openai' && !isOpenAIChatCompletionOnlyModel(modelId)) {
|
||||
@ -75,15 +92,16 @@ export async function createBaseModel(
|
||||
}
|
||||
// 返回模型实例
|
||||
if (typeof provider === 'function') {
|
||||
let model: LanguageModelV1 = provider(modelId)
|
||||
// extraModelConfig:例如google的useSearchGrounding
|
||||
const model: LanguageModelV1 = provider(modelId, extraModelConfig)
|
||||
|
||||
// 应用 AI SDK 中间件
|
||||
if (middlewares && middlewares.length > 0) {
|
||||
model = wrapLanguageModel({
|
||||
model: model,
|
||||
middleware: middlewares
|
||||
})
|
||||
}
|
||||
// // 应用 AI SDK 中间件
|
||||
// if (middlewares && middlewares.length > 0) {
|
||||
// model = wrapLanguageModel({
|
||||
// model: model,
|
||||
// middleware: middlewares
|
||||
// })
|
||||
// }
|
||||
|
||||
return model
|
||||
} else {
|
||||
|
||||
@ -2,18 +2,11 @@
|
||||
* 模型工厂函数
|
||||
* 统一的模型创建和配置管理
|
||||
*/
|
||||
import { LanguageModel, LanguageModelV1Middleware } from 'ai'
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||
import { wrapModelWithMiddlewares } from '../middleware'
|
||||
import { createBaseModel } from './ProviderCreator'
|
||||
|
||||
export interface ModelConfig {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
import { ModelConfig } from './types'
|
||||
|
||||
/**
|
||||
* 创建模型 - 核心函数
|
||||
@ -22,7 +15,7 @@ export async function createModel(config: ModelConfig): Promise<LanguageModel> {
|
||||
validateModelConfig(config)
|
||||
|
||||
// 1. 创建基础模型
|
||||
const baseModel = await createBaseModel(config.providerId, config.modelId, config.options)
|
||||
const baseModel = await createBaseModel(config)
|
||||
|
||||
// 2. 应用中间件(如果有)
|
||||
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
|
||||
@ -45,7 +38,7 @@ function validateModelConfig(config: ModelConfig): void {
|
||||
if (!config.modelId) {
|
||||
throw new Error('ModelConfig: modelId is required')
|
||||
}
|
||||
if (!config.options) {
|
||||
throw new Error('ModelConfig: options is required')
|
||||
if (!config.providerSettings) {
|
||||
throw new Error('ModelConfig: providerSettings is required')
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
// 主要的模型创建API
|
||||
export { createModel, createModels, type ModelConfig } from './factory'
|
||||
export { createModel, createModels } from './factory'
|
||||
|
||||
// 底层Provider创建功能(供高级用户使用)
|
||||
export {
|
||||
@ -16,4 +16,4 @@ export {
|
||||
} from './ProviderCreator'
|
||||
|
||||
// 保留原有类型
|
||||
export type { ModelCreationRequest, ResolvedConfig } from './types'
|
||||
export type { ModelConfig } from './types'
|
||||
|
||||
@ -1,32 +1,44 @@
|
||||
/**
|
||||
* Creation 模块类型定义
|
||||
*/
|
||||
// /**
|
||||
// * Creation 模块类型定义
|
||||
// */
|
||||
// import { LanguageModelV1Middleware } from 'ai'
|
||||
|
||||
// import { ProviderId, ProviderSettingsMap } from '../../types'
|
||||
// import { AiPlugin } from '../plugins'
|
||||
|
||||
// /**
|
||||
// * 模型创建请求
|
||||
// */
|
||||
// export interface ModelCreationRequest {
|
||||
// providerId: ProviderId
|
||||
// modelId: string
|
||||
// options: ProviderSettingsMap[ProviderId]
|
||||
// middlewares?: LanguageModelV1Middleware[]
|
||||
// }
|
||||
|
||||
// /**
|
||||
// * 配置解析结果
|
||||
// */
|
||||
// export interface ResolvedConfig {
|
||||
// provider: {
|
||||
// id: ProviderId
|
||||
// options: ProviderSettingsMap[ProviderId]
|
||||
// }
|
||||
// model: {
|
||||
// id: string
|
||||
// }
|
||||
// plugins: AiPlugin[]
|
||||
// middlewares: LanguageModelV1Middleware[]
|
||||
// }
|
||||
|
||||
import { LanguageModelV1Middleware } from 'ai'
|
||||
|
||||
import { ProviderId, ProviderSettingsMap } from '../../types'
|
||||
import { AiPlugin } from '../plugins'
|
||||
import type { ProviderId, ProviderSettingsMap } from '../../types'
|
||||
|
||||
/**
|
||||
* 模型创建请求
|
||||
*/
|
||||
export interface ModelCreationRequest {
|
||||
export interface ModelConfig {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
providerSettings: ProviderSettingsMap[ProviderId]
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 配置解析结果
|
||||
*/
|
||||
export interface ResolvedConfig {
|
||||
provider: {
|
||||
id: ProviderId
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
}
|
||||
model: {
|
||||
id: string
|
||||
}
|
||||
plugins: AiPlugin[]
|
||||
middlewares: LanguageModelV1Middleware[]
|
||||
extraModelConfig?: any
|
||||
}
|
||||
|
||||
@ -49,27 +49,29 @@ export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConf
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* 适配 Gemini 网络搜索
|
||||
* 将 googleSearch 工具放入 providerOptions.google.tools
|
||||
*/
|
||||
export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
|
||||
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
|
||||
const googleSearchTool = { googleSearch: {} }
|
||||
// export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
|
||||
// const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
|
||||
// const googleSearchTool = { googleSearch: {} }
|
||||
|
||||
const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : []
|
||||
// const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : []
|
||||
|
||||
return {
|
||||
...params,
|
||||
providerOptions: {
|
||||
...params.providerOptions,
|
||||
google: {
|
||||
...params.providerOptions?.google,
|
||||
tools: [...existingTools, googleSearchTool],
|
||||
...(config.extra || {})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// return {
|
||||
// ...params,
|
||||
// providerOptions: {
|
||||
// ...params.providerOptions,
|
||||
// google: {
|
||||
// ...params.providerOptions?.google,
|
||||
// useSearchGrounding: true,
|
||||
// // tools: [...existingTools, googleSearchTool],
|
||||
// ...(config.extra || {})
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* 适配 Anthropic 网络搜索
|
||||
@ -113,9 +115,10 @@ export function adaptWebSearchForProvider(
|
||||
case 'openai':
|
||||
return adaptOpenAIWebSearch(params, webSearchConfig)
|
||||
|
||||
case 'google':
|
||||
case 'google-vertex':
|
||||
return adaptGeminiWebSearch(params, webSearchConfig)
|
||||
// google的需要通过插件,在创建model的时候传入参数
|
||||
// case 'google':
|
||||
// case 'google-vertex':
|
||||
// return adaptGeminiWebSearch(params, webSearchConfig)
|
||||
|
||||
case 'anthropic':
|
||||
return adaptAnthropicWebSearch(params, webSearchConfig)
|
||||
@ -125,12 +128,3 @@ export function adaptWebSearchForProvider(
|
||||
return params
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查 provider 是否支持网络搜索
|
||||
*/
|
||||
export function isWebSearchSupported(providerId: string): boolean {
|
||||
const supportedProviders = ['openai', 'google', 'google-vertex', 'anthropic']
|
||||
|
||||
return supportedProviders.includes(providerId)
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig } from './helper'
|
||||
import { adaptWebSearchForProvider, type WebSearchConfig } from './helper'
|
||||
|
||||
/**
|
||||
* 网络搜索插件
|
||||
@ -14,42 +14,51 @@ import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig }
|
||||
* options.ts 文件负责将高层级的设置(如 assistant.enableWebSearch)
|
||||
* 转换为 providerOptions 中的 webSearch: { enabled: true } 配置。
|
||||
*/
|
||||
export const webSearchPlugin = definePlugin({
|
||||
name: 'webSearch',
|
||||
export const webSearchPlugin = (config) =>
|
||||
definePlugin({
|
||||
name: 'webSearch',
|
||||
enforce: 'pre',
|
||||
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
// configureModel: async (modelConfig: any, context: AiRequestContext) => {
|
||||
// if (context.providerId === 'google') {
|
||||
// return {
|
||||
// ...modelConfig
|
||||
// }
|
||||
// }
|
||||
// return null
|
||||
// },
|
||||
|
||||
// 从 providerOptions 中提取 webSearch 配置
|
||||
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
|
||||
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用)
|
||||
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) {
|
||||
return params
|
||||
// 从 providerOptions 中提取 webSearch 配置
|
||||
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch
|
||||
|
||||
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用)
|
||||
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) {
|
||||
return params
|
||||
}
|
||||
console.log('webSearchConfig', webSearchConfig)
|
||||
// // 检查当前 provider 是否支持网络搜索
|
||||
// if (!isWebSearchSupported(providerId)) {
|
||||
// // 对于不支持的 provider,只记录警告,不修改参数
|
||||
// console.warn(
|
||||
// `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
|
||||
// )
|
||||
// return params
|
||||
// }
|
||||
|
||||
// 使用适配器函数处理网络搜索
|
||||
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
|
||||
// 清理原始的 webSearch 配置
|
||||
if (adaptedParams.providerOptions?.[providerId]) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
|
||||
adaptedParams.providerOptions[providerId] = rest
|
||||
}
|
||||
return adaptedParams
|
||||
}
|
||||
|
||||
// 检查当前 provider 是否支持网络搜索
|
||||
if (!isWebSearchSupported(providerId)) {
|
||||
// 对于不支持的 provider,只记录警告,不修改参数
|
||||
console.warn(
|
||||
`[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
|
||||
)
|
||||
return params
|
||||
}
|
||||
|
||||
// 使用适配器函数处理网络搜索
|
||||
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
|
||||
|
||||
// 清理原始的 webSearch 配置
|
||||
if (adaptedParams.providerOptions?.[providerId]) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
|
||||
adaptedParams.providerOptions[providerId] = rest
|
||||
}
|
||||
|
||||
return adaptedParams
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export type { WebSearchConfig } from './helper'
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types'
|
||||
import type { ProviderId } from '../../types'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
// 插件管理器
|
||||
export { PluginManager } from './manager'
|
||||
|
||||
// 工具函数
|
||||
export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext {
|
||||
export function createContext<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
originalParams: any
|
||||
): AiRequestContext {
|
||||
return {
|
||||
providerId,
|
||||
modelId,
|
||||
|
||||
@ -52,7 +52,7 @@ export class PluginManager {
|
||||
*/
|
||||
async executeFirst<T>(
|
||||
hookName: 'resolveModel' | 'loadTemplate',
|
||||
arg: string,
|
||||
arg: any,
|
||||
context: AiRequestContext
|
||||
): Promise<T | null> {
|
||||
for (const plugin of this.plugins) {
|
||||
@ -71,7 +71,7 @@ export class PluginManager {
|
||||
* 执行 Sequential 钩子 - 链式数据转换
|
||||
*/
|
||||
async executeSequential<T>(
|
||||
hookName: 'transformParams' | 'transformResult',
|
||||
hookName: 'transformParams' | 'transformResult' | 'configureModel',
|
||||
initialValue: T,
|
||||
context: AiRequestContext
|
||||
): Promise<T> {
|
||||
@ -120,7 +120,9 @@ export class PluginManager {
|
||||
* 收集所有流转换器(返回数组,AI SDK 原生支持)
|
||||
*/
|
||||
collectStreamTransforms(params: any, context: AiRequestContext) {
|
||||
return this.plugins.map((plugin) => plugin.transformStream?.(params, context))
|
||||
return this.plugins
|
||||
.filter((plugin) => plugin.transformStream)
|
||||
.map((plugin) => plugin.transformStream?.(params, context))
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -217,7 +217,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
return await createModel({
|
||||
providerId: this.config.providerId,
|
||||
modelId: modelOrId,
|
||||
options: this.config.options,
|
||||
providerSettings: this.config.providerSettings,
|
||||
middlewares
|
||||
})
|
||||
} else {
|
||||
@ -245,7 +245,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
): RuntimeExecutor<T> {
|
||||
return new RuntimeExecutor({
|
||||
providerId,
|
||||
options,
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
@ -259,7 +259,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return new RuntimeExecutor({
|
||||
providerId: 'openai-compatible',
|
||||
options,
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ import { type AiPlugin } from '../plugins'
|
||||
*/
|
||||
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
options: ProviderSettingsMap[T]
|
||||
providerSettings: ProviderSettingsMap[T]
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user