refactor: streamline reasoning plugins and remove unused components

- Removed the `reasoningTimePlugin` and `mcpPromptPlugin` to simplify the plugin architecture.
- Updated the `smoothReasoningPlugin` to enhance its functionality and reduce delay in processing.
- Adjusted the `textPlugin` to align with the new delay settings for smoother output.
- Modified the `ModernAiProvider` to utilize the updated `smoothReasoningPlugin` without the removed plugins.
This commit is contained in:
MyPrototypeWhat 2025-06-30 18:34:08 +08:00
parent 007de81928
commit cf5ed8e858
5 changed files with 67 additions and 316 deletions

View File

@ -28,7 +28,6 @@ import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
import reasoningTimePlugin from './plugins/reasoningTimePlugin'
import smoothReasoningPlugin from './plugins/smoothReasoningPlugin'
import textPlugin from './plugins/textPlugin'
import { getAiSdkProviderId } from './provider/factory'
@ -124,13 +123,7 @@ export default class ModernAiProvider {
// 2. 推理模型时添加推理插件
if (middlewareConfig.enableReasoning) {
plugins.push(
smoothReasoningPlugin({
delayInMs: 80,
chunkingRegex: /([\u4E00-\u9FFF]{3})|\S+\s+/
}),
reasoningTimePlugin
)
plugins.push(smoothReasoningPlugin)
}
// 3. 启用Prompt工具调用时添加工具插件

View File

@ -1,257 +0,0 @@
import { definePlugin, StreamTextParams } from '@cherrystudio/ai-core'
import { buildSystemPromptWithTools } from '@renderer/aiCore/transformParameters'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import { Assistant, MCPTool, MCPToolResponse } from '@renderer/types'
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
/**
* MCP Prompt
*/
export interface MCPPromptPluginConfig {
mcpTools: MCPTool[]
assistant: Assistant
onChunk: (chunk: any) => void
recursiveCall: (params: StreamTextParams) => Promise<{ stream?: ReadableStream; getText?: () => string }>
recursionDepth?: number // 当前递归深度,默认为 0
maxRecursionDepth?: number // 最大递归深度,默认为 20
}
/**
* MCP Prompt
* prompt
*/
export const createMCPPromptPlugin = definePlugin((config: MCPPromptPluginConfig) => {
return {
name: 'mcp-prompt-plugin',
// 1. 参数转换 - 注入工具描述到系统提示
transformParams: async (params: StreamTextParams) => {
const { mcpTools, assistant } = config
if (mcpTools.length === 0) {
return params
}
try {
// 复用现有的系统提示构建逻辑
const enhancedSystemPrompt = await buildSystemPromptWithTools(params.system || '', mcpTools, assistant)
return {
...params,
system: enhancedSystemPrompt,
// Prompt 模式不使用 function calling
tools: undefined
}
} catch (error) {
console.error('构建系统提示失败:', error)
return params
}
},
// 2. 流处理 - 检测工具调用并执行
transformStream: () => () => {
let fullResponseText = ''
let hasProcessedTools = false
return new TransformStream({
async transform(chunk, controller) {
try {
// 收集完整的文本响应
if (chunk.type === 'text-delta') {
fullResponseText += chunk.textDelta
}
// 在流结束时检查并处理工具调用
if (chunk.type === 'finish' && !hasProcessedTools) {
hasProcessedTools = true
if (containsToolCallPattern(fullResponseText)) {
await processToolCallsAndRecurse(fullResponseText, config, controller)
return // 不转发 finish chunk让递归调用处理
}
}
// 正常转发其他类型的 chunk
controller.enqueue(chunk)
} catch (error) {
console.error('MCP Prompt Plugin Transform Error:', error)
controller.error(error)
}
},
async flush(controller) {
// 流结束时的最后检查
if (!hasProcessedTools && containsToolCallPattern(fullResponseText)) {
await processToolCallsAndRecurse(fullResponseText, config, controller)
}
}
})
}
}
})
/**
*
*/
async function processToolCallsAndRecurse(
responseText: string,
config: MCPPromptPluginConfig,
controller: TransformStreamDefaultController
) {
const { mcpTools, assistant, onChunk, recursionDepth = 0, maxRecursionDepth = 20 } = config
// 检查是否超过最大递归深度
if (recursionDepth >= maxRecursionDepth) {
console.log(`已达到最大递归深度 ${maxRecursionDepth},停止工具调用处理`)
controller.enqueue({
type: 'text-delta',
textDelta: `\n\n[已达到最大工具调用深度 ${maxRecursionDepth},停止继续调用]`
})
return
}
try {
console.log(`检测到工具调用,开始处理... (递归深度: ${recursionDepth}/${maxRecursionDepth})`)
const allToolResponses: MCPToolResponse[] = []
// 直接使用现有的 parseAndCallTools 函数
// 它会自动解析文本中的工具调用、执行工具、触发 onChunk
const toolResults = await parseAndCallTools(
responseText, // 传入完整响应文本,让 parseAndCallTools 自己解析
allToolResponses,
onChunk, // 直接传入来自配置的 onChunk
(mcpToolResponse, resp) => {
// 复用现有的消息转换逻辑
return convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp)
},
assistant.model!,
mcpTools
)
console.log('工具执行完成,结果数量:', toolResults.length)
// 如果有工具结果,构建新消息并递归调用
if (toolResults.length > 0) {
await performRecursiveCall(responseText, toolResults, config, controller)
}
} catch (error) {
console.error('工具调用处理失败:', error)
// 发送错误信息作为文本 chunk
controller.enqueue({
type: 'text-delta',
textDelta: `\n\n[工具调用错误: ${error instanceof Error ? error.message : String(error)}]`
})
}
}
/**
*
*/
async function performRecursiveCall(
originalResponse: string,
toolResults: any[],
config: MCPPromptPluginConfig,
controller: TransformStreamDefaultController
) {
const { assistant, recursiveCall, recursionDepth = 0 } = config
try {
// 获取当前的消息历史(需要从上下文获取,这里暂时用空数组)
// TODO: 实现从上下文获取当前消息的逻辑
const currentMessages = getCurrentMessagesFromContext()
// 构建新的消息历史
const newMessages = [
...currentMessages,
{
role: 'assistant' as const,
content: originalResponse
},
...toolResults // toolResults 已经是正确的消息格式
]
console.log(`构建新消息历史完成,消息数量: ${newMessages.length},递归深度: ${recursionDepth}`)
// 复用现有的参数构建逻辑
const { params: recursiveParams } = await buildStreamTextParams(newMessages, assistant, {
mcpTools: config.mcpTools,
enableTools: true
})
console.log(`开始递归调用... (深度: ${recursionDepth + 1})`)
// 递归调用,递增递归深度
const recursiveResult = await recursiveCall(recursiveParams)
// 转发递归结果的流
if (recursiveResult.stream) {
const reader = recursiveResult.stream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) break
controller.enqueue(value)
}
} finally {
reader.releaseLock()
}
} else if (recursiveResult.getText) {
// 如果没有流,但有文本结果
const finalText = recursiveResult.getText()
controller.enqueue({
type: 'text-delta',
textDelta: finalText
})
}
console.log(`递归调用完成 (深度: ${recursionDepth + 1})`)
} catch (error) {
console.error('递归调用失败:', error)
controller.enqueue({
type: 'text-delta',
textDelta: `\n\n[递归调用错误: ${error instanceof Error ? error.message : String(error)}]`
})
}
}
/**
*
*/
function containsToolCallPattern(text: string): boolean {
const patterns = [
/<tool_use>/i,
/<tool_call>/i
// 可以根据实际使用的格式添加更多模式
]
return patterns.some((pattern) => pattern.test(text))
}
/**
*
* TODO: 实现从实际上下文获取消息的逻辑
*/
function getCurrentMessagesFromContext(): any[] {
// 这里需要实现从上下文获取当前消息历史的逻辑
// 暂时返回空数组,后续根据实际情况补充
console.warn('getCurrentMessagesFromContext 尚未实现,返回空数组')
return []
}
/**
* MCP SDK
*
*/
function convertMcpToolResponseToSdkMessageParam(mcpToolResponse: MCPToolResponse, resp: any): any {
// 这里需要根据实际的转换逻辑来实现
// 暂时返回一个基础的用户消息格式
return {
role: 'user',
content: `工具 ${mcpToolResponse.tool.name} 执行结果: ${JSON.stringify(resp)}`
}
}
export default createMCPPromptPlugin

