mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: enhance ModernAiProvider with new reasoning plugins and dynamic middleware construction
- Introduced `reasoningTimePlugin` and `smoothReasoningPlugin` to improve reasoning content handling and processing. - Refactored `ModernAiProvider` to dynamically build plugin arrays based on middleware configuration, enhancing flexibility. - Removed the obsolete `ThinkingTimeMiddleware` to streamline middleware management. - Updated `buildAiSdkMiddlewares` to reflect changes in middleware handling and improve clarity in the configuration process. - Enhanced logging for better visibility into plugin and middleware configurations during execution.
This commit is contained in:
parent
87f803b0d3
commit
f61da8c2d6
@ -9,6 +9,7 @@
|
||||
*/
|
||||
|
||||
import {
|
||||
AiPlugin,
|
||||
createExecutor,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
@ -26,7 +27,8 @@ import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './index'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
|
||||
import { CompletionsResult } from './middleware/schemas'
|
||||
import reasonPlugin from './plugins/reasonPlugin'
|
||||
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
|
||||
import smoothReasoningPlugin from './plugins/smoothReasoningPlugin'
|
||||
import textPlugin from './plugins/textPlugin'
|
||||
import { getAiSdkProviderId } from './provider/factory'
|
||||
|
||||
@ -103,27 +105,69 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
|
||||
}
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private modernExecutor?: ReturnType<typeof createExecutor>
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private provider: Provider
|
||||
private config: ReturnType<typeof providerToAiSdkConfig>
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
this.legacyProvider = new LegacyAiProvider(provider)
|
||||
|
||||
// TODO:如果后续在调用completions时需要切换provider的话,
|
||||
// 初始化时不构建中间件,等到需要时再构建
|
||||
const config = providerToAiSdkConfig(provider)
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(provider)
|
||||
|
||||
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
|
||||
}
|
||||
|
||||
this.modernExecutor = createExecutor(config.providerId, config.options, [
|
||||
reasonPlugin({
|
||||
delayInMs: 80,
|
||||
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
||||
}),
|
||||
textPlugin
|
||||
])
|
||||
/**
|
||||
* 根据条件构建插件数组
|
||||
*/
|
||||
private buildPlugins(middlewareConfig: AiSdkMiddlewareConfig) {
|
||||
const plugins: AiPlugin[] = []
|
||||
const model = middlewareConfig.model
|
||||
// 1. 总是添加通用插件
|
||||
plugins.push(textPlugin)
|
||||
|
||||
// 2. 推理模型时添加推理插件
|
||||
if (model && middlewareConfig.enableReasoning) {
|
||||
plugins.push(
|
||||
smoothReasoningPlugin({
|
||||
delayInMs: 80,
|
||||
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
||||
}),
|
||||
reasoningTimePlugin()
|
||||
)
|
||||
}
|
||||
|
||||
// 3. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.enableTool) {
|
||||
plugins.push(
|
||||
createMCPPromptPlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
console.log('createSystemMessage_context', context.isRecursiveCall)
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
params.messages = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: systemPrompt
|
||||
},
|
||||
...params.messages
|
||||
]
|
||||
return null
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
console.log(
|
||||
'最终插件列表:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
return plugins
|
||||
}
|
||||
|
||||
public async completions(
|
||||
@ -131,37 +175,25 @@ export default class ModernAiProvider {
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiSdkMiddlewareConfig
|
||||
): Promise<CompletionsResult> {
|
||||
// const model = params.assistant.model
|
||||
|
||||
// 检查是否应该使用现代化客户端
|
||||
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
|
||||
// try {
|
||||
console.log('completions', modelId, params, middlewareConfig)
|
||||
return await this.modernCompletions(modelId, params, middlewareConfig)
|
||||
// } catch (error) {
|
||||
// console.warn('Modern client failed, falling back to legacy:', error)
|
||||
// fallback到原有实现
|
||||
// }
|
||||
// }
|
||||
|
||||
// 使用原有实现
|
||||
// return this.legacyProvider.completions(params, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
* 使用建造者模式动态构建中间件
|
||||
*/
|
||||
private async modernCompletions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
middlewareConfig: AiSdkMiddlewareConfig
|
||||
): Promise<CompletionsResult> {
|
||||
if (!this.modernExecutor) {
|
||||
throw new Error('Modern AI SDK client not initialized')
|
||||
}
|
||||
|
||||
try {
|
||||
// 根据条件构建插件数组
|
||||
const plugins = this.buildPlugins(middlewareConfig)
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = createExecutor(this.config.providerId, this.config.options, plugins)
|
||||
|
||||
// 动态构建中间件数组
|
||||
const middlewares = buildAiSdkMiddlewares(middlewareConfig)
|
||||
console.log('构建的中间件:', middlewares)
|
||||
@ -170,31 +202,8 @@ export default class ModernAiProvider {
|
||||
if (middlewareConfig.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||
// 创建MCP Prompt插件
|
||||
if (middlewareConfig.enableTool) {
|
||||
const mcpPromptPlugin = createMCPPromptPlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
console.log('createSystemMessage_context', context.isRecursiveCall)
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
params.messages = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: systemPrompt
|
||||
},
|
||||
...params.messages
|
||||
]
|
||||
return null
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
})
|
||||
this.modernExecutor.pluginEngine.use(mcpPromptPlugin)
|
||||
}
|
||||
const streamResult = await this.modernExecutor.streamText(
|
||||
|
||||
const streamResult = await executor.streamText(
|
||||
modelId,
|
||||
params,
|
||||
middlewares.length > 0 ? { middlewares } : undefined
|
||||
@ -207,7 +216,7 @@ export default class ModernAiProvider {
|
||||
}
|
||||
} else {
|
||||
// 流式处理但没有 onChunk 回调
|
||||
const streamResult = await this.modernExecutor.streamText(
|
||||
const streamResult = await executor.streamText(
|
||||
modelId,
|
||||
params,
|
||||
middlewares.length > 0 ? { middlewares } : undefined
|
||||
|
||||
@ -3,12 +3,9 @@ import {
|
||||
LanguageModelV1Middleware,
|
||||
simulateStreamingMiddleware
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { isReasoningModel } from '@renderer/config/models'
|
||||
import type { MCPTool, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
|
||||
|
||||
/**
|
||||
* AI SDK 中间件配置项
|
||||
*/
|
||||
@ -112,34 +109,24 @@ export class AiSdkMiddlewareBuilder {
|
||||
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] {
|
||||
const builder = new AiSdkMiddlewareBuilder()
|
||||
|
||||
// 1. 思考模型且有onChunk回调时添加思考时间中间件
|
||||
if (config.onChunk && config.model && isReasoningModel(config.model)) {
|
||||
builder.add({
|
||||
name: 'thinking-time',
|
||||
middleware: thinkingTimeMiddleware()
|
||||
})
|
||||
}
|
||||
|
||||
// 2. 可以在这里根据其他条件添加更多中间件
|
||||
// 例如:工具调用、Web搜索等相关中间件
|
||||
|
||||
// 3. 根据provider添加特定中间件
|
||||
// 1. 根据provider添加特定中间件
|
||||
if (config.provider) {
|
||||
addProviderSpecificMiddlewares(builder, config)
|
||||
}
|
||||
|
||||
// 4. 根据模型类型添加特定中间件
|
||||
// 2. 根据模型类型添加特定中间件
|
||||
if (config.model) {
|
||||
addModelSpecificMiddlewares(builder, config)
|
||||
}
|
||||
|
||||
// 5. 非流式输出时添加模拟流中间件
|
||||
// 3. 非流式输出时添加模拟流中间件
|
||||
if (config.streamOutput === false) {
|
||||
builder.add({
|
||||
name: 'simulate-streaming',
|
||||
middleware: simulateStreamingMiddleware()
|
||||
})
|
||||
}
|
||||
|
||||
console.log('builder.build()', builder.buildNamed())
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
@ -1,67 +0,0 @@
|
||||
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
|
||||
|
||||
/**
|
||||
* 一个用于统计 LLM "思考时间"(Time to First Token)的 AI SDK 中间件。
|
||||
*
|
||||
* 工作原理:
|
||||
* 1. 在 `stream` 方法被调用时,记录一个起始时间。
|
||||
* 2. 它会创建一个新的 `TransformStream` 来代理原始的流。
|
||||
* 3. 当第一个数据块 (chunk) 从原始流中到达时,记录结束时间。
|
||||
* 4. 计算两者之差,即为 "思考时间"
|
||||
* 这里只处理了thinking_complete
|
||||
*/
|
||||
export default function thinkingTimeMiddleware(): LanguageModelV1Middleware {
|
||||
return {
|
||||
wrapStream: async ({ doStream }) => {
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
let accumulatedThinkingContent = ''
|
||||
const { stream, ...reset } = await doStream()
|
||||
const transformStream = new TransformStream<LanguageModelV1StreamPart, any>({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type === 'reasoning' || chunk.type === 'redacted-reasoning') {
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = performance.now()
|
||||
}
|
||||
accumulatedThinkingContent += chunk.textDelta || ''
|
||||
// 将所有 chunk 原样传递下去
|
||||
controller.enqueue({ ...chunk, thinking_millsec: performance.now() - thinkingStartTime })
|
||||
} else {
|
||||
if (hasThinkingContent && thinkingStartTime > 0) {
|
||||
const thinkingTime = performance.now() - thinkingStartTime
|
||||
const thinkingCompleteChunk = {
|
||||
type: 'reasoning-signature',
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingTime
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
hasThinkingContent = false
|
||||
thinkingStartTime = 0
|
||||
accumulatedThinkingContent = ''
|
||||
}
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
// 如果流的末尾都是 reasoning,也需要发送 complete 事件
|
||||
if (hasThinkingContent && thinkingStartTime > 0) {
|
||||
const thinkingTime = Date.now() - thinkingStartTime
|
||||
const thinkingCompleteChunk = {
|
||||
type: 'reasoning-signature',
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingTime
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
}
|
||||
controller.terminate()
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
stream: stream.pipeThrough(transformStream),
|
||||
...reset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
37
src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts
Normal file
37
src/renderer/src/aiCore/plugins/reasoningTimePlugin.ts
Normal file
@ -0,0 +1,37 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
|
||||
export default definePlugin(() => ({
|
||||
name: 'reasoningTimePlugin',
|
||||
transformStream: () => () => {
|
||||
let thinkingStartTime = 0
|
||||
let hasStartedThinking = false
|
||||
let accumulatedThinkingContent = ''
|
||||
return new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type === 'reasoning') {
|
||||
if (!hasStartedThinking) {
|
||||
hasStartedThinking = true
|
||||
thinkingStartTime = performance.now()
|
||||
}
|
||||
accumulatedThinkingContent += chunk.textDelta
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
})
|
||||
} else if (hasStartedThinking && accumulatedThinkingContent) {
|
||||
controller.enqueue({
|
||||
type: 'reasoning-signature',
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
})
|
||||
accumulatedThinkingContent = ''
|
||||
hasStartedThinking = false
|
||||
thinkingStartTime = 0
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
@ -1,7 +1,7 @@
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
|
||||
export default definePlugin(({ delayInMs, chunkingRegex }: { delayInMs: number; chunkingRegex: RegExp }) => ({
|
||||
name: 'reasonPlugin',
|
||||
name: 'smoothReasoningPlugin',
|
||||
|
||||
transformStream: () => () => {
|
||||
let buffer = ''
|
||||
Loading…
Reference in New Issue
Block a user