mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 23:10:20 +08:00
feat: enhance ToolCallChunkHandler with detailed chunk handling and remove unused plugins
- Updated `handleToolCallCreated` method to support additional chunk types with optional provider metadata. - Removed deprecated `smoothReasoningPlugin` and `textPlugin` files to clean up the codebase. - Cleaned up unused type imports in `tool.ts` for improved clarity and maintainability.
This commit is contained in:
parent
fcc8836c95
commit
addd5ffdfa
@ -8,6 +8,7 @@ import { ToolCallUnion, ToolResultUnion, ToolSet } from '@cherrystudio/ai-core'
|
|||||||
import Logger from '@renderer/config/logger'
|
import Logger from '@renderer/config/logger'
|
||||||
import { BaseTool, MCPToolResponse, ToolCallResponse } from '@renderer/types'
|
import { BaseTool, MCPToolResponse, ToolCallResponse } from '@renderer/types'
|
||||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||||
|
import { type ProviderMetadata } from 'ai'
|
||||||
// import type {
|
// import type {
|
||||||
// AnthropicSearchOutput,
|
// AnthropicSearchOutput,
|
||||||
// WebSearchPluginConfig
|
// WebSearchPluginConfig
|
||||||
@ -40,7 +41,27 @@ export class ToolCallChunkHandler {
|
|||||||
// this.onChunk = callback
|
// this.onChunk = callback
|
||||||
// }
|
// }
|
||||||
|
|
||||||
handleToolCallCreated(chunk: { type: 'tool-input-start' | 'tool-input-delta' | 'tool-input-end' }): void {
|
handleToolCallCreated(
|
||||||
|
chunk:
|
||||||
|
| {
|
||||||
|
type: 'tool-input-start'
|
||||||
|
id: string
|
||||||
|
toolName: string
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
providerExecuted?: boolean
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'tool-input-end'
|
||||||
|
id: string
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'tool-input-delta'
|
||||||
|
id: string
|
||||||
|
delta: string
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
}
|
||||||
|
): void {
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
case 'tool-input-start': {
|
case 'tool-input-start': {
|
||||||
// 能拿到说明是mcpTool
|
// 能拿到说明是mcpTool
|
||||||
|
|||||||
@ -1,152 +0,0 @@
|
|||||||
// // 可能会废弃,在流上做delay还是有问题
|
|
||||||
|
|
||||||
// import { definePlugin } from '@cherrystudio/ai-core'
|
|
||||||
|
|
||||||
// const chunkingRegex = /([\u4E00-\u9FFF])|\S+\s+/
|
|
||||||
// const delayInMs = 50
|
|
||||||
|
|
||||||
// export default definePlugin({
|
|
||||||
// name: 'reasoningPlugin',
|
|
||||||
|
|
||||||
// transformStream: () => () => {
|
|
||||||
// // === smoothing 状态 ===
|
|
||||||
// let buffer = ''
|
|
||||||
|
|
||||||
// // === 时间跟踪状态 ===
|
|
||||||
// let thinkingStartTime = performance.now()
|
|
||||||
// let hasStartedThinking = false
|
|
||||||
// let accumulatedThinkingContent = ''
|
|
||||||
|
|
||||||
// // === 日志计数器 ===
|
|
||||||
// let chunkCount = 0
|
|
||||||
// let delayCount = 0
|
|
||||||
|
|
||||||
// const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))
|
|
||||||
|
|
||||||
// // 收集所有当前可匹配的chunks
|
|
||||||
// const collectMatches = (inputBuffer: string) => {
|
|
||||||
// const matches: string[] = []
|
|
||||||
// let tempBuffer = inputBuffer
|
|
||||||
// let match
|
|
||||||
|
|
||||||
// // 重置regex状态
|
|
||||||
// chunkingRegex.lastIndex = 0
|
|
||||||
|
|
||||||
// while ((match = chunkingRegex.exec(tempBuffer)) !== null) {
|
|
||||||
// matches.push(match[0])
|
|
||||||
// tempBuffer = tempBuffer.slice(match.index + match[0].length)
|
|
||||||
// // 重置regex以从头开始匹配剩余内容
|
|
||||||
// chunkingRegex.lastIndex = 0
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return {
|
|
||||||
// matches,
|
|
||||||
// remaining: tempBuffer
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return new TransformStream({
|
|
||||||
// async transform(chunk, controller) {
|
|
||||||
// if (chunk.type !== 'reasoning') {
|
|
||||||
// // === 处理 reasoning 结束 ===
|
|
||||||
// if (hasStartedThinking && accumulatedThinkingContent) {
|
|
||||||
// console.log(
|
|
||||||
// `[ReasoningPlugin] Ending reasoning. Final stats: chunks=${chunkCount}, delays=${delayCount}, efficiency=${(chunkCount / Math.max(delayCount, 1)).toFixed(2)}x`
|
|
||||||
// )
|
|
||||||
|
|
||||||
// // 先输出剩余的 buffer
|
|
||||||
// if (buffer.length > 0) {
|
|
||||||
// console.log(`[ReasoningPlugin] Flushing remaining buffer: "${buffer}"`)
|
|
||||||
// controller.enqueue({
|
|
||||||
// type: 'reasoning',
|
|
||||||
// textDelta: buffer,
|
|
||||||
// thinking_millsec: performance.now() - thinkingStartTime
|
|
||||||
// })
|
|
||||||
// buffer = ''
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // 生成 reasoning-signature
|
|
||||||
// controller.enqueue({
|
|
||||||
// type: 'reasoning-signature',
|
|
||||||
// text: accumulatedThinkingContent,
|
|
||||||
// thinking_millsec: performance.now() - thinkingStartTime
|
|
||||||
// })
|
|
||||||
|
|
||||||
// // 重置状态
|
|
||||||
// accumulatedThinkingContent = ''
|
|
||||||
// hasStartedThinking = false
|
|
||||||
// thinkingStartTime = 0
|
|
||||||
// chunkCount = 0
|
|
||||||
// delayCount = 0
|
|
||||||
// }
|
|
||||||
|
|
||||||
// controller.enqueue(chunk)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // === 处理 reasoning 类型 ===
|
|
||||||
|
|
||||||
// // 1. 时间跟踪逻辑
|
|
||||||
// if (!hasStartedThinking) {
|
|
||||||
// hasStartedThinking = true
|
|
||||||
// thinkingStartTime = performance.now()
|
|
||||||
// console.log(`[ReasoningPlugin] Starting reasoning session`)
|
|
||||||
// }
|
|
||||||
// accumulatedThinkingContent += chunk.textDelta
|
|
||||||
|
|
||||||
// // 2. 动态Smooth处理逻辑
|
|
||||||
// const beforeBuffer = buffer
|
|
||||||
// buffer += chunk.textDelta
|
|
||||||
|
|
||||||
// console.log(`[ReasoningPlugin] Received chunk: "${chunk.textDelta}", buffer: "${beforeBuffer}" → "${buffer}"`)
|
|
||||||
|
|
||||||
// // 收集所有当前可以匹配的chunks
|
|
||||||
// const { matches, remaining } = collectMatches(buffer)
|
|
||||||
|
|
||||||
// if (matches.length > 0) {
|
|
||||||
// console.log(
|
|
||||||
// `[ReasoningPlugin] Collected ${matches.length} matches: [${matches.map((m) => `"${m}"`).join(', ')}], remaining: "${remaining}"`
|
|
||||||
// )
|
|
||||||
|
|
||||||
// // 批量输出所有匹配的chunks
|
|
||||||
// for (const matchText of matches) {
|
|
||||||
// controller.enqueue({
|
|
||||||
// type: 'reasoning',
|
|
||||||
// textDelta: matchText,
|
|
||||||
// thinking_millsec: performance.now() - thinkingStartTime
|
|
||||||
// })
|
|
||||||
// chunkCount++
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // 更新buffer为剩余内容
|
|
||||||
// buffer = remaining
|
|
||||||
|
|
||||||
// // 只等待一次,而不是每个chunk都等待
|
|
||||||
// delayCount++
|
|
||||||
// console.log(
|
|
||||||
// `[ReasoningPlugin] Delaying ${delayInMs}ms (delay #${delayCount}, efficiency: ${(chunkCount / delayCount).toFixed(2)} chunks/delay)`
|
|
||||||
// )
|
|
||||||
// const delayStart = performance.now()
|
|
||||||
// await delay(delayInMs)
|
|
||||||
// const actualDelay = performance.now() - delayStart
|
|
||||||
// console.log(`[ReasoningPlugin] Delay completed: expected=${delayInMs}ms, actual=${actualDelay.toFixed(1)}ms`)
|
|
||||||
// } else {
|
|
||||||
// console.log(`[ReasoningPlugin] No matches found, keeping in buffer: "${buffer}"`)
|
|
||||||
// }
|
|
||||||
// // 如果没有匹配,保留在buffer中等待下次数据
|
|
||||||
// },
|
|
||||||
|
|
||||||
// // === flush 处理剩余 buffer ===
|
|
||||||
// flush(controller) {
|
|
||||||
// if (buffer.length > 0) {
|
|
||||||
// console.log(`[ReasoningPlugin] Final flush: "${buffer}"`)
|
|
||||||
// controller.enqueue({
|
|
||||||
// type: 'reasoning',
|
|
||||||
// textDelta: buffer,
|
|
||||||
// thinking_millsec: hasStartedThinking ? performance.now() - thinkingStartTime : 0
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
// // 可能会废弃,在流上做delay还是有问题
|
|
||||||
|
|
||||||
// import { definePlugin, smoothStream } from '@cherrystudio/ai-core'
|
|
||||||
|
|
||||||
// export default definePlugin({
|
|
||||||
// name: 'textPlugin',
|
|
||||||
// transformStream: () =>
|
|
||||||
// smoothStream({
|
|
||||||
// delayInMs: 50,
|
|
||||||
// // 中文3个字符一个chunk,英文一个单词一个chunk
|
|
||||||
// chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
|
||||||
// })
|
|
||||||
// })
|
|
||||||
140
src/renderer/src/aiCore/tools/MemorySearchTool.ts
Normal file
140
src/renderer/src/aiCore/tools/MemorySearchTool.ts
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import store from '@renderer/store'
|
||||||
|
import { selectCurrentUserId, selectGlobalMemoryEnabled, selectMemoryConfig } from '@renderer/store/memory'
|
||||||
|
import type { Assistant } from '@renderer/types'
|
||||||
|
import { type InferToolOutput, tool } from 'ai'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 🧠 基础记忆搜索工具
|
||||||
|
* AI 可以主动调用的简单记忆搜索
|
||||||
|
*/
|
||||||
|
export const memorySearchTool = () => {
|
||||||
|
return tool({
|
||||||
|
name: 'builtin_memory_search',
|
||||||
|
description: 'Search through conversation memories and stored facts for relevant context',
|
||||||
|
inputSchema: z.object({
|
||||||
|
query: z.string().describe('Search query to find relevant memories'),
|
||||||
|
limit: z.number().min(1).max(20).default(5).describe('Maximum number of memories to return')
|
||||||
|
}),
|
||||||
|
execute: async ({ query, limit = 5 }) => {
|
||||||
|
console.log('🧠 [memorySearchTool] Searching memories:', { query, limit })
|
||||||
|
|
||||||
|
try {
|
||||||
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
|
if (!globalMemoryEnabled) {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
const memoryConfig = selectMemoryConfig(store.getState())
|
||||||
|
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||||
|
console.warn('Memory search skipped: embedding or LLM model not configured')
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentUserId = selectCurrentUserId(store.getState())
|
||||||
|
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, 'default', currentUserId)
|
||||||
|
|
||||||
|
const memoryProcessor = new MemoryProcessor()
|
||||||
|
const relevantMemories = await memoryProcessor.searchRelevantMemories(query, processorConfig, limit)
|
||||||
|
|
||||||
|
if (relevantMemories?.length > 0) {
|
||||||
|
console.log('🧠 [memorySearchTool] Found memories:', relevantMemories.length)
|
||||||
|
return relevantMemories
|
||||||
|
}
|
||||||
|
return []
|
||||||
|
} catch (error) {
|
||||||
|
console.error('🧠 [memorySearchTool] Error:', error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 🧠 智能记忆搜索工具(带上下文提取)
|
||||||
|
* 从用户消息和对话历史中自动提取关键词进行记忆搜索
|
||||||
|
*/
|
||||||
|
export const memorySearchToolWithExtraction = (assistant: Assistant) => {
|
||||||
|
return tool({
|
||||||
|
name: 'memory_search_with_extraction',
|
||||||
|
description: 'Search memories with automatic keyword extraction from conversation context',
|
||||||
|
inputSchema: z.object({
|
||||||
|
userMessage: z.object({
|
||||||
|
content: z.string().describe('The main content of the user message'),
|
||||||
|
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||||
|
}),
|
||||||
|
lastAnswer: z
|
||||||
|
.object({
|
||||||
|
content: z.string().describe('The main content of the last assistant response'),
|
||||||
|
role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||||
|
})
|
||||||
|
.optional()
|
||||||
|
}),
|
||||||
|
execute: async ({ userMessage, lastAnswer }) => {
|
||||||
|
console.log('🧠 [memorySearchToolWithExtraction] Processing:', { userMessage, lastAnswer })
|
||||||
|
|
||||||
|
try {
|
||||||
|
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||||
|
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||||
|
return {
|
||||||
|
extractedKeywords: 'Memory search disabled',
|
||||||
|
searchResults: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const memoryConfig = selectMemoryConfig(store.getState())
|
||||||
|
if (!memoryConfig.llmApiClient || !memoryConfig.embedderApiClient) {
|
||||||
|
console.warn('Memory search skipped: embedding or LLM model not configured')
|
||||||
|
return {
|
||||||
|
extractedKeywords: 'Memory models not configured',
|
||||||
|
searchResults: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 🔍 使用用户消息内容作为搜索关键词
|
||||||
|
const content = userMessage.content
|
||||||
|
|
||||||
|
if (!content) {
|
||||||
|
return {
|
||||||
|
extractedKeywords: 'No content to search',
|
||||||
|
searchResults: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentUserId = selectCurrentUserId(store.getState())
|
||||||
|
const processorConfig = MemoryProcessor.getProcessorConfig(memoryConfig, assistant.id, currentUserId)
|
||||||
|
|
||||||
|
const memoryProcessor = new MemoryProcessor()
|
||||||
|
const relevantMemories = await memoryProcessor.searchRelevantMemories(
|
||||||
|
content,
|
||||||
|
processorConfig,
|
||||||
|
5 // Limit to top 5 most relevant memories
|
||||||
|
)
|
||||||
|
|
||||||
|
if (relevantMemories?.length > 0) {
|
||||||
|
console.log('🧠 [memorySearchToolWithExtraction] Found memories:', relevantMemories.length)
|
||||||
|
return {
|
||||||
|
extractedKeywords: content,
|
||||||
|
searchResults: relevantMemories
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
extractedKeywords: content,
|
||||||
|
searchResults: []
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('🧠 [memorySearchToolWithExtraction] Error:', error)
|
||||||
|
return {
|
||||||
|
extractedKeywords: 'Search failed',
|
||||||
|
searchResults: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export type MemorySearchToolOutput = InferToolOutput<ReturnType<typeof memorySearchTool>>
|
||||||
|
export type MemorySearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof memorySearchToolWithExtraction>>
|
||||||
@ -1,6 +1,3 @@
|
|||||||
import { WebSearchToolOutputSchema } from '@cherrystudio/ai-core/built-in/plugins'
|
|
||||||
import type { WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool'
|
|
||||||
|
|
||||||
import type { MCPToolInputSchema } from './index'
|
import type { MCPToolInputSchema } from './index'
|
||||||
|
|
||||||
export type ToolType = 'builtin' | 'provider' | 'mcp'
|
export type ToolType = 'builtin' | 'provider' | 'mcp'
|
||||||
@ -12,15 +9,15 @@ export interface BaseTool {
|
|||||||
type: ToolType
|
type: ToolType
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolCallResponse {
|
// export interface ToolCallResponse {
|
||||||
id: string
|
// id: string
|
||||||
toolName: string
|
// toolName: string
|
||||||
arguments: Record<string, unknown> | undefined
|
// arguments: Record<string, unknown> | undefined
|
||||||
status: 'invoking' | 'completed' | 'error'
|
// status: 'invoking' | 'completed' | 'error'
|
||||||
result?: any // AI SDK的工具执行结果
|
// result?: any // AI SDK的工具执行结果
|
||||||
error?: string
|
// error?: string
|
||||||
providerExecuted?: boolean // 标识是Provider端执行还是客户端执行
|
// providerExecuted?: boolean // 标识是Provider端执行还是客户端执行
|
||||||
}
|
// }
|
||||||
|
|
||||||
export interface BuiltinTool extends BaseTool {
|
export interface BuiltinTool extends BaseTool {
|
||||||
inputSchema: MCPToolInputSchema
|
inputSchema: MCPToolInputSchema
|
||||||
@ -33,5 +30,3 @@ export interface MCPTool extends BaseTool {
|
|||||||
inputSchema: MCPToolInputSchema
|
inputSchema: MCPToolInputSchema
|
||||||
type: 'mcp'
|
type: 'mcp'
|
||||||
}
|
}
|
||||||
|
|
||||||
export type WebSearchToolOutputSchema = WebSearchToolOutput | WebSearchToolOutputSchema
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user