View File

@ -1,38 +0,0 @@
import { definePlugin } from '@cherrystudio/ai-core'
export default definePlugin({
name: 'reasoningTimePlugin',
transformStream: () => () => {
let thinkingStartTime = performance.now()
let hasStartedThinking = false
let accumulatedThinkingContent = ''
return new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'reasoning') {
if (!hasStartedThinking) {
hasStartedThinking = true
// thinkingStartTime = performance.now()
}
accumulatedThinkingContent += chunk.textDelta
console.log('performance.now() - thinkingStartTime', performance.now() - thinkingStartTime)
controller.enqueue({
...chunk,
thinking_millsec: performance.now() - thinkingStartTime
})
} else if (hasStartedThinking && accumulatedThinkingContent) {
controller.enqueue({
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: performance.now() - thinkingStartTime
})
accumulatedThinkingContent = ''
hasStartedThinking = false
thinkingStartTime = 0
controller.enqueue(chunk)
} else {
controller.enqueue(chunk)
}
}
})
}
})

View File

@ -1,42 +1,95 @@
import { definePlugin } from '@cherrystudio/ai-core'
export default definePlugin(({ delayInMs, chunkingRegex }: { delayInMs: number; chunkingRegex: RegExp }) => ({
name: 'smoothReasoningPlugin',
const chunkingRegex = /([\u4E00-\u9FFF])|\S+\s+/
const delayInMs = 20
export default definePlugin({
name: 'reasoningPlugin',
transformStream: () => () => {
// === smoothing 状态 ===
let buffer = ''
// === 时间跟踪状态 ===
let thinkingStartTime = performance.now()
let hasStartedThinking = false
let accumulatedThinkingContent = ''
const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))
const detectChunk = (buffer: string) => {
const match = chunkingRegex.exec(buffer)
if (!match) {
return null
}
if (!match) return null
return buffer.slice(0, match.index) + match?.[0]
}
return new TransformStream({
async transform(chunk, controller) {
if (chunk.type !== 'reasoning') {
if (buffer.length > 0) {
controller.enqueue({ type: 'reasoning', textDelta: buffer })
buffer = ''
// === 处理 reasoning 结束 ===
if (hasStartedThinking && accumulatedThinkingContent) {
// 先输出剩余的 buffer
if (buffer.length > 0) {
controller.enqueue({
type: 'reasoning',
textDelta: buffer,
thinking_millsec: performance.now() - thinkingStartTime
})
buffer = ''
}
// 生成 reasoning-signature
controller.enqueue({
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: performance.now() - thinkingStartTime
})
// 重置状态
accumulatedThinkingContent = ''
hasStartedThinking = false
thinkingStartTime = 0
}
controller.enqueue(chunk)
return
}
// === 处理 reasoning 类型 ===
// 1. 时间跟踪逻辑
if (!hasStartedThinking) {
hasStartedThinking = true
thinkingStartTime = performance.now()
}
accumulatedThinkingContent += chunk.textDelta
// 2. Smooth 处理逻辑
buffer += chunk.textDelta
let match
while ((match = detectChunk(buffer)) != null) {
controller.enqueue({ type: 'reasoning', textDelta: match })
controller.enqueue({
type: 'reasoning',
textDelta: match,
thinking_millsec: performance.now() - thinkingStartTime
})
buffer = buffer.slice(match.length)
await delay(delayInMs)
}
},
// === flush 处理剩余 buffer ===
flush(controller) {
if (buffer.length > 0) {
controller.enqueue({
type: 'reasoning',
textDelta: buffer,
thinking_millsec: hasStartedThinking ? performance.now() - thinkingStartTime : 0
})
}
}
})
}
}))
})

View File

@ -4,7 +4,7 @@ export default definePlugin({
name: 'textPlugin',
transformStream: () =>
smoothStream({
delayInMs: 80,
delayInMs: 20,
// 中文3个字符一个chunk,英文一个单词一个chunk
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
})