Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package

This commit is contained in:
suyao 2025-07-07 02:09:01 +08:00
commit f20d964be3
No known key found for this signature in database
11 changed files with 132 additions and 162 deletions

View File

@ -1,31 +0,0 @@
/**
*
* model给用户
*/
import { LanguageModel } from 'ai'
import { wrapModelWithMiddlewares } from '../middleware'
import { createBaseModel } from './ProviderCreator'
import { ModelCreationRequest, ResolvedConfig } from './types'
/**
*
*/
export async function createModelFromConfig(config: ResolvedConfig): Promise<LanguageModel> {
// 使用ProviderCreator创建基础模型不应用中间件
const baseModel = await createBaseModel(config.provider.id, config.model.id, config.provider.options)
// 在creation层应用中间件用户不直接接触原始model
return wrapModelWithMiddlewares(baseModel, config.middlewares)
}
/**
*
*/
export async function createModel(request: ModelCreationRequest): Promise<LanguageModel> {
// 使用ProviderCreator创建基础模型不应用中间件
const baseModel = await createBaseModel(request.providerId, request.modelId, request.options)
const middlewares = request.middlewares || []
return wrapModelWithMiddlewares(baseModel, middlewares)
}

View File

@ -25,26 +25,43 @@ export class ProviderCreationError extends Error {
* AI SDK
* Provider 使 Provider 使 openai-compatible
*/
export async function createBaseModel<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T],
middlewares?: LanguageModelV2Middleware[]
): Promise<LanguageModelV2>
export async function createBaseModel<T extends ProviderId>({
providerId,
modelId,
providerSettings
// middlewares
}: {
providerId: T
modelId: string
providerSettings: ProviderSettingsMap[T]
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2>
export async function createBaseModel(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible'],
middlewares?: LanguageModelV2Middleware[]
): Promise<LanguageModelV2>
export async function createBaseModel({
providerId,
modelId,
providerSettings
// middlewares
}: {
providerId: string
modelId: string
providerSettings: ProviderSettingsMap['openai-compatible']
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2>
export async function createBaseModel(
providerId: string,
modelId: string = 'default',
options: any,
middlewares?: LanguageModelV2Middleware[]
): Promise<LanguageModelV2> {
export async function createBaseModel({
providerId,
modelId,
providerSettings,
// middlewares,
extraModelConfig
}: {
providerId: string
modelId: string
providerSettings: ProviderSettingsMap[ProviderId]
// middlewares?: LanguageModelV1Middleware[]
extraModelConfig?: any
}): Promise<LanguageModelV2> {
try {
// 对于不在注册表中的 provider默认使用 openai-compatible
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
@ -67,7 +84,7 @@ export async function createBaseModel(
)
}
// 创建provider实例
let provider = creatorFunction(options)
let provider = creatorFunction(providerSettings)
// 加一个特判
if (providerConfig.id === 'openai' && !isOpenAIChatCompletionOnlyModel(modelId)) {
@ -75,15 +92,16 @@ export async function createBaseModel(
}
// 返回模型实例
if (typeof provider === 'function') {
let model: LanguageModelV2 = provider(modelId)
// extraModelConfig:例如google的useSearchGrounding
const model: LanguageModelV2 = provider(modelId, extraModelConfig)
// 应用 AI SDK 中间件
if (middlewares && middlewares.length > 0) {
model = wrapLanguageModel({
model: model,
middleware: middlewares
})
}
// // 应用 AI SDK 中间件
// if (middlewares && middlewares.length > 0) {
// model = wrapLanguageModel({
// model: model,
// middleware: middlewares
// })
// }
return model
} else {

View File

@ -5,16 +5,9 @@
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import { LanguageModel } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { wrapModelWithMiddlewares } from '../middleware'
import { createBaseModel } from './ProviderCreator'
export interface ModelConfig {
providerId: ProviderId
modelId: string
options: ProviderSettingsMap[ProviderId]
middlewares?: LanguageModelV2Middleware[]
}
import { ModelConfig } from './types'
/**
* -
@ -23,7 +16,7 @@ export async function createModel(config: ModelConfig): Promise<LanguageModelV2>
validateModelConfig(config)
// 1. 创建基础模型
const baseModel = await createBaseModel(config.providerId, config.modelId, config.options)
const baseModel = await createBaseModel(config)
// 2. 应用中间件(如果有)
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
@ -46,7 +39,7 @@ function validateModelConfig(config: ModelConfig): void {
if (!config.modelId) {
throw new Error('ModelConfig: modelId is required')
}
if (!config.options) {
throw new Error('ModelConfig: options is required')
if (!config.providerSettings) {
throw new Error('ModelConfig: providerSettings is required')
}
}

View File

@ -4,7 +4,7 @@
*/
// 主要的模型创建API
export { createModel, createModels, type ModelConfig } from './factory'
export { createModel, createModels } from './factory'
// 底层Provider创建功能供高级用户使用
export {
@ -16,4 +16,4 @@ export {
} from './ProviderCreator'
// 保留原有类型
export type { ModelCreationRequest, ResolvedConfig } from './types'
export type { ModelConfig } from './types'

View File

@ -3,30 +3,11 @@
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import { ProviderId, ProviderSettingsMap } from '../../types'
import { AiPlugin } from '../plugins'
import type { ProviderId, ProviderSettingsMap } from '../../types'
/**
*
*/
export interface ModelCreationRequest {
export interface ModelConfig {
providerId: ProviderId
modelId: string
options: ProviderSettingsMap[ProviderId]
middlewares?: LanguageModelV2Middleware[]
}
/**
*
*/
export interface ResolvedConfig {
provider: {
id: ProviderId
options: ProviderSettingsMap[ProviderId]
}
model: {
id: string
}
plugins: AiPlugin[]
middlewares: LanguageModelV2Middleware[]
}

View File

@ -51,27 +51,29 @@ export function adaptOpenAIWebSearch(params: any, webSearchConfig: WebSearchConf
}
/**
*
* Gemini
* googleSearch providerOptions.google.tools
*/
export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
const googleSearchTool = { googleSearch: {} }
// export function adaptGeminiWebSearch(params: any, webSearchConfig: WebSearchConfig | boolean): any {
// const config = typeof webSearchConfig === 'boolean' ? {} : webSearchConfig
// const googleSearchTool = { googleSearch: {} }
const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : []
// const existingTools = Array.isArray(params.providerOptions?.google?.tools) ? params.providerOptions.google.tools : []
return {
...params,
providerOptions: {
...params.providerOptions,
google: {
...params.providerOptions?.google,
tools: [...existingTools, googleSearchTool],
...(config.extra || {})
}
}
}
}
// return {
// ...params,
// providerOptions: {
// ...params.providerOptions,
// google: {
// ...params.providerOptions?.google,
// useSearchGrounding: true,
// // tools: [...existingTools, googleSearchTool],
// ...(config.extra || {})
// }
// }
// }
// }
/**
* Anthropic
@ -115,9 +117,10 @@ export function adaptWebSearchForProvider(
case 'openai':
return adaptOpenAIWebSearch(params, webSearchConfig)
case 'google':
case 'google-vertex':
return adaptGeminiWebSearch(params, webSearchConfig)
// google的需要通过插件在创建model的时候传入参数
// case 'google':
// case 'google-vertex':
// return adaptGeminiWebSearch(params, webSearchConfig)
case 'anthropic':
return adaptAnthropicWebSearch(params, webSearchConfig)
@ -127,12 +130,3 @@ export function adaptWebSearchForProvider(
return params
}
}
/**
* provider
*/
export function isWebSearchSupported(providerId: string): boolean {
const supportedProviders = ['openai', 'google', 'google-vertex', 'anthropic']
return supportedProviders.includes(providerId)
}

View File

@ -5,7 +5,7 @@
import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types'
import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig } from './helper'
import { adaptWebSearchForProvider, type WebSearchConfig } from './helper'
/**
*
@ -14,42 +14,51 @@ import { adaptWebSearchForProvider, isWebSearchSupported, type WebSearchConfig }
* options.ts assistant.enableWebSearch
* providerOptions webSearch: { enabled: true }
*/
export const webSearchPlugin = definePlugin({
name: 'webSearch',
export const webSearchPlugin = (config) =>
definePlugin({
name: 'webSearch',
enforce: 'pre',
transformParams: async (params: any, context: AiRequestContext) => {
const { providerId } = context
// configureModel: async (modelConfig: any, context: AiRequestContext) => {
// if (context.providerId === 'google') {
// return {
// ...modelConfig
// }
// }
// return null
// },
// 从 providerOptions 中提取 webSearch 配置
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch
transformParams: async (params: any, context: AiRequestContext) => {
const { providerId } = context
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用)
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) {
return params
// 从 providerOptions 中提取 webSearch 配置
const webSearchConfig = params.providerOptions?.[providerId]?.webSearch
// 检查是否启用了网络搜索 (enabled: false 可用于显式禁用)
if (!webSearchConfig || (typeof webSearchConfig === 'object' && webSearchConfig.enabled === false)) {
return params
}
console.log('webSearchConfig', webSearchConfig)
// // 检查当前 provider 是否支持网络搜索
// if (!isWebSearchSupported(providerId)) {
// // 对于不支持的 provider只记录警告不修改参数
// console.warn(
// `[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
// )
// return params
// }
// 使用适配器函数处理网络搜索
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
// 清理原始的 webSearch 配置
if (adaptedParams.providerOptions?.[providerId]) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
adaptedParams.providerOptions[providerId] = rest
}
return adaptedParams
}
// 检查当前 provider 是否支持网络搜索
if (!isWebSearchSupported(providerId)) {
// 对于不支持的 provider只记录警告不修改参数
console.warn(
`[webSearchPlugin] Provider '${providerId}' does not support web search. Ignoring webSearch parameter.`
)
return params
}
// 使用适配器函数处理网络搜索
const adaptedParams = adaptWebSearchForProvider(params, providerId, webSearchConfig as WebSearchConfig | boolean)
// 清理原始的 webSearch 配置
if (adaptedParams.providerOptions?.[providerId]) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { webSearch, ...rest } = adaptedParams.providerOptions[providerId]
adaptedParams.providerOptions[providerId] = rest
}
return adaptedParams
}
})
})
// 导出类型定义供开发者使用
export type { WebSearchConfig } from './helper'

View File

@ -1,13 +1,17 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types'
import { ProviderId } from '../providers/registry'
import type { ProviderId } from '../../types'
import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器
export { PluginManager } from './manager'
// 工具函数
export function createContext(providerId: ProviderId, modelId: string, originalParams: any): AiRequestContext {
export function createContext<T extends ProviderId>(
providerId: T,
modelId: string,
originalParams: any
): AiRequestContext {
return {
providerId,
modelId,

View File

@ -52,7 +52,7 @@ export class PluginManager {
*/
async executeFirst<T>(
hookName: 'resolveModel' | 'loadTemplate',
arg: string,
arg: any,
context: AiRequestContext
): Promise<T | null> {
for (const plugin of this.plugins) {
@ -71,7 +71,7 @@ export class PluginManager {
* Sequential -
*/
async executeSequential<T>(
hookName: 'transformParams' | 'transformResult',
hookName: 'transformParams' | 'transformResult' | 'configureModel',
initialValue: T,
context: AiRequestContext
): Promise<T> {
@ -120,7 +120,9 @@ export class PluginManager {
* AI SDK
*/
collectStreamTransforms(params: any, context: AiRequestContext) {
return this.plugins.map((plugin) => plugin.transformStream?.(params, context))
return this.plugins
.filter((plugin) => plugin.transformStream)
.map((plugin) => plugin.transformStream?.(params, context))
}
/**

View File

@ -218,7 +218,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
return await createModel({
providerId: this.config.providerId,
modelId: modelOrId,
options: this.config.options,
providerSettings: this.config.providerSettings,
middlewares
})
} else {
@ -246,7 +246,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
): RuntimeExecutor<T> {
return new RuntimeExecutor({
providerId,
options,
providerSettings: options,
plugins
})
}
@ -260,7 +260,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
): RuntimeExecutor<'openai-compatible'> {
return new RuntimeExecutor({
providerId: 'openai-compatible',
options,
providerSettings: options,
plugins
})
}

View File

@ -9,7 +9,7 @@ import { type AiPlugin } from '../plugins'
*/
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
providerId: T
options: ProviderSettingsMap[T]
providerSettings: ProviderSettingsMap[T]
plugins?: AiPlugin[]
}