feat: enhance ModernAiProvider with new reasoning plugins and dynamic middleware construction

- Introduced `reasoningTimePlugin` and `smoothReasoningPlugin` to improve reasoning content handling and processing.
- Refactored `ModernAiProvider` to dynamically build plugin arrays based on middleware configuration, enhancing flexibility.
- Removed the obsolete `ThinkingTimeMiddleware` to streamline middleware management.
- Updated `buildAiSdkMiddlewares` to reflect changes in middleware handling and improve clarity in the configuration process.
- Enhanced logging for better visibility into plugin and middleware configurations during execution.
This commit is contained in:
suyao 2025-06-27 15:10:47 +08:00
parent 87f803b0d3
commit f61da8c2d6
No known key found for this signature in database
5 changed files with 109 additions and 143 deletions

View File

@ -9,6 +9,7 @@
*/
import {
AiPlugin,
createExecutor,
ProviderConfigFactory,
type ProviderId,
@ -26,7 +27,8 @@ import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
import reasonPlugin from './plugins/reasonPlugin'
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
import smoothReasoningPlugin from './plugins/smoothReasoningPlugin'
import textPlugin from './plugins/textPlugin'
import { getAiSdkProviderId } from './provider/factory'
@ -103,27 +105,69 @@ function isModernSdkSupported(provider: Provider, model?: Model): boolean {
}
export default class ModernAiProvider {
private modernExecutor?: ReturnType<typeof createExecutor>
private legacyProvider: LegacyAiProvider
private provider: Provider
private config: ReturnType<typeof providerToAiSdkConfig>
constructor(provider: Provider) {
this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider)
// TODO:如果后续在调用completions时需要切换provider的话,
// 初始化时不构建中间件,等到需要时再构建
const config = providerToAiSdkConfig(provider)
// 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(provider)
console.log('[Modern AI Provider] Creating executor with MCP Prompt plugin enabled')
}
this.modernExecutor = createExecutor(config.providerId, config.options, [
reasonPlugin({
delayInMs: 80,
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
}),
textPlugin
])
/**
*
*/
private buildPlugins(middlewareConfig: AiSdkMiddlewareConfig) {
const plugins: AiPlugin[] = []
const model = middlewareConfig.model
// 1. 总是添加通用插件
plugins.push(textPlugin)
// 2. 推理模型时添加推理插件
if (model && middlewareConfig.enableReasoning) {
plugins.push(
smoothReasoningPlugin({
delayInMs: 80,
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
}),
reasoningTimePlugin()
)
}
// 3. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.enableTool) {
plugins.push(
createMCPPromptPlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
console.log('createSystemMessage_context', context.isRecursiveCall)
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
params.messages = [
{
role: 'assistant',
content: systemPrompt
},
...params.messages
]
return null
}
return systemPrompt
}
})
)
}
console.log(
'最终插件列表:',
plugins.map((p) => p.name)
)
return plugins
}
public async completions(
@ -131,37 +175,25 @@ export default class ModernAiProvider {
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
// const model = params.assistant.model
// 检查是否应该使用现代化客户端
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
// try {
console.log('completions', modelId, params, middlewareConfig)
return await this.modernCompletions(modelId, params, middlewareConfig)
// } catch (error) {
// console.warn('Modern client failed, falling back to legacy:', error)
// fallback到原有实现
// }
// }
// 使用原有实现
// return this.legacyProvider.completions(params, options)
}
/**
* 使AI SDK的completions实现
* 使
*/
private async modernCompletions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
if (!this.modernExecutor) {
throw new Error('Modern AI SDK client not initialized')
}
try {
// 根据条件构建插件数组
const plugins = this.buildPlugins(middlewareConfig)
// 用构建好的插件数组创建executor
const executor = createExecutor(this.config.providerId, this.config.options, plugins)
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(middlewareConfig)
console.log('构建的中间件:', middlewares)
@ -170,31 +202,8 @@ export default class ModernAiProvider {
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
// 创建MCP Prompt插件
if (middlewareConfig.enableTool) {
const mcpPromptPlugin = createMCPPromptPlugin({
enabled: true,
createSystemMessage: (systemPrompt, params, context) => {
console.log('createSystemMessage_context', context.isRecursiveCall)
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
if (context.isRecursiveCall) {
return null
}
params.messages = [
{
role: 'assistant',
content: systemPrompt
},
...params.messages
]
return null
}
return systemPrompt
}
})
this.modernExecutor.pluginEngine.use(mcpPromptPlugin)
}
const streamResult = await this.modernExecutor.streamText(
const streamResult = await executor.streamText(
modelId,
params,
middlewares.length > 0 ? { middlewares } : undefined
@ -207,7 +216,7 @@ export default class ModernAiProvider {
}
} else {
// 流式处理但没有 onChunk 回调
const streamResult = await this.modernExecutor.streamText(
const streamResult = await executor.streamText(
modelId,
params,
middlewares.length > 0 ? { middlewares } : undefined

View File

@ -3,12 +3,9 @@ import {
LanguageModelV1Middleware,
simulateStreamingMiddleware
} from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models'
import type { MCPTool, Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
/**
* AI SDK
*/
@ -112,34 +109,24 @@ export class AiSdkMiddlewareBuilder {
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): LanguageModelV1Middleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 1. 思考模型且有onChunk回调时添加思考时间中间件
if (config.onChunk && config.model && isReasoningModel(config.model)) {
builder.add({
name: 'thinking-time',
middleware: thinkingTimeMiddleware()
})
}
// 2. 可以在这里根据其他条件添加更多中间件
// 例如工具调用、Web搜索等相关中间件
// 3. 根据provider添加特定中间件
// 1. 根据provider添加特定中间件
if (config.provider) {
addProviderSpecificMiddlewares(builder, config)
}
// 4. 根据模型类型添加特定中间件
// 2. 根据模型类型添加特定中间件
if (config.model) {
addModelSpecificMiddlewares(builder, config)
}
// 5. 非流式输出时添加模拟流中间件
// 3. 非流式输出时添加模拟流中间件
if (config.streamOutput === false) {
builder.add({
name: 'simulate-streaming',
middleware: simulateStreamingMiddleware()
})
}
console.log('builder.build()', builder.buildNamed())
return builder.build()
}

View File

@ -1,67 +0,0 @@
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
/**
* LLM "思考时间"Time to First Token AI SDK
*
* :
* 1. `stream`
* 2. `TransformStream`
* 3. (chunk)
* 4. "思考时间"
* thinking_complete
*/
export default function thinkingTimeMiddleware(): LanguageModelV1Middleware {
return {
wrapStream: async ({ doStream }) => {
let hasThinkingContent = false
let thinkingStartTime = 0
let accumulatedThinkingContent = ''
const { stream, ...reset } = await doStream()
const transformStream = new TransformStream<LanguageModelV1StreamPart, any>({
transform(chunk, controller) {
if (chunk.type === 'reasoning' || chunk.type === 'redacted-reasoning') {
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = performance.now()
}
accumulatedThinkingContent += chunk.textDelta || ''
// 将所有 chunk 原样传递下去
controller.enqueue({ ...chunk, thinking_millsec: performance.now() - thinkingStartTime })
} else {
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = performance.now() - thinkingStartTime
const thinkingCompleteChunk = {
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
controller.enqueue(thinkingCompleteChunk)
hasThinkingContent = false
thinkingStartTime = 0
accumulatedThinkingContent = ''
}
controller.enqueue(chunk)
}
},
flush(controller) {
// 如果流的末尾都是 reasoning也需要发送 complete 事件
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk = {
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
controller.enqueue(thinkingCompleteChunk)
}
controller.terminate()
}
})
return {
stream: stream.pipeThrough(transformStream),
...reset
}
}
}
}

View File

@ -0,0 +1,37 @@
import { definePlugin } from '@cherrystudio/ai-core'
export default definePlugin(() => ({
name: 'reasoningTimePlugin',
transformStream: () => () => {
let thinkingStartTime = 0
let hasStartedThinking = false
let accumulatedThinkingContent = ''
return new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'reasoning') {
if (!hasStartedThinking) {
hasStartedThinking = true
thinkingStartTime = performance.now()
}
accumulatedThinkingContent += chunk.textDelta
controller.enqueue({
...chunk,
thinking_millsec: performance.now() - thinkingStartTime
})
} else if (hasStartedThinking && accumulatedThinkingContent) {
controller.enqueue({
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: performance.now() - thinkingStartTime
})
accumulatedThinkingContent = ''
hasStartedThinking = false
thinkingStartTime = 0
controller.enqueue(chunk)
} else {
controller.enqueue(chunk)
}
}
})
}
}))

View File

@ -1,7 +1,7 @@
import { definePlugin } from '@cherrystudio/ai-core'
export default definePlugin(({ delayInMs, chunkingRegex }: { delayInMs: number; chunkingRegex: RegExp }) => ({
name: 'reasonPlugin',
name: 'smoothReasoningPlugin',
transformStream: () => () => {
let buffer = ''