feat: enhance AI SDK middleware integration and support

- Added AiSdkMiddlewareBuilder for dynamic middleware construction based on various conditions.
- Updated ModernAiProvider to utilize new middleware configuration, improving flexibility in handling completions.
- Refactored ApiService to pass middleware configuration during AI completions, enabling better control over processing.
- Introduced new README documentation for the middleware builder, outlining usage and supported conditions.
This commit is contained in:
suyao 2025-06-20 15:31:41 +08:00 committed by MyPrototypeWhat
parent 1bccfd3170
commit 3771b24b52
5 changed files with 372 additions and 17 deletions

View File

@ -79,6 +79,7 @@ export type {
ToolExecutionError,
ToolResult
} from 'ai'
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
// 重新导出所有 Provider Settings 类型
export type {

View File

@ -24,7 +24,7 @@ import { Chunk } from '@renderer/types/chunk'
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
// 引入原有的AiProvider作为fallback
import LegacyAiProvider from './index'
import thinkingTimeMiddleware from './middleware/aisdk/ThinkingTimeMiddleware'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
// 引入参数转换模块
@ -88,7 +88,7 @@ function providerToAiSdkConfig(provider: Provider): {
*/
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai']
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
@ -108,21 +108,19 @@ export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private provider: Provider
constructor(provider: Provider, onChunk?: (chunk: Chunk) => void) {
constructor(provider: Provider) {
this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider)
// 初始化时不构建中间件,等到需要时再构建
const config = providerToAiSdkConfig(provider)
this.modernClient = createClient(
config.providerId,
config.options,
onChunk ? [{ name: 'thinking-time', aiSdkMiddlewares: [thinkingTimeMiddleware(onChunk)] }] : undefined
)
this.modernClient = createClient(config.providerId, config.options)
}
public async completions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig,
onChunk?: (chunk: Chunk) => void
): Promise<CompletionsResult> {
// const model = params.assistant.model
@ -131,7 +129,7 @@ export default class ModernAiProvider {
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
// try {
console.log('completions', modelId, params, onChunk)
return await this.modernCompletions(modelId, params, onChunk)
return await this.modernCompletions(modelId, params, middlewareConfig)
// } catch (error) {
// console.warn('Modern client failed, falling back to legacy:', error)
// fallback到原有实现
@ -144,22 +142,41 @@ export default class ModernAiProvider {
/**
* 使AI SDK的completions实现
* 使 AiSdkUtils
* 使
*/
private async modernCompletions(
modelId: string,
params: StreamTextParams,
onChunk?: (chunk: Chunk) => void
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
if (!this.modernClient) {
throw new Error('Modern AI SDK client not initialized')
}
try {
if (onChunk) {
// 合并传入的配置和实例配置
const finalConfig: AiSdkMiddlewareConfig = {
...middlewareConfig,
provider: this.provider,
// 工具相关信息从 params 中获取
enableTool: params.tools !== undefined && Array.isArray(params.tools) && params.tools.length > 0
}
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(finalConfig)
console.log(
'构建的中间件:',
middlewares.map((m) => m.name)
)
// 创建带有中间件的客户端
const config = providerToAiSdkConfig(this.provider)
const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares)
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(onChunk)
const streamResult = await this.modernClient.streamText(modelId, params)
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
const finalText = await adapter.processStream(streamResult)
return {
@ -167,7 +184,7 @@ export default class ModernAiProvider {
}
} else {
// 流式处理但没有 onChunk 回调
const streamResult = await this.modernClient.streamText(modelId, params)
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
const finalText = await streamResult.text
return {

View File

@ -0,0 +1,188 @@
import { AiPlugin, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
/**
* AI SDK
*/
export interface AiSdkMiddlewareConfig {
streamOutput?: boolean
onChunk?: (chunk: Chunk) => void
model?: Model
provider?: Provider
enableReasoning?: boolean
enableTool?: boolean
enableWebSearch?: boolean
}
/**
* AI SDK
*/
export type NamedAiSdkMiddleware = AiPlugin
/**
* AI SDK
*
*/
export class AiSdkMiddlewareBuilder {
private middlewares: NamedAiSdkMiddleware[] = []
/**
*
*/
public add(namedMiddleware: NamedAiSdkMiddleware): this {
this.middlewares.push(namedMiddleware)
return this
}
/**
*
*/
public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this {
const index = this.middlewares.findIndex((m) => m.name === targetName)
if (index !== -1) {
this.middlewares.splice(index + 1, 0, middleware)
} else {
console.warn(`AiSdkMiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
*
*/
public has(name: string): boolean {
return this.middlewares.some((m) => m.name === name)
}
/**
*
*/
public remove(name: string): this {
this.middlewares = this.middlewares.filter((m) => m.name !== name)
return this
}
/**
*
*/
public build(): NamedAiSdkMiddleware[] {
return [...this.middlewares]
}
/**
*
*/
public clear(): this {
this.middlewares = []
return this
}
/**
*
*/
public get length(): number {
return this.middlewares.length
}
}
/**
* AI SDK中间件的工厂函数
*
*/
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 1. 思考模型且有onChunk回调时添加思考时间中间件
if (config.onChunk && config.model && isReasoningModel(config.model)) {
builder.add({
name: 'thinking-time',
aiSdkMiddlewares: [thinkingTimeMiddleware(config.onChunk)]
})
}
// 2. 可以在这里根据其他条件添加更多中间件
// 例如工具调用、Web搜索等相关中间件
// 3. 根据provider添加特定中间件
if (config.provider) {
addProviderSpecificMiddlewares(builder, config)
}
// 4. 根据模型类型添加特定中间件
if (config.model) {
addModelSpecificMiddlewares(builder, config)
}
// 5. 非流式输出时添加模拟流中间件
if (config.streamOutput === false) {
builder.add({
name: 'simulate-streaming',
aiSdkMiddlewares: [simulateStreamingMiddleware()]
})
}
return builder.build()
}
/**
* provider特定的中间件
*/
function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
if (!config.provider) return
// 根据不同provider添加特定中间件
switch (config.provider.type) {
case 'anthropic':
// Anthropic特定中间件
break
case 'openai':
// OpenAI特定中间件
break
case 'gemini':
// Gemini特定中间件
break
default:
// 其他provider的通用处理
break
}
}
/**
*
*/
function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
if (!config.model) return
// 可以根据模型ID或特性添加特定中间件
// 例如:图像生成模型、多模态模型等
// 示例:某些模型需要特殊处理
if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) {
// 图像生成相关中间件
}
}
/**
*
*/
export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder {
return new AiSdkMiddlewareBuilder()
}
/**
*
*/
export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder {
const builder = new AiSdkMiddlewareBuilder()
const defaultMiddlewares = buildAiSdkMiddlewares(config)
defaultMiddlewares.forEach((middleware) => {
builder.add(middleware)
})
return builder
}

View File

@ -0,0 +1,140 @@
# AI SDK 中间件建造者
## 概述
`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件如流式输出、思考模型、provider类型等自动构建合适的中间件组合。
## 使用方式
### 基本用法
```typescript
import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder'
// 配置中间件参数
const config: AiSdkMiddlewareConfig = {
streamOutput: false, // 非流式输出
onChunk: chunkHandler, // chunk回调函数
model: currentModel, // 当前模型
provider: currentProvider, // 当前provider
enableReasoning: true, // 启用推理
enableTool: false, // 禁用工具
enableWebSearch: false // 禁用网页搜索
}
// 构建中间件数组
const middlewares = buildAiSdkMiddlewares(config)
// 创建带有中间件的客户端
const client = createClient(providerId, options, middlewares)
```
### 手动构建
```typescript
import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder'
const builder = createAiSdkMiddlewareBuilder()
// 添加特定中间件
builder.add({
name: 'custom-middleware',
aiSdkMiddlewares: [customMiddleware()]
})
// 检查是否包含某个中间件
if (builder.has('thinking-time')) {
console.log('已包含思考时间中间件')
}
// 移除不需要的中间件
builder.remove('simulate-streaming')
// 构建最终数组
const middlewares = builder.build()
```
## 支持的条件
### 1. 流式输出控制
- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware`
- **streamOutput = true**: 使用原生流式处理
### 2. 思考模型处理
- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true
- **效果**: 自动添加 `thinkingTimeMiddleware`
### 3. Provider 特定中间件
根据不同的 provider 类型添加特定中间件:
- **anthropic**: Anthropic 特定处理
- **openai**: OpenAI 特定处理
- **gemini**: Gemini 特定处理
### 4. 模型特定中间件
根据模型特性添加中间件:
- **图像生成模型**: 添加图像处理相关中间件
- **多模态模型**: 添加多模态处理中间件
## 扩展指南
### 添加新的条件判断
`buildAiSdkMiddlewares` 函数中添加新的条件:
```typescript
// 例如:添加缓存中间件
if (config.enableCache) {
builder.add({
name: 'cache',
aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)]
})
}
```
### 添加 Provider 特定处理
`addProviderSpecificMiddlewares` 函数中添加:
```typescript
case 'custom-provider':
builder.add({
name: 'custom-provider-middleware',
aiSdkMiddlewares: [customProviderMiddleware()]
})
break
```
### 添加模型特定处理
`addModelSpecificMiddlewares` 函数中添加:
```typescript
if (config.model.id.includes('custom-model')) {
builder.add({
name: 'custom-model-middleware',
aiSdkMiddlewares: [customModelMiddleware()]
})
}
```
## 中间件执行顺序
中间件按照添加顺序执行:
1. **simulate-streaming** (如果 streamOutput = false)
2. **thinking-time** (如果是思考模型且有 onChunk)
3. **provider-specific** (根据 provider 类型)
4. **model-specific** (根据模型类型)
## 注意事项
1. 中间件的执行顺序很重要,确保按正确顺序添加
2. 避免添加冲突的中间件
3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加
4. 建议在开发环境下启用日志,以便调试中间件构建过程

View File

@ -3,6 +3,7 @@
*/
import { StreamTextParams } from '@cherrystudio/ai-core'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import {
@ -293,7 +294,7 @@ export async function fetchChatCompletion({
onChunkReceived: (chunk: Chunk) => void
}) {
const provider = getAssistantProvider(assistant)
const AI = new AiProviderNew(provider, onChunkReceived)
const AI = new AiProviderNew(provider)
const mcpTools = await fetchMcpTools(assistant)
@ -303,9 +304,17 @@ export async function fetchChatCompletion({
requestOptions: options
})
const middlewareConfig: AiSdkMiddlewareConfig = {
streamOutput: assistant.settings?.streamOutput ?? true,
onChunk: onChunkReceived,
model: assistant.model,
provider: provider,
enableReasoning: assistant.settings?.reasoning_effort !== undefined
}
// --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
await AI.completions(modelId, aiSdkParams, onChunkReceived)
await AI.completions(modelId, aiSdkParams, middlewareConfig)
}
interface FetchTranslateProps {