diff --git a/src/main/reranker/BaseReranker.ts b/src/main/reranker/BaseReranker.ts index a88d0883ae..f956a0573f 100644 --- a/src/main/reranker/BaseReranker.ts +++ b/src/main/reranker/BaseReranker.ts @@ -38,7 +38,7 @@ export default abstract class BaseReranker { protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) { const provider = this.base.rerankModelProvider const documents = searchResults.map((doc) => doc.pageContent) - const topN = this.base.topN || 10 + const topN = this.base.documentCount if (provider === 'voyageai') { return { diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx index fd76d870c5..34f750c908 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx @@ -1,8 +1,7 @@ import { CopyOutlined } from '@ant-design/icons' import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { TopView } from '@renderer/components/TopView' -import { DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant' -import { getFileFromUrl, getKnowledgeBaseParams } from '@renderer/services/KnowledgeService' +import { searchKnowledgeBase } from '@renderer/services/KnowledgeService' import { FileType, KnowledgeBase } from '@renderer/types' import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd' import { useRef, useState } from 'react' @@ -38,29 +37,8 @@ const PopupContainer: React.FC = ({ base, resolve }) => { setSearchKeyword(value.trim()) setLoading(true) try { - const searchResults = await window.api.knowledgeBase.search({ - search: value, - base: getKnowledgeBaseParams(base) - }) - let rerankResult = searchResults - if (base.rerankModel) { - rerankResult = await window.api.knowledgeBase.rerank({ - search: value, - base: getKnowledgeBaseParams(base), - results: searchResults - }) - } - const results = await Promise.all( - rerankResult.map(async (item) => { - const file = await getFileFromUrl(item.metadata.source) - return { ...item, file } - }) - ) - const filteredResults = results.filter((item) => { - const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD - return item.score >= threshold - }) - setResults(filteredResults) + const searchResults = await searchKnowledgeBase(value, base) + setResults(searchResults) } catch (error) { console.error('Search failed:', error) } finally { diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx index f409990094..237c4a8bcb 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx @@ -29,7 +29,6 @@ interface FormData { chunkOverlap?: number threshold?: number rerankModel?: string - topN?: number } interface Props extends ShowParams { @@ -95,8 +94,7 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { threshold: values.threshold ?? undefined, rerankModel: values.rerankModel ? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel) - : undefined, - topN: values.topN + : undefined } updateKnowledgeBase(newBase) setOpen(false) @@ -283,23 +281,6 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { - 30)) { - return Promise.reject(new Error(t('knowledge.topN_too_large_or_small'))) - } - return Promise.resolve() - } - } - ]}> - - > => { + try { + const baseParams = getKnowledgeBaseParams(base) + const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT + const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD + + // 执行搜索 + const searchResults = await window.api.knowledgeBase.search({ + search: query, + base: baseParams + }) + + // 过滤阈值不达标的结果 + const filteredResults = searchResults.filter((item) => item.score >= threshold) + + // 如果有rerank模型,执行重排 + let rerankResults = filteredResults + if (base.rerankModel && filteredResults.length > 0) { + rerankResults = await window.api.knowledgeBase.rerank({ + search: rewrite || query, + base: baseParams, + results: filteredResults + }) + } + + // 限制文档数量 + const limitedResults = rerankResults.slice(0, documentCount) + + // 处理文件信息 + return await Promise.all( + limitedResults.map(async (item) => { + const file = await getFileFromUrl(item.metadata.source) + return { ...item, file } + }) + ) + } catch (error) { + Logger.error(`Error searching knowledge base ${base.name}:`, error) + return [] + } +} + export const processKnowledgeSearch = async ( extractResults: ExtractResults, knowledgeBaseIds: string[] | undefined @@ -100,6 +145,7 @@ export const processKnowledgeSearch = async ( Logger.log('No valid question found in extractResults.knowledge') return [] } + const questions = extractResults.knowledge.question const rewrite = extractResults.knowledge.rewrite @@ -109,73 +155,35 @@ export const processKnowledgeSearch = async ( return [] } - const referencesPromises = bases.map(async (base) => { - try { - const baseParams = getKnowledgeBaseParams(base) - const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT + // 为每个知识库执行多问题搜索 + const baseSearchPromises = bases.map(async (base) => { + // 为每个问题搜索并合并结果 + const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite))) - const allSearchResultsPromises = questions.map((question) => - window.api.knowledgeBase - .search({ - search: question, - base: baseParams - }) - .then((results) => - results.filter((item) => { - const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD - return item.score >= threshold - }) - ) - ) + // 合并结果并去重 + const flatResults = allResults.flat() + const uniqueResults = Array.from( + new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values() + ).sort((a, b) => b.score - a.score) - const allSearchResults = await Promise.all(allSearchResultsPromises) - - const searchResults = Array.from( - new Map(allSearchResults.flat().map((item) => [item.metadata.uniqueId || item.pageContent, item])).values() - ).sort((a, b) => b.score - a.score) - - Logger.log(`Knowledge base ${base.name} search results:`, searchResults) - - let rerankResults = searchResults - if (base.rerankModel && searchResults.length > 0) { - rerankResults = await window.api.knowledgeBase.rerank({ - search: rewrite, - base: baseParams, - results: searchResults - }) - } - - if (rerankResults.length > 0) { - rerankResults = rerankResults.slice(0, documentCount) - } - - const processdResults = await Promise.all( - rerankResults.map(async (item) => { - const file = await getFileFromUrl(item.metadata.source) - return { ...item, file } - }) - ) - - return await Promise.all( - processdResults.map(async (item, index) => { - // const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId) - return { - id: index + 1, // 搜索多个库会导致ID重复 + // 转换为引用格式 + return await Promise.all( + uniqueResults.map( + async (item, index) => + ({ + id: index + 1, content: item.pageContent, sourceUrl: await getKnowledgeSourceUrl(item), - type: 'file' // 需要映射 baseItem.type是'localPathLoader' -> 'file' - } as KnowledgeReference - }) + type: 'file' + }) as KnowledgeReference ) - } catch (error) { - Logger.error(`Error searching knowledge base ${base.name}:`, error) - return [] - } + ) }) - const resultsPerBase = await Promise.all(referencesPromises) - + // 汇总所有知识库的结果 + const resultsPerBase = await Promise.all(baseSearchPromises) const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref) + // 重新为引用分配ID return allReferencesRaw.map((ref, index) => ({ ...ref, diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index ac10c11b3d..0169bfd83f 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -372,7 +372,7 @@ export interface KnowledgeBase { chunkOverlap?: number threshold?: number rerankModel?: Model - topN?: number + // topN?: number } export type KnowledgeBaseParams = { @@ -388,7 +388,7 @@ export type KnowledgeBaseParams = { rerankBaseURL?: string rerankModel?: string rerankModelProvider?: string - topN?: number + documentCount?: number } export type GenerateImageParams = {