mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
feat: implement knowledge search tool and enhance search orchestration logic
- Added a new `knowledgeSearchTool` to facilitate knowledge base searches based on user queries and intent analysis. - Refactored `analyzeSearchIntent` to simplify message context construction and improve prompt formatting. - Introduced a flag to prevent concurrent analysis processes in `searchOrchestrationPlugin`. - Updated tool configuration logic to conditionally add the knowledge search tool based on the presence of knowledge bases and user settings. - Cleaned up commented-out code for better readability and maintainability.
This commit is contained in:
parent
33db455e32
commit
0310648445
@ -23,6 +23,7 @@ import { isEmpty } from 'lodash'
|
||||
import { z } from 'zod'
|
||||
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
|
||||
import { memorySearchTool } from '../tools/MemorySearchTool'
|
||||
import { webSearchTool } from '../tools/WebSearchTool'
|
||||
|
||||
@ -61,6 +62,7 @@ const SearchIntentAnalysisSchema = z.object({
|
||||
|
||||
type SearchIntentResult = z.infer<typeof SearchIntentAnalysisSchema>
|
||||
|
||||
let isAnalyzing = false
|
||||
/**
|
||||
* 🧠 意图分析函数 - 使用结构化输出重构
|
||||
*/
|
||||
@ -104,21 +106,12 @@ async function analyzeSearchIntent(
|
||||
schema = SearchIntentAnalysisSchema
|
||||
}
|
||||
|
||||
// 构建消息上下文
|
||||
const messages = lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage]
|
||||
console.log('messagesmessagesmessagesmessagesmessagesmessagesmessages', messages)
|
||||
// 格式化消息为提示词期望的格式
|
||||
// const chatHistory =
|
||||
// messages.length > 1
|
||||
// ? messages
|
||||
// .slice(0, -1)
|
||||
// .map((msg) => `${msg.role}: ${getMainTextContent(msg)}`)
|
||||
// .join('\n')
|
||||
// : ''
|
||||
// const question = getMainTextContent(lastUserMessage) || ''
|
||||
// 构建消息上下文 - 简化逻辑
|
||||
const chatHistory = lastAnswer ? `assistant: ${getMessageContent(lastAnswer)}` : ''
|
||||
const question = getMessageContent(lastUserMessage) || ''
|
||||
|
||||
// // 使用模板替换变量
|
||||
// const formattedPrompt = prompt.replace('{chat_history}', chatHistory).replace('{question}', question)
|
||||
// 使用模板替换变量
|
||||
const formattedPrompt = prompt.replace('{chat_history}', chatHistory).replace('{question}', question)
|
||||
|
||||
// 获取模型和provider信息
|
||||
const model = assistant.model || getDefaultModel()
|
||||
@ -130,7 +123,12 @@ async function analyzeSearchIntent(
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await context?.executor?.generateObject(model.id, { schema, prompt })
|
||||
isAnalyzing = true
|
||||
const result = await context?.executor?.generateObject(model.id, {
|
||||
schema,
|
||||
prompt: formattedPrompt
|
||||
})
|
||||
isAnalyzing = false
|
||||
console.log('result', context)
|
||||
const parsedResult = result?.object as SearchIntentResult
|
||||
|
||||
@ -165,7 +163,7 @@ async function storeConversationMemory(messages: ModelMessage[], assistant: Assi
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
if (!globalMemoryEnabled || !assistant.enableMemory) {
|
||||
console.log('Memory storage is disabled')
|
||||
// console.log('Memory storage is disabled')
|
||||
return
|
||||
}
|
||||
|
||||
@ -237,10 +235,10 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
* 🔍 Step 1: 意图识别阶段
|
||||
*/
|
||||
onRequestStart: async (context: AiRequestContext) => {
|
||||
if (isAnalyzing) return
|
||||
console.log('🧠 [SearchOrchestration] Starting intent analysis...', context.requestId)
|
||||
|
||||
try {
|
||||
// 从参数中提取信息
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
if (!messages || messages.length === 0) {
|
||||
@ -256,7 +254,9 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
|
||||
// 判断是否需要各种搜索
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
console.log('knowledgeBaseIds', knowledgeBaseIds)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
console.log('hasKnowledgeBase', hasKnowledgeBase)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const globalMemoryEnabled = selectGlobalMemoryEnabled(store.getState())
|
||||
|
||||
@ -266,12 +266,11 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
|
||||
console.log('🧠 [SearchOrchestration] Search capabilities:', {
|
||||
shouldWebSearch,
|
||||
shouldKnowledgeSearch,
|
||||
hasKnowledgeBase,
|
||||
shouldMemorySearch
|
||||
})
|
||||
|
||||
// 执行意图分析
|
||||
if (shouldWebSearch || shouldKnowledgeSearch) {
|
||||
if (shouldWebSearch || hasKnowledgeBase) {
|
||||
const analysisResult = await analyzeSearchIntent(lastUserMessage, assistant, {
|
||||
shouldWebSearch,
|
||||
shouldKnowledgeSearch,
|
||||
@ -295,15 +294,15 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
* 🔧 Step 2: 工具配置阶段
|
||||
*/
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
if (isAnalyzing) return
|
||||
console.log('🔧 [SearchOrchestration] Configuring tools based on intent...', context.requestId)
|
||||
|
||||
try {
|
||||
const analysisResult = intentAnalysisResults[context.requestId]
|
||||
console.log('analysisResult', analysisResult)
|
||||
if (!analysisResult || !assistant) {
|
||||
console.log('🔧 [SearchOrchestration] No analysis result or assistant, skipping tool configuration')
|
||||
return params
|
||||
}
|
||||
// if (!analysisResult || !assistant) {
|
||||
// console.log('🔧 [SearchOrchestration] No analysis result or assistant, skipping tool configuration')
|
||||
// return params
|
||||
// }
|
||||
|
||||
// 确保 tools 对象存在
|
||||
if (!params.tools) {
|
||||
@ -311,7 +310,7 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
}
|
||||
|
||||
// 🌐 网络搜索工具配置
|
||||
if (analysisResult.websearch && assistant.webSearchProviderId) {
|
||||
if (analysisResult?.websearch && assistant.webSearchProviderId) {
|
||||
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
|
||||
|
||||
if (needsSearch) {
|
||||
@ -321,14 +320,27 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
}
|
||||
|
||||
// 📚 知识库搜索工具配置
|
||||
if (analysisResult.knowledge) {
|
||||
const needsKnowledgeSearch =
|
||||
analysisResult.knowledge.question && analysisResult.knowledge.question[0] !== 'not_needed'
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
|
||||
if (needsKnowledgeSearch) {
|
||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool')
|
||||
// TODO: 添加知识库搜索工具
|
||||
// params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant.knowledge_bases)
|
||||
if (hasKnowledgeBase) {
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// off 模式:直接添加知识库搜索工具,跳过意图识别
|
||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool (force mode)')
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
|
||||
params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
|
||||
} else {
|
||||
// on 模式:根据意图识别结果决定是否添加工具
|
||||
const needsKnowledgeSearch =
|
||||
analysisResult?.knowledge &&
|
||||
analysisResult.knowledge.question &&
|
||||
analysisResult.knowledge.question[0] !== 'not_needed'
|
||||
|
||||
if (needsKnowledgeSearch) {
|
||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool (intent-based)')
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -355,7 +367,6 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
console.log('💾 [SearchOrchestration] Starting memory storage...', context.requestId)
|
||||
|
||||
try {
|
||||
const assistant = context.originalParams.assistant
|
||||
const messages = context.originalParams.messages
|
||||
|
||||
if (messages && assistant) {
|
||||
|
||||
100
src/renderer/src/aiCore/tools/KnowledgeSearchTool.tsx
Normal file
100
src/renderer/src/aiCore/tools/KnowledgeSearchTool.tsx
Normal file
@ -0,0 +1,100 @@
|
||||
import { processKnowledgeSearch } from '@renderer/services/KnowledgeService'
|
||||
import type { Assistant, KnowledgeReference } from '@renderer/types'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
import { z } from 'zod'
|
||||
|
||||
// Schema definitions - 添加 userMessage 字段来获取用户消息
|
||||
const KnowledgeSearchInputSchema = z.object({
|
||||
query: z.string().describe('The search query for knowledge base'),
|
||||
rewrite: z.string().optional().describe('Optional rewritten query with alternative phrasing'),
|
||||
userMessage: z.string().describe('The original user message content for direct search mode')
|
||||
})
|
||||
|
||||
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
|
||||
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
|
||||
|
||||
/**
|
||||
* 知识库搜索工具
|
||||
* 基于 ApiService.ts 中的 searchKnowledgeBase 逻辑实现
|
||||
*/
|
||||
export const knowledgeSearchTool = (assistant: Assistant) => {
|
||||
return tool({
|
||||
name: 'builtin_knowledge_search',
|
||||
description: 'Search the knowledge base for relevant information',
|
||||
inputSchema: KnowledgeSearchInputSchema,
|
||||
execute: async ({ query, rewrite, userMessage }) => {
|
||||
console.log('🔍 [KnowledgeSearchTool] Executing search:', { query, rewrite, userMessage })
|
||||
|
||||
try {
|
||||
// 获取助手的知识库配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
|
||||
// 检查是否有知识库
|
||||
if (!hasKnowledgeBase) {
|
||||
console.log('🔍 [KnowledgeSearchTool] No knowledge bases found for assistant')
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建搜索条件 - 复制原逻辑
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// 直接模式:使用用户消息内容 (类似原逻辑的 getMainTextContent(lastUserMessage))
|
||||
const directContent = userMessage || query || 'search'
|
||||
searchCriteria = {
|
||||
question: [directContent],
|
||||
rewrite: directContent
|
||||
}
|
||||
console.log('🔍 [KnowledgeSearchTool] Direct mode - using user message:', directContent)
|
||||
} else {
|
||||
// 自动模式:使用意图识别的结果 (类似原逻辑的 extractResults.knowledge)
|
||||
searchCriteria = {
|
||||
question: [query],
|
||||
rewrite: rewrite || query
|
||||
}
|
||||
console.log('🔍 [KnowledgeSearchTool] Auto mode - using intent analysis result')
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (searchCriteria.question[0] === 'not_needed') {
|
||||
console.log('🔍 [KnowledgeSearchTool] Search not needed')
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 对象 - 与原逻辑一致
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
|
||||
console.log('🔍 [KnowledgeSearchTool] Search criteria:', searchCriteria)
|
||||
console.log('🔍 [KnowledgeSearchTool] Knowledge base IDs:', knowledgeBaseIds)
|
||||
|
||||
// 执行知识库搜索
|
||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds)
|
||||
|
||||
console.log('🔍 [KnowledgeSearchTool] Search results:', knowledgeReferences)
|
||||
|
||||
// 返回结果数组
|
||||
return knowledgeReferences.map((ref: KnowledgeReference) => ({
|
||||
id: ref.id,
|
||||
content: ref.content,
|
||||
sourceUrl: ref.sourceUrl,
|
||||
type: ref.type,
|
||||
file: ref.file
|
||||
}))
|
||||
} catch (error) {
|
||||
console.error('🔍 [KnowledgeSearchTool] Search failed:', error)
|
||||
|
||||
// 返回空数组而不是抛出错误,避免中断对话流程
|
||||
return []
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export default knowledgeSearchTool
|
||||
@ -6,7 +6,7 @@ import MessageTool from './MessageTool'
|
||||
interface Props {
|
||||
block: ToolMessageBlock
|
||||
}
|
||||
|
||||
// TODO: 知识库tool
|
||||
export default function MessageTools({ block }: Props) {
|
||||
const toolResponse = block.metadata?.rawMcpToolResponse
|
||||
if (!toolResponse) return null
|
||||
|
||||
Loading…
Reference in New Issue
Block a user