Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package

This commit is contained in:
suyao 2025-07-07 18:43:27 +08:00
commit 342c5ab82c
No known key found for this signature in database
7 changed files with 85 additions and 42 deletions

View File

@ -27,24 +27,28 @@ export class ProviderCreationError extends Error {
export async function createBaseModel<T extends ProviderId>({
providerId,
modelId,
providerSettings
providerSettings,
extraModelConfig
// middlewares
}: {
providerId: T
modelId: string
providerSettings: ProviderSettingsMap[T]
extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2>
export async function createBaseModel({
providerId,
modelId,
providerSettings
providerSettings,
extraModelConfig
// middlewares
}: {
providerId: string
modelId: string
providerSettings: ProviderSettingsMap['openai-compatible']
extraModelConfig?: any
// middlewares?: LanguageModelV1Middleware[]
}): Promise<LanguageModelV2>

View File

@ -16,11 +16,7 @@ export async function createModel(config: ModelConfig): Promise<LanguageModelV2>
validateModelConfig(config)
// 1. 创建基础模型
const baseModel = await createBaseModel({
providerId: config.providerId,
modelId: config.modelId,
providerSettings: config.options
})
const baseModel = await createBaseModel(config)
// 2. 应用中间件(如果有)
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
@ -43,7 +39,7 @@ function validateModelConfig(config: ModelConfig): void {
if (!config.modelId) {
throw new Error('ModelConfig: modelId is required')
}
if (!config.options) {
if (!config.providerSettings) {
throw new Error('ModelConfig: providerSettings is required')
}
}

View File

@ -5,9 +5,10 @@ import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import type { ProviderId, ProviderSettingsMap } from '../../types'
export interface ModelConfig {
providerId: ProviderId
export interface ModelConfig<T extends ProviderId = ProviderId> {
providerId: T
modelId: string
options: ProviderSettingsMap[ProviderId]
providerSettings: ProviderSettingsMap[T]
middlewares?: LanguageModelV2Middleware[]
extraModelConfig?: Record<string, any>
}

View File

@ -71,7 +71,7 @@ export class PluginManager {
* Sequential -
*/
async executeSequential<T>(
hookName: 'transformParams' | 'transformResult' | 'configureModel',
hookName: 'transformParams' | 'transformResult',
initialValue: T,
context: AiRequestContext
): Promise<T> {
@ -87,6 +87,18 @@ export class PluginManager {
return result
}
/**
* ConfigureContext -
*/
async executeConfigureContext(context: AiRequestContext): Promise<void> {
for (const plugin of this.plugins) {
const hook = plugin.configureContext
if (hook) {
await hook(context)
}
}
}
/**
* Parallel -
*/

View File

@ -1,4 +1,4 @@
import type { TextStreamPart, ToolSet } from 'ai'
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
import { ProviderId } from '../providers/registry'
@ -32,13 +32,13 @@ export interface AiPlugin {
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
resolveModel?: (modelId: string, context: AiRequestContext) => Promise<LanguageModel | null> | LanguageModel | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
configureContext?: (context: AiRequestContext) => void | Promise<void>
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
configureModel?: (model: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>

View File

@ -7,7 +7,7 @@ import { generateObject, generateText, LanguageModel, streamObject, streamText }
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { createModel, getProviderInfo } from '../models'
import { type AiPlugin } from '../plugins'
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types'
@ -28,6 +28,19 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
}
createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
return definePlugin({
name: '_internal_resolveModel',
enforce: 'post',
resolveModel: async (modelId: string, context: AiRequestContext) => {
// 从 context 中读取由用户插件注入的 extraModelConfig
const extraModelConfig = context.extraModelConfig || {}
return await this.resolveModel(modelId, middlewares, extraModelConfig)
}
})
}
// === 高阶重载:直接使用模型 ===
/**
@ -59,14 +72,14 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamText>> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
// 2. 执行插件处理
return this.pluginEngine.executeStreamWithPlugins(
'streamText',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (finalModelId, transformedParams, streamTransforms) => {
async (model, transformedParams, streamTransforms) => {
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
@ -110,13 +123,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateText>> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
return this.pluginEngine.executeWithPlugins(
'generateText',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (finalModelId, transformedParams) => {
async (model, transformedParams) => {
return await generateText({ model, ...transformedParams })
}
)
@ -151,13 +164,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateObject>> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
return this.pluginEngine.executeWithPlugins(
'generateObject',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (finalModelId, transformedParams) => {
async (model, transformedParams) => {
return await generateObject({ model, ...transformedParams })
}
)
@ -192,13 +205,13 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamObject>> {
const model = await this.resolveModel(modelOrId, options?.middlewares)
this.pluginEngine.use(this.createResolveModelPlugin(options?.middlewares))
return this.pluginEngine.executeWithPlugins(
'streamObject',
typeof modelOrId === 'string' ? modelOrId : model.modelId,
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (finalModelId, transformedParams) => {
async (model, transformedParams) => {
return await streamObject({ model, ...transformedParams })
}
)
@ -211,15 +224,17 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
*/
private async resolveModel(
modelOrId: LanguageModel,
middlewares?: LanguageModelV2Middleware[]
middlewares?: LanguageModelV2Middleware[],
extraModelConfig?: Record<string, any>
): Promise<LanguageModelV2> {
if (typeof modelOrId === 'string') {
// 字符串modelId需要创建模型
return await createModel({
providerId: this.config.providerId,
modelId: modelOrId,
options: this.config.providerSettings,
middlewares
providerSettings: this.config.providerSettings,
middlewares,
extraModelConfig
})
} else {
// 已经是模型,直接返回

View File

@ -1,3 +1,5 @@
import { LanguageModel } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry'
@ -63,7 +65,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
methodName: string,
modelId: string,
params: TParams,
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>,
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 使用正确的createContext创建请求上下文
@ -79,18 +81,23 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId
// 2. 解析模型
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(finalModelId, transformedParams)
const result = await executor(model, transformedParams)
// 5. 转换结果(对于非流式调用)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
@ -114,7 +121,7 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
methodName: string,
modelId: string,
params: TParams,
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 创建请求上下文
@ -130,12 +137,18 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId
// 2. 解析模型
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
@ -144,12 +157,14 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
// 5. 执行流式 API 调用
const result = await executor(finalModelId, transformedParams, streamTransforms)
const result = await executor(model, transformedParams, streamTransforms)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true })
return result
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)