mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 06:49:02 +08:00
feat: enhance AI Core runtime with advanced model handling and middleware support
- Introduced new high-level APIs for model creation and configuration, improving usability for advanced users. - Enhanced the RuntimeExecutor to support both direct model usage and model ID resolution, allowing for more flexible execution options. - Updated existing methods to accept middleware configurations, streamlining the integration of custom processing logic. - Refactored the plugin system to better accommodate middleware, enhancing the overall extensibility of the AI Core. - Improved documentation to reflect the new capabilities and usage patterns for the runtime APIs.
This commit is contained in:
parent
7d8ed3a737
commit
e4c0ea035f
@ -2,7 +2,7 @@
|
|||||||
* 运行时执行器
|
* 运行时执行器
|
||||||
* 专注于插件化的AI调用处理
|
* 专注于插件化的AI调用处理
|
||||||
*/
|
*/
|
||||||
import { generateObject, generateText, LanguageModelV1, streamObject, streamText } from 'ai'
|
import { generateObject, generateText, LanguageModel, LanguageModelV1Middleware, streamObject, streamText } from 'ai'
|
||||||
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||||
import { createModel, getProviderInfo } from '../models'
|
import { createModel, getProviderInfo } from '../models'
|
||||||
@ -28,24 +28,43 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
this.pluginClient = new PluginEngine(config.providerId, config.plugins || [])
|
this.pluginClient = new PluginEngine(config.providerId, config.plugins || [])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === 高阶重载:直接使用模型 ===
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 流式文本生成 - 使用modelId自动创建模型
|
* 流式文本生成 - 使用已创建的模型(高级用法)
|
||||||
|
*/
|
||||||
|
async streamText(
|
||||||
|
model: LanguageModel,
|
||||||
|
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
||||||
|
): Promise<ReturnType<typeof streamText>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式文本生成 - 使用modelId + 可选middleware(灵活用法)
|
||||||
*/
|
*/
|
||||||
async streamText(
|
async streamText(
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof streamText>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式文本生成 - 内部实现(统一处理重载)
|
||||||
|
*/
|
||||||
|
async streamText(
|
||||||
|
modelOrId: LanguageModel | string,
|
||||||
|
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
): Promise<ReturnType<typeof streamText>> {
|
): Promise<ReturnType<typeof streamText>> {
|
||||||
// 1. 使用 createModel 创建模型
|
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||||
const model = await createModel({
|
|
||||||
providerId: this.config.providerId,
|
|
||||||
modelId,
|
|
||||||
options: this.config.options
|
|
||||||
})
|
|
||||||
|
|
||||||
// 2. 执行插件处理
|
// 2. 执行插件处理
|
||||||
return this.pluginClient.executeStreamWithPlugins(
|
return this.pluginClient.executeStreamWithPlugins(
|
||||||
'streamText',
|
'streamText',
|
||||||
modelId,
|
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams, streamTransforms) => {
|
async (finalModelId, transformedParams, streamTransforms) => {
|
||||||
const experimental_transform =
|
const experimental_transform =
|
||||||
@ -60,46 +79,42 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === 其他方法的重载 ===
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 流式文本生成 - 直接使用已创建的模型
|
* 生成文本 - 使用已创建的模型
|
||||||
*/
|
*/
|
||||||
async streamTextWithModel(
|
async generateText(
|
||||||
model: LanguageModelV1,
|
model: LanguageModel,
|
||||||
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
||||||
): Promise<ReturnType<typeof streamText>> {
|
): Promise<ReturnType<typeof generateText>>
|
||||||
return this.pluginClient.executeStreamWithPlugins(
|
|
||||||
'streamText',
|
|
||||||
model.modelId,
|
|
||||||
params,
|
|
||||||
async (finalModelId, transformedParams, streamTransforms) => {
|
|
||||||
const experimental_transform =
|
|
||||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
|
||||||
|
|
||||||
return await streamText({
|
|
||||||
model,
|
|
||||||
...transformedParams,
|
|
||||||
experimental_transform
|
|
||||||
})
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 生成文本
|
* 生成文本 - 使用modelId + 可选middleware
|
||||||
*/
|
*/
|
||||||
async generateText(
|
async generateText(
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof generateText>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成文本 - 内部实现
|
||||||
|
*/
|
||||||
|
async generateText(
|
||||||
|
modelOrId: LanguageModel | string,
|
||||||
|
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
): Promise<ReturnType<typeof generateText>> {
|
): Promise<ReturnType<typeof generateText>> {
|
||||||
const model = await createModel({
|
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||||
providerId: this.config.providerId,
|
|
||||||
modelId,
|
|
||||||
options: this.config.options
|
|
||||||
})
|
|
||||||
|
|
||||||
return this.pluginClient.executeWithPlugins(
|
return this.pluginClient.executeWithPlugins(
|
||||||
'generateText',
|
'generateText',
|
||||||
modelId,
|
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams) => {
|
async (finalModelId, transformedParams) => {
|
||||||
return await generateText({ model, ...transformedParams })
|
return await generateText({ model, ...transformedParams })
|
||||||
@ -108,21 +123,39 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 生成结构化对象
|
* 生成结构化对象 - 使用已创建的模型
|
||||||
*/
|
*/
|
||||||
async generateObject(
|
async generateObject(
|
||||||
modelId: string,
|
model: LanguageModel,
|
||||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||||
|
): Promise<ReturnType<typeof generateObject>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成结构化对象 - 使用modelId + 可选middleware
|
||||||
|
*/
|
||||||
|
async generateObject(
|
||||||
|
modelOrId: string,
|
||||||
|
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof generateObject>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成结构化对象 - 内部实现
|
||||||
|
*/
|
||||||
|
async generateObject(
|
||||||
|
modelOrId: LanguageModel | string,
|
||||||
|
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
): Promise<ReturnType<typeof generateObject>> {
|
): Promise<ReturnType<typeof generateObject>> {
|
||||||
const model = await createModel({
|
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||||
providerId: this.config.providerId,
|
|
||||||
modelId,
|
|
||||||
options: this.config.options
|
|
||||||
})
|
|
||||||
|
|
||||||
return this.pluginClient.executeWithPlugins(
|
return this.pluginClient.executeWithPlugins(
|
||||||
'generateObject',
|
'generateObject',
|
||||||
modelId,
|
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams) => {
|
async (finalModelId, transformedParams) => {
|
||||||
return await generateObject({ model, ...transformedParams })
|
return await generateObject({ model, ...transformedParams })
|
||||||
@ -131,27 +164,69 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 流式生成结构化对象
|
* 流式生成结构化对象 - 使用已创建的模型
|
||||||
|
*/
|
||||||
|
async streamObject(
|
||||||
|
model: LanguageModel,
|
||||||
|
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||||
|
): Promise<ReturnType<typeof streamObject>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式生成结构化对象 - 使用modelId + 可选middleware
|
||||||
*/
|
*/
|
||||||
async streamObject(
|
async streamObject(
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof streamObject>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式生成结构化对象 - 内部实现
|
||||||
|
*/
|
||||||
|
async streamObject(
|
||||||
|
modelOrId: LanguageModel | string,
|
||||||
|
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
): Promise<ReturnType<typeof streamObject>> {
|
): Promise<ReturnType<typeof streamObject>> {
|
||||||
const model = await createModel({
|
const model = await this.resolveModel(modelOrId, options?.middlewares)
|
||||||
providerId: this.config.providerId,
|
|
||||||
modelId,
|
|
||||||
options: this.config.options
|
|
||||||
})
|
|
||||||
|
|
||||||
return this.pluginClient.executeWithPlugins(
|
return this.pluginClient.executeWithPlugins(
|
||||||
'streamObject',
|
'streamObject',
|
||||||
modelId,
|
typeof modelOrId === 'string' ? modelOrId : model.modelId,
|
||||||
params,
|
params,
|
||||||
async (finalModelId, transformedParams) => {
|
async (finalModelId, transformedParams) => {
|
||||||
return await streamObject({ model, ...transformedParams })
|
return await streamObject({ model, ...transformedParams })
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === 辅助方法 ===
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析模型:如果是字符串则创建模型,如果是模型则直接返回
|
||||||
|
*/
|
||||||
|
private async resolveModel(
|
||||||
|
modelOrId: LanguageModel | string,
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
): Promise<LanguageModel> {
|
||||||
|
if (typeof modelOrId === 'string') {
|
||||||
|
// 字符串modelId,需要创建模型
|
||||||
|
return await createModel({
|
||||||
|
providerId: this.config.providerId,
|
||||||
|
modelId: modelOrId,
|
||||||
|
options: this.config.options,
|
||||||
|
middlewares
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// 已经是模型,直接返回
|
||||||
|
return modelOrId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取客户端信息
|
* 获取客户端信息
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -16,6 +16,8 @@ export type {
|
|||||||
|
|
||||||
// === 便捷工厂函数 ===
|
// === 便捷工厂函数 ===
|
||||||
|
|
||||||
|
import { LanguageModelV1Middleware } from 'ai'
|
||||||
|
|
||||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||||
import { type AiPlugin } from '../plugins'
|
import { type AiPlugin } from '../plugins'
|
||||||
import { RuntimeExecutor } from './executor'
|
import { RuntimeExecutor } from './executor'
|
||||||
@ -44,59 +46,63 @@ export function createOpenAICompatibleExecutor(
|
|||||||
// === 直接调用API(无需创建executor实例)===
|
// === 直接调用API(无需创建executor实例)===
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 直接流式文本生成
|
* 直接流式文本生成 - 支持middlewares
|
||||||
*/
|
*/
|
||||||
export async function streamText<T extends ProviderId>(
|
export async function streamText<T extends ProviderId>(
|
||||||
providerId: T,
|
providerId: T,
|
||||||
options: ProviderSettingsMap[T],
|
options: ProviderSettingsMap[T],
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
|
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
|
||||||
plugins?: AiPlugin[]
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||||
const executor = createExecutor(providerId, options, plugins)
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
return executor.streamText(modelId, params)
|
return executor.streamText(modelId, params, { middlewares })
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 直接生成文本
|
* 直接生成文本 - 支持middlewares
|
||||||
*/
|
*/
|
||||||
export async function generateText<T extends ProviderId>(
|
export async function generateText<T extends ProviderId>(
|
||||||
providerId: T,
|
providerId: T,
|
||||||
options: ProviderSettingsMap[T],
|
options: ProviderSettingsMap[T],
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
|
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
|
||||||
plugins?: AiPlugin[]
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||||
const executor = createExecutor(providerId, options, plugins)
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
return executor.generateText(modelId, params)
|
return executor.generateText(modelId, params, { middlewares })
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 直接生成结构化对象
|
* 直接生成结构化对象 - 支持middlewares
|
||||||
*/
|
*/
|
||||||
export async function generateObject<T extends ProviderId>(
|
export async function generateObject<T extends ProviderId>(
|
||||||
providerId: T,
|
providerId: T,
|
||||||
options: ProviderSettingsMap[T],
|
options: ProviderSettingsMap[T],
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
|
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
|
||||||
plugins?: AiPlugin[]
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
||||||
const executor = createExecutor(providerId, options, plugins)
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
return executor.generateObject(modelId, params)
|
return executor.generateObject(modelId, params, { middlewares })
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 直接流式生成结构化对象
|
* 直接流式生成结构化对象 - 支持middlewares
|
||||||
*/
|
*/
|
||||||
export async function streamObject<T extends ProviderId>(
|
export async function streamObject<T extends ProviderId>(
|
||||||
providerId: T,
|
providerId: T,
|
||||||
options: ProviderSettingsMap[T],
|
options: ProviderSettingsMap[T],
|
||||||
modelId: string,
|
modelId: string,
|
||||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
|
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
|
||||||
plugins?: AiPlugin[]
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
||||||
const executor = createExecutor(providerId, options, plugins)
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
return executor.streamObject(modelId, params)
|
return executor.streamObject(modelId, params, { middlewares })
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Agent 功能预留 ===
|
// === Agent 功能预留 ===
|
||||||
|
|||||||
@ -55,17 +55,6 @@ export class PluginEngine<T extends ProviderId = ProviderId> {
|
|||||||
return this.pluginManager.getPlugins()
|
return this.pluginManager.getPlugins()
|
||||||
}
|
}
|
||||||
|
|
||||||
// /**
|
|
||||||
// * 使用core模块创建模型(包含中间件)
|
|
||||||
// */
|
|
||||||
// async createModelWithMiddlewares(modelId: string): Promise<any> {
|
|
||||||
// // 使用core模块的resolveConfig解析配置
|
|
||||||
// const config = resolveConfig(this.providerId, modelId, this.options, this.pluginManager.getPlugins())
|
|
||||||
|
|
||||||
// // 使用core模块创建包装好的模型
|
|
||||||
// return createModelFromConfig(config)
|
|
||||||
// }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 执行带插件的操作(非流式)
|
* 执行带插件的操作(非流式)
|
||||||
* 提供给AiExecutor使用
|
* 提供给AiExecutor使用
|
||||||
|
|||||||
@ -15,6 +15,9 @@ import { ProviderId, type ProviderSettingsMap } from './types'
|
|||||||
// ==================== 主要用户接口 ====================
|
// ==================== 主要用户接口 ====================
|
||||||
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
|
export { createExecutor, createOpenAICompatibleExecutor } from './core/runtime'
|
||||||
|
|
||||||
|
// ==================== 高级API ====================
|
||||||
|
export { createModel, type ModelConfig } from './core/models'
|
||||||
|
|
||||||
// ==================== 插件系统 ====================
|
// ==================== 插件系统 ====================
|
||||||
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins'
|
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './core/plugins'
|
||||||
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||||
|
|||||||
@ -9,8 +9,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import {
|
import {
|
||||||
AiClient,
|
createExecutor,
|
||||||
createClient,
|
|
||||||
ProviderConfigFactory,
|
ProviderConfigFactory,
|
||||||
type ProviderId,
|
type ProviderId,
|
||||||
type ProviderSettingsMap,
|
type ProviderSettingsMap,
|
||||||
@ -101,7 +100,7 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default class ModernAiProvider {
|
export default class ModernAiProvider {
|
||||||
private modernClient?: AiClient
|
private modernExecutor?: ReturnType<typeof createExecutor>
|
||||||
private legacyProvider: LegacyAiProvider
|
private legacyProvider: LegacyAiProvider
|
||||||
private provider: Provider
|
private provider: Provider
|
||||||
|
|
||||||
@ -109,9 +108,10 @@ export default class ModernAiProvider {
|
|||||||
this.provider = provider
|
this.provider = provider
|
||||||
this.legacyProvider = new LegacyAiProvider(provider)
|
this.legacyProvider = new LegacyAiProvider(provider)
|
||||||
|
|
||||||
|
// TODO:如果后续在调用completions时需要切换provider的话,
|
||||||
// 初始化时不构建中间件,等到需要时再构建
|
// 初始化时不构建中间件,等到需要时再构建
|
||||||
const config = providerToAiSdkConfig(provider)
|
const config = providerToAiSdkConfig(provider)
|
||||||
this.modernClient = createClient(config.providerId, config.options)
|
this.modernExecutor = createExecutor(config.providerId, config.options)
|
||||||
}
|
}
|
||||||
|
|
||||||
public async completions(
|
public async completions(
|
||||||
@ -145,7 +145,7 @@ export default class ModernAiProvider {
|
|||||||
params: StreamTextParams,
|
params: StreamTextParams,
|
||||||
middlewareConfig: AiSdkMiddlewareConfig
|
middlewareConfig: AiSdkMiddlewareConfig
|
||||||
): Promise<CompletionsResult> {
|
): Promise<CompletionsResult> {
|
||||||
if (!this.modernClient) {
|
if (!this.modernExecutor) {
|
||||||
throw new Error('Modern AI SDK client not initialized')
|
throw new Error('Modern AI SDK client not initialized')
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,26 +160,25 @@ export default class ModernAiProvider {
|
|||||||
|
|
||||||
// 动态构建中间件数组
|
// 动态构建中间件数组
|
||||||
const middlewares = buildAiSdkMiddlewares(finalConfig)
|
const middlewares = buildAiSdkMiddlewares(finalConfig)
|
||||||
console.log(
|
console.log('构建的中间件:', middlewares.length)
|
||||||
'构建的中间件:',
|
|
||||||
middlewares.map((m) => m.name)
|
|
||||||
)
|
|
||||||
|
|
||||||
// 创建带有中间件的客户端
|
|
||||||
const config = providerToAiSdkConfig(this.provider)
|
|
||||||
const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares)
|
|
||||||
|
|
||||||
|
// 创建带有中间件的执行器
|
||||||
if (middlewareConfig.onChunk) {
|
if (middlewareConfig.onChunk) {
|
||||||
// 流式处理 - 使用适配器
|
// 流式处理 - 使用适配器
|
||||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||||
const streamResult = await clientWithMiddlewares.streamText(modelId, {
|
const streamResult = await this.modernExecutor.streamText(
|
||||||
...params,
|
modelId,
|
||||||
experimental_transform: smoothStream({
|
{
|
||||||
delayInMs: 80,
|
...params,
|
||||||
// 中文3个字符一个chunk,英文一个单词一个chunk
|
experimental_transform: smoothStream({
|
||||||
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
delayInMs: 80,
|
||||||
})
|
// 中文3个字符一个chunk,英文一个单词一个chunk
|
||||||
})
|
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
||||||
|
})
|
||||||
|
},
|
||||||
|
middlewares.length > 0 ? { middlewares } : undefined
|
||||||
|
)
|
||||||
|
|
||||||
const finalText = await adapter.processStream(streamResult)
|
const finalText = await adapter.processStream(streamResult)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -187,7 +186,11 @@ export default class ModernAiProvider {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 流式处理但没有 onChunk 回调
|
// 流式处理但没有 onChunk 回调
|
||||||
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
|
const streamResult = await this.modernExecutor.streamText(
|
||||||
|
modelId,
|
||||||
|
params,
|
||||||
|
middlewares.length > 0 ? { middlewares } : undefined
|
||||||
|
)
|
||||||
const finalText = await streamResult.text
|
const finalText = await streamResult.text
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -1,4 +1,8 @@
|
|||||||
import { AiPlugin, extractReasoningMiddleware, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
|
import {
|
||||||
|
extractReasoningMiddleware,
|
||||||
|
LanguageModelV1Middleware,
|
||||||
|
simulateStreamingMiddleware
|
||||||
|
} from '@cherrystudio/ai-core'
|
||||||
import { isReasoningModel } from '@renderer/config/models'
|
import { isReasoningModel } from '@renderer/config/models'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import type { Chunk } from '@renderer/types/chunk'
|
import type { Chunk } from '@renderer/types/chunk'
|
||||||
@ -21,7 +25,10 @@ export interface AiSdkMiddlewareConfig {
|
|||||||
/**
|
/**
|
||||||
* 具名的 AI SDK 中间件
|
* 具名的 AI SDK 中间件
|
||||||
*/
|
*/
|
||||||
export type NamedAiSdkMiddleware = AiPlugin
|
export interface NamedAiSdkMiddleware {
|
||||||
|
name: string
|
||||||
|
middleware: LanguageModelV1Middleware
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI SDK 中间件建造者
|
* AI SDK 中间件建造者
|
||||||
@ -69,7 +76,14 @@ export class AiSdkMiddlewareBuilder {
|
|||||||
/**
|
/**
|
||||||
* 构建最终的中间件数组
|
* 构建最终的中间件数组
|
||||||
*/
|
*/
|
||||||
public build(): NamedAiSdkMiddleware[] {
|
public build(): LanguageModelV1Middleware[] {
|
||||||
|
return this.middlewares.map((m) => m.middleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取具名中间件数组(用于调试)
|
||||||
|
*/
|
||||||
|
public buildNamed(): NamedAiSdkMiddleware[] {
|
||||||
return [...this.middlewares]
|
return [...this.middlewares]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,14 +107,14 @@ export class AiSdkMiddlewareBuilder {
|
|||||||
* 根据配置构建AI SDK中间件的工厂函数
|
* 根据配置构建AI SDK中间件的工厂函数
|
||||||
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
|
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
|
||||||
*/
|
*/
|
||||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] {
|
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] {
|
||||||
const builder = new AiSdkMiddlewareBuilder()
|
const builder = new AiSdkMiddlewareBuilder()
|
||||||
|
|
||||||
// 1. 思考模型且有onChunk回调时添加思考时间中间件
|
// 1. 思考模型且有onChunk回调时添加思考时间中间件
|
||||||
if (config.onChunk && config.model && isReasoningModel(config.model)) {
|
if (config.onChunk && config.model && isReasoningModel(config.model)) {
|
||||||
builder.add({
|
builder.add({
|
||||||
name: 'thinking-time',
|
name: 'thinking-time',
|
||||||
aiSdkMiddlewares: [thinkingTimeMiddleware()]
|
middleware: thinkingTimeMiddleware()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,7 +135,7 @@ export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdk
|
|||||||
if (config.streamOutput === false) {
|
if (config.streamOutput === false) {
|
||||||
builder.add({
|
builder.add({
|
||||||
name: 'simulate-streaming',
|
name: 'simulate-streaming',
|
||||||
aiSdkMiddlewares: [simulateStreamingMiddleware()]
|
middleware: simulateStreamingMiddleware()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,7 +156,7 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
|
|||||||
case 'openai':
|
case 'openai':
|
||||||
builder.add({
|
builder.add({
|
||||||
name: 'thinking-tag-extraction',
|
name: 'thinking-tag-extraction',
|
||||||
aiSdkMiddlewares: [extractReasoningMiddleware({ tagName: 'think' })]
|
middleware: extractReasoningMiddleware({ tagName: 'think' })
|
||||||
})
|
})
|
||||||
break
|
break
|
||||||
case 'gemini':
|
case 'gemini':
|
||||||
@ -183,8 +197,12 @@ export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfi
|
|||||||
const builder = new AiSdkMiddlewareBuilder()
|
const builder = new AiSdkMiddlewareBuilder()
|
||||||
const defaultMiddlewares = buildAiSdkMiddlewares(config)
|
const defaultMiddlewares = buildAiSdkMiddlewares(config)
|
||||||
|
|
||||||
defaultMiddlewares.forEach((middleware) => {
|
// 将普通中间件数组转换为具名中间件并添加
|
||||||
builder.add(middleware)
|
defaultMiddlewares.forEach((middleware, index) => {
|
||||||
|
builder.add({
|
||||||
|
name: `default-middleware-${index}`,
|
||||||
|
middleware
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
return builder
|
return builder
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user