mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat: enhance AI SDK middleware integration and support
- Added AiSdkMiddlewareBuilder for dynamic middleware construction based on various conditions. - Updated ModernAiProvider to utilize new middleware configuration, improving flexibility in handling completions. - Refactored ApiService to pass middleware configuration during AI completions, enabling better control over processing. - Introduced new README documentation for the middleware builder, outlining usage and supported conditions.
This commit is contained in:
parent
1bccfd3170
commit
3771b24b52
@ -79,6 +79,7 @@ export type {
|
||||
ToolExecutionError,
|
||||
ToolResult
|
||||
} from 'ai'
|
||||
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
|
||||
// 重新导出所有 Provider Settings 类型
|
||||
export type {
|
||||
|
||||
@ -24,7 +24,7 @@ import { Chunk } from '@renderer/types/chunk'
|
||||
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
|
||||
// 引入原有的AiProvider作为fallback
|
||||
import LegacyAiProvider from './index'
|
||||
import thinkingTimeMiddleware from './middleware/aisdk/ThinkingTimeMiddleware'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
// 引入参数转换模块
|
||||
|
||||
@ -88,7 +88,7 @@ function providerToAiSdkConfig(provider: Provider): {
|
||||
*/
|
||||
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
// 目前支持主要的providers
|
||||
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai']
|
||||
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
|
||||
|
||||
// 检查provider类型
|
||||
if (!supportedProviders.includes(provider.type)) {
|
||||
@ -108,21 +108,19 @@ export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private provider: Provider
|
||||
|
||||
constructor(provider: Provider, onChunk?: (chunk: Chunk) => void) {
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
this.legacyProvider = new LegacyAiProvider(provider)
|
||||
|
||||
// 初始化时不构建中间件,等到需要时再构建
|
||||
const config = providerToAiSdkConfig(provider)
|
||||
this.modernClient = createClient(
|
||||
config.providerId,
|
||||
config.options,
|
||||
onChunk ? [{ name: 'thinking-time', aiSdkMiddlewares: [thinkingTimeMiddleware(onChunk)] }] : undefined
|
||||
)
|
||||
this.modernClient = createClient(config.providerId, config.options)
|
||||
}
|
||||
|
||||
public async completions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiSdkMiddlewareConfig,
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
): Promise<CompletionsResult> {
|
||||
// const model = params.assistant.model
|
||||
@ -131,7 +129,7 @@ export default class ModernAiProvider {
|
||||
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
|
||||
// try {
|
||||
console.log('completions', modelId, params, onChunk)
|
||||
return await this.modernCompletions(modelId, params, onChunk)
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig)
|
||||
// } catch (error) {
|
||||
// console.warn('Modern client failed, falling back to legacy:', error)
|
||||
// fallback到原有实现
|
||||
@ -144,22 +142,41 @@ export default class ModernAiProvider {
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
* 使用 AiSdkUtils 工具模块进行参数构建
|
||||
* 使用建造者模式动态构建中间件
|
||||
*/
|
||||
private async modernCompletions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
middlewareConfig: AiSdkMiddlewareConfig
|
||||
): Promise<CompletionsResult> {
|
||||
if (!this.modernClient) {
|
||||
throw new Error('Modern AI SDK client not initialized')
|
||||
}
|
||||
|
||||
try {
|
||||
if (onChunk) {
|
||||
// 合并传入的配置和实例配置
|
||||
const finalConfig: AiSdkMiddlewareConfig = {
|
||||
...middlewareConfig,
|
||||
provider: this.provider,
|
||||
// 工具相关信息从 params 中获取
|
||||
enableTool: params.tools !== undefined && Array.isArray(params.tools) && params.tools.length > 0
|
||||
}
|
||||
|
||||
// 动态构建中间件数组
|
||||
const middlewares = buildAiSdkMiddlewares(finalConfig)
|
||||
console.log(
|
||||
'构建的中间件:',
|
||||
middlewares.map((m) => m.name)
|
||||
)
|
||||
|
||||
// 创建带有中间件的客户端
|
||||
const config = providerToAiSdkConfig(this.provider)
|
||||
const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares)
|
||||
|
||||
if (middlewareConfig.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(onChunk)
|
||||
const streamResult = await this.modernClient.streamText(modelId, params)
|
||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
return {
|
||||
@ -167,7 +184,7 @@ export default class ModernAiProvider {
|
||||
}
|
||||
} else {
|
||||
// 流式处理但没有 onChunk 回调
|
||||
const streamResult = await this.modernClient.streamText(modelId, params)
|
||||
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
|
||||
const finalText = await streamResult.text
|
||||
|
||||
return {
|
||||
|
||||
@ -0,0 +1,188 @@
|
||||
import { AiPlugin, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
|
||||
import { isReasoningModel } from '@renderer/config/models'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
|
||||
|
||||
/**
|
||||
* AI SDK 中间件配置项
|
||||
*/
|
||||
export interface AiSdkMiddlewareConfig {
|
||||
streamOutput?: boolean
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
model?: Model
|
||||
provider?: Provider
|
||||
enableReasoning?: boolean
|
||||
enableTool?: boolean
|
||||
enableWebSearch?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 具名的 AI SDK 中间件
|
||||
*/
|
||||
export type NamedAiSdkMiddleware = AiPlugin
|
||||
|
||||
/**
|
||||
* AI SDK 中间件建造者
|
||||
* 用于根据不同条件动态构建中间件数组
|
||||
*/
|
||||
export class AiSdkMiddlewareBuilder {
|
||||
private middlewares: NamedAiSdkMiddleware[] = []
|
||||
|
||||
/**
|
||||
* 添加具名中间件
|
||||
*/
|
||||
public add(namedMiddleware: NamedAiSdkMiddleware): this {
|
||||
this.middlewares.push(namedMiddleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定位置插入中间件
|
||||
*/
|
||||
public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this {
|
||||
const index = this.middlewares.findIndex((m) => m.name === targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index + 1, 0, middleware)
|
||||
} else {
|
||||
console.warn(`AiSdkMiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否包含指定名称的中间件
|
||||
*/
|
||||
public has(name: string): boolean {
|
||||
return this.middlewares.some((m) => m.name === name)
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除指定名称的中间件
|
||||
*/
|
||||
public remove(name: string): this {
|
||||
this.middlewares = this.middlewares.filter((m) => m.name !== name)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
*/
|
||||
public build(): NamedAiSdkMiddleware[] {
|
||||
return [...this.middlewares]
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空所有中间件
|
||||
*/
|
||||
public clear(): this {
|
||||
this.middlewares = []
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取中间件总数
|
||||
*/
|
||||
public get length(): number {
|
||||
return this.middlewares.length
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据配置构建AI SDK中间件的工厂函数
|
||||
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
|
||||
*/
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
|
||||
// 1. 思考模型且有onChunk回调时添加思考时间中间件
|
||||
if (config.onChunk && config.model && isReasoningModel(config.model)) {
|
||||
builder.add({
|
||||
name: 'thinking-time',
|
||||
aiSdkMiddlewares: [thinkingTimeMiddleware(config.onChunk)]
|
||||
})
|
||||
}
|
||||
|
||||
// 2. 可以在这里根据其他条件添加更多中间件
|
||||
// 例如:工具调用、Web搜索等相关中间件
|
||||
|
||||
// 3. 根据provider添加特定中间件
|
||||
if (config.provider) {
|
||||
addProviderSpecificMiddlewares(builder, config)
|
||||
}
|
||||
|
||||
// 4. 根据模型类型添加特定中间件
|
||||
if (config.model) {
|
||||
addModelSpecificMiddlewares(builder, config)
|
||||
}
|
||||
|
||||
// 5. 非流式输出时添加模拟流中间件
|
||||
if (config.streamOutput === false) {
|
||||
builder.add({
|
||||
name: 'simulate-streaming',
|
||||
aiSdkMiddlewares: [simulateStreamingMiddleware()]
|
||||
})
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加provider特定的中间件
|
||||
*/
|
||||
function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
|
||||
if (!config.provider) return
|
||||
|
||||
// 根据不同provider添加特定中间件
|
||||
switch (config.provider.type) {
|
||||
case 'anthropic':
|
||||
// Anthropic特定中间件
|
||||
break
|
||||
case 'openai':
|
||||
// OpenAI特定中间件
|
||||
break
|
||||
case 'gemini':
|
||||
// Gemini特定中间件
|
||||
break
|
||||
default:
|
||||
// 其他provider的通用处理
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加模型特定的中间件
|
||||
*/
|
||||
function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
|
||||
if (!config.model) return
|
||||
|
||||
// 可以根据模型ID或特性添加特定中间件
|
||||
// 例如:图像生成模型、多模态模型等
|
||||
|
||||
// 示例:某些模型需要特殊处理
|
||||
if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) {
|
||||
// 图像生成相关中间件
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建一个预配置的中间件建造者
|
||||
*/
|
||||
export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder {
|
||||
return new AiSdkMiddlewareBuilder()
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建一个带有默认中间件的建造者
|
||||
*/
|
||||
export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
const defaultMiddlewares = buildAiSdkMiddlewares(config)
|
||||
|
||||
defaultMiddlewares.forEach((middleware) => {
|
||||
builder.add(middleware)
|
||||
})
|
||||
|
||||
return builder
|
||||
}
|
||||
140
src/renderer/src/aiCore/middleware/aisdk/README.md
Normal file
140
src/renderer/src/aiCore/middleware/aisdk/README.md
Normal file
@ -0,0 +1,140 @@
|
||||
# AI SDK 中间件建造者
|
||||
|
||||
## 概述
|
||||
|
||||
`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件(如流式输出、思考模型、provider类型等)自动构建合适的中间件组合。
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 基本用法
|
||||
|
||||
```typescript
|
||||
import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder'
|
||||
|
||||
// 配置中间件参数
|
||||
const config: AiSdkMiddlewareConfig = {
|
||||
streamOutput: false, // 非流式输出
|
||||
onChunk: chunkHandler, // chunk回调函数
|
||||
model: currentModel, // 当前模型
|
||||
provider: currentProvider, // 当前provider
|
||||
enableReasoning: true, // 启用推理
|
||||
enableTool: false, // 禁用工具
|
||||
enableWebSearch: false // 禁用网页搜索
|
||||
}
|
||||
|
||||
// 构建中间件数组
|
||||
const middlewares = buildAiSdkMiddlewares(config)
|
||||
|
||||
// 创建带有中间件的客户端
|
||||
const client = createClient(providerId, options, middlewares)
|
||||
```
|
||||
|
||||
### 手动构建
|
||||
|
||||
```typescript
|
||||
import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder'
|
||||
|
||||
const builder = createAiSdkMiddlewareBuilder()
|
||||
|
||||
// 添加特定中间件
|
||||
builder.add({
|
||||
name: 'custom-middleware',
|
||||
aiSdkMiddlewares: [customMiddleware()]
|
||||
})
|
||||
|
||||
// 检查是否包含某个中间件
|
||||
if (builder.has('thinking-time')) {
|
||||
console.log('已包含思考时间中间件')
|
||||
}
|
||||
|
||||
// 移除不需要的中间件
|
||||
builder.remove('simulate-streaming')
|
||||
|
||||
// 构建最终数组
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
## 支持的条件
|
||||
|
||||
### 1. 流式输出控制
|
||||
|
||||
- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware`
|
||||
- **streamOutput = true**: 使用原生流式处理
|
||||
|
||||
### 2. 思考模型处理
|
||||
|
||||
- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true
|
||||
- **效果**: 自动添加 `thinkingTimeMiddleware`
|
||||
|
||||
### 3. Provider 特定中间件
|
||||
|
||||
根据不同的 provider 类型添加特定中间件:
|
||||
|
||||
- **anthropic**: Anthropic 特定处理
|
||||
- **openai**: OpenAI 特定处理
|
||||
- **gemini**: Gemini 特定处理
|
||||
|
||||
### 4. 模型特定中间件
|
||||
|
||||
根据模型特性添加中间件:
|
||||
|
||||
- **图像生成模型**: 添加图像处理相关中间件
|
||||
- **多模态模型**: 添加多模态处理中间件
|
||||
|
||||
## 扩展指南
|
||||
|
||||
### 添加新的条件判断
|
||||
|
||||
在 `buildAiSdkMiddlewares` 函数中添加新的条件:
|
||||
|
||||
```typescript
|
||||
// 例如:添加缓存中间件
|
||||
if (config.enableCache) {
|
||||
builder.add({
|
||||
name: 'cache',
|
||||
aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)]
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### 添加 Provider 特定处理
|
||||
|
||||
在 `addProviderSpecificMiddlewares` 函数中添加:
|
||||
|
||||
```typescript
|
||||
case 'custom-provider':
|
||||
builder.add({
|
||||
name: 'custom-provider-middleware',
|
||||
aiSdkMiddlewares: [customProviderMiddleware()]
|
||||
})
|
||||
break
|
||||
```
|
||||
|
||||
### 添加模型特定处理
|
||||
|
||||
在 `addModelSpecificMiddlewares` 函数中添加:
|
||||
|
||||
```typescript
|
||||
if (config.model.id.includes('custom-model')) {
|
||||
builder.add({
|
||||
name: 'custom-model-middleware',
|
||||
aiSdkMiddlewares: [customModelMiddleware()]
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
## 中间件执行顺序
|
||||
|
||||
中间件按照添加顺序执行:
|
||||
|
||||
1. **simulate-streaming** (如果 streamOutput = false)
|
||||
2. **thinking-time** (如果是思考模型且有 onChunk)
|
||||
3. **provider-specific** (根据 provider 类型)
|
||||
4. **model-specific** (根据模型类型)
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 中间件的执行顺序很重要,确保按正确顺序添加
|
||||
2. 避免添加冲突的中间件
|
||||
3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加
|
||||
4. 建议在开发环境下启用日志,以便调试中间件构建过程
|
||||
@ -3,6 +3,7 @@
|
||||
*/
|
||||
|
||||
import { StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||
import {
|
||||
@ -293,7 +294,7 @@ export async function fetchChatCompletion({
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
}) {
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProviderNew(provider, onChunkReceived)
|
||||
const AI = new AiProviderNew(provider)
|
||||
|
||||
const mcpTools = await fetchMcpTools(assistant)
|
||||
|
||||
@ -303,9 +304,17 @@ export async function fetchChatCompletion({
|
||||
requestOptions: options
|
||||
})
|
||||
|
||||
const middlewareConfig: AiSdkMiddlewareConfig = {
|
||||
streamOutput: assistant.settings?.streamOutput ?? true,
|
||||
onChunk: onChunkReceived,
|
||||
model: assistant.model,
|
||||
provider: provider,
|
||||
enableReasoning: assistant.settings?.reasoning_effort !== undefined
|
||||
}
|
||||
|
||||
// --- Call AI Completions ---
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
await AI.completions(modelId, aiSdkParams, onChunkReceived)
|
||||
await AI.completions(modelId, aiSdkParams, middlewareConfig)
|
||||
}
|
||||
|
||||
interface FetchTranslateProps {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user