fix(knowledge): remove topN

This commit is contained in:
eeee0717 2025-05-18 09:14:57 +08:00 committed by 亢奋猫
parent 53cff11726
commit 04c3911243
5 changed files with 76 additions and 109 deletions

View File

@ -38,7 +38,7 @@ export default abstract class BaseReranker {
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
const provider = this.base.rerankModelProvider
const documents = searchResults.map((doc) => doc.pageContent)
const topN = this.base.topN || 10
const topN = this.base.documentCount
if (provider === 'voyageai') {
return {

View File

@ -1,8 +1,7 @@
import { CopyOutlined } from '@ant-design/icons'
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
import { getFileFromUrl, getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
import { searchKnowledgeBase } from '@renderer/services/KnowledgeService'
import { FileType, KnowledgeBase } from '@renderer/types'
import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd'
import { useRef, useState } from 'react'
@ -38,29 +37,8 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
setSearchKeyword(value.trim())
setLoading(true)
try {
const searchResults = await window.api.knowledgeBase.search({
search: value,
base: getKnowledgeBaseParams(base)
})
let rerankResult = searchResults
if (base.rerankModel) {
rerankResult = await window.api.knowledgeBase.rerank({
search: value,
base: getKnowledgeBaseParams(base),
results: searchResults
})
}
const results = await Promise.all(
rerankResult.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})
)
const filteredResults = results.filter((item) => {
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
return item.score >= threshold
})
setResults(filteredResults)
const searchResults = await searchKnowledgeBase(value, base)
setResults(searchResults)
} catch (error) {
console.error('Search failed:', error)
} finally {

View File

@ -29,7 +29,6 @@ interface FormData {
chunkOverlap?: number
threshold?: number
rerankModel?: string
topN?: number
}
interface Props extends ShowParams {
@ -95,8 +94,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
threshold: values.threshold ?? undefined,
rerankModel: values.rerankModel
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
: undefined,
topN: values.topN
: undefined
}
updateKnowledgeBase(newBase)
setOpen(false)
@ -283,23 +281,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
</Form.Item>
<Form.Item
name="topN"
label={t('knowledge.topN')}
layout="horizontal"
initialValue={base.topN}
rules={[
{
validator(_, value) {
if (value && (value < 0 || value > 30)) {
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
}
return Promise.resolve()
}
}
]}>
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
</Form.Item>
<Alert
message={t('knowledge.chunk_size_change_warning')}
type="warning"

View File

@ -47,8 +47,8 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
rerankBaseURL: rerankHost,
rerankApiKey: rerankAiProvider.getApiKey() || 'secret',
rerankModel: base.rerankModel?.id,
rerankModelProvider: base.rerankModel?.provider,
topN: base.topN
rerankModelProvider: base.rerankModel?.provider
// topN: base.topN
}
}
@ -88,6 +88,51 @@ export const getKnowledgeSourceUrl = async (item: ExtractChunkData & { file: Fil
return item.metadata.source
}
export const searchKnowledgeBase = async (
query: string,
base: KnowledgeBase,
rewrite?: string
): Promise<Array<ExtractChunkData & { file: FileType | null }>> => {
try {
const baseParams = getKnowledgeBaseParams(base)
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
// 执行搜索
const searchResults = await window.api.knowledgeBase.search({
search: query,
base: baseParams
})
// 过滤阈值不达标的结果
const filteredResults = searchResults.filter((item) => item.score >= threshold)
// 如果有rerank模型执行重排
let rerankResults = filteredResults
if (base.rerankModel && filteredResults.length > 0) {
rerankResults = await window.api.knowledgeBase.rerank({
search: rewrite || query,
base: baseParams,
results: filteredResults
})
}
// 限制文档数量
const limitedResults = rerankResults.slice(0, documentCount)
// 处理文件信息
return await Promise.all(
limitedResults.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})
)
} catch (error) {
Logger.error(`Error searching knowledge base ${base.name}:`, error)
return []
}
}
export const processKnowledgeSearch = async (
extractResults: ExtractResults,
knowledgeBaseIds: string[] | undefined
@ -100,6 +145,7 @@ export const processKnowledgeSearch = async (
Logger.log('No valid question found in extractResults.knowledge')
return []
}
const questions = extractResults.knowledge.question
const rewrite = extractResults.knowledge.rewrite
@ -109,73 +155,35 @@ export const processKnowledgeSearch = async (
return []
}
const referencesPromises = bases.map(async (base) => {
try {
const baseParams = getKnowledgeBaseParams(base)
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
// 为每个知识库执行多问题搜索
const baseSearchPromises = bases.map(async (base) => {
// 为每个问题搜索并合并结果
const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite)))
const allSearchResultsPromises = questions.map((question) =>
window.api.knowledgeBase
.search({
search: question,
base: baseParams
})
.then((results) =>
results.filter((item) => {
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
return item.score >= threshold
})
)
)
// 合并结果并去重
const flatResults = allResults.flat()
const uniqueResults = Array.from(
new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
).sort((a, b) => b.score - a.score)
const allSearchResults = await Promise.all(allSearchResultsPromises)
const searchResults = Array.from(
new Map(allSearchResults.flat().map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
).sort((a, b) => b.score - a.score)
Logger.log(`Knowledge base ${base.name} search results:`, searchResults)
let rerankResults = searchResults
if (base.rerankModel && searchResults.length > 0) {
rerankResults = await window.api.knowledgeBase.rerank({
search: rewrite,
base: baseParams,
results: searchResults
})
}
if (rerankResults.length > 0) {
rerankResults = rerankResults.slice(0, documentCount)
}
const processdResults = await Promise.all(
rerankResults.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})
)
return await Promise.all(
processdResults.map(async (item, index) => {
// const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
return {
id: index + 1, // 搜索多个库会导致ID重复
// 转换为引用格式
return await Promise.all(
uniqueResults.map(
async (item, index) =>
({
id: index + 1,
content: item.pageContent,
sourceUrl: await getKnowledgeSourceUrl(item),
type: 'file' // 需要映射 baseItem.type是'localPathLoader' -> 'file'
} as KnowledgeReference
})
type: 'file'
}) as KnowledgeReference
)
} catch (error) {
Logger.error(`Error searching knowledge base ${base.name}:`, error)
return []
}
)
})
const resultsPerBase = await Promise.all(referencesPromises)
// 汇总所有知识库的结果
const resultsPerBase = await Promise.all(baseSearchPromises)
const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref)
// 重新为引用分配ID
return allReferencesRaw.map((ref, index) => ({
...ref,

View File

@ -372,7 +372,7 @@ export interface KnowledgeBase {
chunkOverlap?: number
threshold?: number
rerankModel?: Model
topN?: number
// topN?: number
}
export type KnowledgeBaseParams = {
@ -388,7 +388,7 @@ export type KnowledgeBaseParams = {
rerankBaseURL?: string
rerankModel?: string
rerankModelProvider?: string
topN?: number
documentCount?: number
}
export type GenerateImageParams = {