From 3771b24b52bcb16ea02679ee2025de6fa56010f9 Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 20 Jun 2025 15:31:41 +0800 Subject: [PATCH] feat: enhance AI SDK middleware integration and support - Added AiSdkMiddlewareBuilder for dynamic middleware construction based on various conditions. - Updated ModernAiProvider to utilize new middleware configuration, improving flexibility in handling completions. - Refactored ApiService to pass middleware configuration during AI completions, enabling better control over processing. - Introduced new README documentation for the middleware builder, outlining usage and supported conditions. --- packages/aiCore/src/index.ts | 1 + src/renderer/src/aiCore/index_new.ts | 47 +++-- .../aisdk/AiSdkMiddlewareBuilder.ts | 188 ++++++++++++++++++ .../src/aiCore/middleware/aisdk/README.md | 140 +++++++++++++ src/renderer/src/services/ApiService.ts | 13 +- 5 files changed, 372 insertions(+), 17 deletions(-) create mode 100644 src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts create mode 100644 src/renderer/src/aiCore/middleware/aisdk/README.md diff --git a/packages/aiCore/src/index.ts b/packages/aiCore/src/index.ts index 0165214613..98c3352b6d 100644 --- a/packages/aiCore/src/index.ts +++ b/packages/aiCore/src/index.ts @@ -79,6 +79,7 @@ export type { ToolExecutionError, ToolResult } from 'ai' +export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' // 重新导出所有 Provider Settings 类型 export type { diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 559f9bb5f7..b8aec419f3 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -24,7 +24,7 @@ import { Chunk } from '@renderer/types/chunk' import AiSdkToChunkAdapter from './AiSdkToChunkAdapter' // 引入原有的AiProvider作为fallback import LegacyAiProvider from './index' -import thinkingTimeMiddleware from './middleware/aisdk/ThinkingTimeMiddleware' +import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder' import { CompletionsResult } from './middleware/schemas' // 引入参数转换模块 @@ -88,7 +88,7 @@ function providerToAiSdkConfig(provider: Provider): { */ function isModernSdkSupported(provider: Provider, model?: Model): boolean { // 目前支持主要的providers - const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai'] + const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai'] // 检查provider类型 if (!supportedProviders.includes(provider.type)) { @@ -108,21 +108,19 @@ export default class ModernAiProvider { private legacyProvider: LegacyAiProvider private provider: Provider - constructor(provider: Provider, onChunk?: (chunk: Chunk) => void) { + constructor(provider: Provider) { this.provider = provider this.legacyProvider = new LegacyAiProvider(provider) + // 初始化时不构建中间件,等到需要时再构建 const config = providerToAiSdkConfig(provider) - this.modernClient = createClient( - config.providerId, - config.options, - onChunk ? [{ name: 'thinking-time', aiSdkMiddlewares: [thinkingTimeMiddleware(onChunk)] }] : undefined - ) + this.modernClient = createClient(config.providerId, config.options) } public async completions( modelId: string, params: StreamTextParams, + middlewareConfig: AiSdkMiddlewareConfig, onChunk?: (chunk: Chunk) => void ): Promise { // const model = params.assistant.model @@ -131,7 +129,7 @@ export default class ModernAiProvider { // if (this.modernClient && model && isModernSdkSupported(this.provider, model)) { // try { console.log('completions', modelId, params, onChunk) - return await this.modernCompletions(modelId, params, onChunk) + return await this.modernCompletions(modelId, params, middlewareConfig) // } catch (error) { // console.warn('Modern client failed, falling back to legacy:', error) // fallback到原有实现 @@ -144,22 +142,41 @@ export default class ModernAiProvider { /** * 使用现代化AI SDK的completions实现 - * 使用 AiSdkUtils 工具模块进行参数构建 + * 使用建造者模式动态构建中间件 */ private async modernCompletions( modelId: string, params: StreamTextParams, - onChunk?: (chunk: Chunk) => void + middlewareConfig: AiSdkMiddlewareConfig ): Promise { if (!this.modernClient) { throw new Error('Modern AI SDK client not initialized') } try { - if (onChunk) { + // 合并传入的配置和实例配置 + const finalConfig: AiSdkMiddlewareConfig = { + ...middlewareConfig, + provider: this.provider, + // 工具相关信息从 params 中获取 + enableTool: params.tools !== undefined && Array.isArray(params.tools) && params.tools.length > 0 + } + + // 动态构建中间件数组 + const middlewares = buildAiSdkMiddlewares(finalConfig) + console.log( + '构建的中间件:', + middlewares.map((m) => m.name) + ) + + // 创建带有中间件的客户端 + const config = providerToAiSdkConfig(this.provider) + const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares) + + if (middlewareConfig.onChunk) { // 流式处理 - 使用适配器 - const adapter = new AiSdkToChunkAdapter(onChunk) - const streamResult = await this.modernClient.streamText(modelId, params) + const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk) + const streamResult = await clientWithMiddlewares.streamText(modelId, params) const finalText = await adapter.processStream(streamResult) return { @@ -167,7 +184,7 @@ export default class ModernAiProvider { } } else { // 流式处理但没有 onChunk 回调 - const streamResult = await this.modernClient.streamText(modelId, params) + const streamResult = await clientWithMiddlewares.streamText(modelId, params) const finalText = await streamResult.text return { diff --git a/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts new file mode 100644 index 0000000000..d309a1112d --- /dev/null +++ b/src/renderer/src/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder.ts @@ -0,0 +1,188 @@ +import { AiPlugin, simulateStreamingMiddleware } from '@cherrystudio/ai-core' +import { isReasoningModel } from '@renderer/config/models' +import type { Model, Provider } from '@renderer/types' +import type { Chunk } from '@renderer/types/chunk' + +import thinkingTimeMiddleware from './ThinkingTimeMiddleware' + +/** + * AI SDK 中间件配置项 + */ +export interface AiSdkMiddlewareConfig { + streamOutput?: boolean + onChunk?: (chunk: Chunk) => void + model?: Model + provider?: Provider + enableReasoning?: boolean + enableTool?: boolean + enableWebSearch?: boolean +} + +/** + * 具名的 AI SDK 中间件 + */ +export type NamedAiSdkMiddleware = AiPlugin + +/** + * AI SDK 中间件建造者 + * 用于根据不同条件动态构建中间件数组 + */ +export class AiSdkMiddlewareBuilder { + private middlewares: NamedAiSdkMiddleware[] = [] + + /** + * 添加具名中间件 + */ + public add(namedMiddleware: NamedAiSdkMiddleware): this { + this.middlewares.push(namedMiddleware) + return this + } + + /** + * 在指定位置插入中间件 + */ + public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this { + const index = this.middlewares.findIndex((m) => m.name === targetName) + if (index !== -1) { + this.middlewares.splice(index + 1, 0, middleware) + } else { + console.warn(`AiSdkMiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`) + } + return this + } + + /** + * 检查是否包含指定名称的中间件 + */ + public has(name: string): boolean { + return this.middlewares.some((m) => m.name === name) + } + + /** + * 移除指定名称的中间件 + */ + public remove(name: string): this { + this.middlewares = this.middlewares.filter((m) => m.name !== name) + return this + } + + /** + * 构建最终的中间件数组 + */ + public build(): NamedAiSdkMiddleware[] { + return [...this.middlewares] + } + + /** + * 清空所有中间件 + */ + public clear(): this { + this.middlewares = [] + return this + } + + /** + * 获取中间件总数 + */ + public get length(): number { + return this.middlewares.length + } +} + +/** + * 根据配置构建AI SDK中间件的工厂函数 + * 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果 + */ +export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] { + const builder = new AiSdkMiddlewareBuilder() + + // 1. 思考模型且有onChunk回调时添加思考时间中间件 + if (config.onChunk && config.model && isReasoningModel(config.model)) { + builder.add({ + name: 'thinking-time', + aiSdkMiddlewares: [thinkingTimeMiddleware(config.onChunk)] + }) + } + + // 2. 可以在这里根据其他条件添加更多中间件 + // 例如:工具调用、Web搜索等相关中间件 + + // 3. 根据provider添加特定中间件 + if (config.provider) { + addProviderSpecificMiddlewares(builder, config) + } + + // 4. 根据模型类型添加特定中间件 + if (config.model) { + addModelSpecificMiddlewares(builder, config) + } + + // 5. 非流式输出时添加模拟流中间件 + if (config.streamOutput === false) { + builder.add({ + name: 'simulate-streaming', + aiSdkMiddlewares: [simulateStreamingMiddleware()] + }) + } + + return builder.build() +} + +/** + * 添加provider特定的中间件 + */ +function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { + if (!config.provider) return + + // 根据不同provider添加特定中间件 + switch (config.provider.type) { + case 'anthropic': + // Anthropic特定中间件 + break + case 'openai': + // OpenAI特定中间件 + break + case 'gemini': + // Gemini特定中间件 + break + default: + // 其他provider的通用处理 + break + } +} + +/** + * 添加模型特定的中间件 + */ +function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { + if (!config.model) return + + // 可以根据模型ID或特性添加特定中间件 + // 例如:图像生成模型、多模态模型等 + + // 示例:某些模型需要特殊处理 + if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) { + // 图像生成相关中间件 + } +} + +/** + * 创建一个预配置的中间件建造者 + */ +export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder { + return new AiSdkMiddlewareBuilder() +} + +/** + * 创建一个带有默认中间件的建造者 + */ +export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder { + const builder = new AiSdkMiddlewareBuilder() + const defaultMiddlewares = buildAiSdkMiddlewares(config) + + defaultMiddlewares.forEach((middleware) => { + builder.add(middleware) + }) + + return builder +} diff --git a/src/renderer/src/aiCore/middleware/aisdk/README.md b/src/renderer/src/aiCore/middleware/aisdk/README.md new file mode 100644 index 0000000000..7731d263c3 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/aisdk/README.md @@ -0,0 +1,140 @@ +# AI SDK 中间件建造者 + +## 概述 + +`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件(如流式输出、思考模型、provider类型等)自动构建合适的中间件组合。 + +## 使用方式 + +### 基本用法 + +```typescript +import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder' + +// 配置中间件参数 +const config: AiSdkMiddlewareConfig = { + streamOutput: false, // 非流式输出 + onChunk: chunkHandler, // chunk回调函数 + model: currentModel, // 当前模型 + provider: currentProvider, // 当前provider + enableReasoning: true, // 启用推理 + enableTool: false, // 禁用工具 + enableWebSearch: false // 禁用网页搜索 +} + +// 构建中间件数组 +const middlewares = buildAiSdkMiddlewares(config) + +// 创建带有中间件的客户端 +const client = createClient(providerId, options, middlewares) +``` + +### 手动构建 + +```typescript +import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder' + +const builder = createAiSdkMiddlewareBuilder() + +// 添加特定中间件 +builder.add({ + name: 'custom-middleware', + aiSdkMiddlewares: [customMiddleware()] +}) + +// 检查是否包含某个中间件 +if (builder.has('thinking-time')) { + console.log('已包含思考时间中间件') +} + +// 移除不需要的中间件 +builder.remove('simulate-streaming') + +// 构建最终数组 +const middlewares = builder.build() +``` + +## 支持的条件 + +### 1. 流式输出控制 + +- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware` +- **streamOutput = true**: 使用原生流式处理 + +### 2. 思考模型处理 + +- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true +- **效果**: 自动添加 `thinkingTimeMiddleware` + +### 3. Provider 特定中间件 + +根据不同的 provider 类型添加特定中间件: + +- **anthropic**: Anthropic 特定处理 +- **openai**: OpenAI 特定处理 +- **gemini**: Gemini 特定处理 + +### 4. 模型特定中间件 + +根据模型特性添加中间件: + +- **图像生成模型**: 添加图像处理相关中间件 +- **多模态模型**: 添加多模态处理中间件 + +## 扩展指南 + +### 添加新的条件判断 + +在 `buildAiSdkMiddlewares` 函数中添加新的条件: + +```typescript +// 例如:添加缓存中间件 +if (config.enableCache) { + builder.add({ + name: 'cache', + aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)] + }) +} +``` + +### 添加 Provider 特定处理 + +在 `addProviderSpecificMiddlewares` 函数中添加: + +```typescript +case 'custom-provider': + builder.add({ + name: 'custom-provider-middleware', + aiSdkMiddlewares: [customProviderMiddleware()] + }) + break +``` + +### 添加模型特定处理 + +在 `addModelSpecificMiddlewares` 函数中添加: + +```typescript +if (config.model.id.includes('custom-model')) { + builder.add({ + name: 'custom-model-middleware', + aiSdkMiddlewares: [customModelMiddleware()] + }) +} +``` + +## 中间件执行顺序 + +中间件按照添加顺序执行: + +1. **simulate-streaming** (如果 streamOutput = false) +2. **thinking-time** (如果是思考模型且有 onChunk) +3. **provider-specific** (根据 provider 类型) +4. **model-specific** (根据模型类型) + +## 注意事项 + +1. 中间件的执行顺序很重要,确保按正确顺序添加 +2. 避免添加冲突的中间件 +3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加 +4. 建议在开发环境下启用日志,以便调试中间件构建过程 diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index d8fe261497..a3e706c25e 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -3,6 +3,7 @@ */ import { StreamTextParams } from '@cherrystudio/ai-core' +import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder' import { CompletionsParams } from '@renderer/aiCore/middleware/schemas' import { buildStreamTextParams } from '@renderer/aiCore/transformParameters' import { @@ -293,7 +294,7 @@ export async function fetchChatCompletion({ onChunkReceived: (chunk: Chunk) => void }) { const provider = getAssistantProvider(assistant) - const AI = new AiProviderNew(provider, onChunkReceived) + const AI = new AiProviderNew(provider) const mcpTools = await fetchMcpTools(assistant) @@ -303,9 +304,17 @@ export async function fetchChatCompletion({ requestOptions: options }) + const middlewareConfig: AiSdkMiddlewareConfig = { + streamOutput: assistant.settings?.streamOutput ?? true, + onChunk: onChunkReceived, + model: assistant.model, + provider: provider, + enableReasoning: assistant.settings?.reasoning_effort !== undefined + } + // --- Call AI Completions --- onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED }) - await AI.completions(modelId, aiSdkParams, onChunkReceived) + await AI.completions(modelId, aiSdkParams, middlewareConfig) } interface FetchTranslateProps {