mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-30 07:39:06 +08:00
fix(knowledge): remove topN
This commit is contained in:
parent
53cff11726
commit
04c3911243
@ -38,7 +38,7 @@ export default abstract class BaseReranker {
|
||||
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
||||
const provider = this.base.rerankModelProvider
|
||||
const documents = searchResults.map((doc) => doc.pageContent)
|
||||
const topN = this.base.topN || 10
|
||||
const topN = this.base.documentCount
|
||||
|
||||
if (provider === 'voyageai') {
|
||||
return {
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import { CopyOutlined } from '@ant-design/icons'
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
|
||||
import { getFileFromUrl, getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
|
||||
import { searchKnowledgeBase } from '@renderer/services/KnowledgeService'
|
||||
import { FileType, KnowledgeBase } from '@renderer/types'
|
||||
import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd'
|
||||
import { useRef, useState } from 'react'
|
||||
@ -38,29 +37,8 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
|
||||
setSearchKeyword(value.trim())
|
||||
setLoading(true)
|
||||
try {
|
||||
const searchResults = await window.api.knowledgeBase.search({
|
||||
search: value,
|
||||
base: getKnowledgeBaseParams(base)
|
||||
})
|
||||
let rerankResult = searchResults
|
||||
if (base.rerankModel) {
|
||||
rerankResult = await window.api.knowledgeBase.rerank({
|
||||
search: value,
|
||||
base: getKnowledgeBaseParams(base),
|
||||
results: searchResults
|
||||
})
|
||||
}
|
||||
const results = await Promise.all(
|
||||
rerankResult.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
)
|
||||
const filteredResults = results.filter((item) => {
|
||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
||||
return item.score >= threshold
|
||||
})
|
||||
setResults(filteredResults)
|
||||
const searchResults = await searchKnowledgeBase(value, base)
|
||||
setResults(searchResults)
|
||||
} catch (error) {
|
||||
console.error('Search failed:', error)
|
||||
} finally {
|
||||
|
||||
@ -29,7 +29,6 @@ interface FormData {
|
||||
chunkOverlap?: number
|
||||
threshold?: number
|
||||
rerankModel?: string
|
||||
topN?: number
|
||||
}
|
||||
|
||||
interface Props extends ShowParams {
|
||||
@ -95,8 +94,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
threshold: values.threshold ?? undefined,
|
||||
rerankModel: values.rerankModel
|
||||
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
|
||||
: undefined,
|
||||
topN: values.topN
|
||||
: undefined
|
||||
}
|
||||
updateKnowledgeBase(newBase)
|
||||
setOpen(false)
|
||||
@ -283,23 +281,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="topN"
|
||||
label={t('knowledge.topN')}
|
||||
layout="horizontal"
|
||||
initialValue={base.topN}
|
||||
rules={[
|
||||
{
|
||||
validator(_, value) {
|
||||
if (value && (value < 0 || value > 30)) {
|
||||
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
|
||||
}
|
||||
return Promise.resolve()
|
||||
}
|
||||
}
|
||||
]}>
|
||||
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
|
||||
</Form.Item>
|
||||
<Alert
|
||||
message={t('knowledge.chunk_size_change_warning')}
|
||||
type="warning"
|
||||
|
||||
@ -47,8 +47,8 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
|
||||
rerankBaseURL: rerankHost,
|
||||
rerankApiKey: rerankAiProvider.getApiKey() || 'secret',
|
||||
rerankModel: base.rerankModel?.id,
|
||||
rerankModelProvider: base.rerankModel?.provider,
|
||||
topN: base.topN
|
||||
rerankModelProvider: base.rerankModel?.provider
|
||||
// topN: base.topN
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,6 +88,51 @@ export const getKnowledgeSourceUrl = async (item: ExtractChunkData & { file: Fil
|
||||
return item.metadata.source
|
||||
}
|
||||
|
||||
export const searchKnowledgeBase = async (
|
||||
query: string,
|
||||
base: KnowledgeBase,
|
||||
rewrite?: string
|
||||
): Promise<Array<ExtractChunkData & { file: FileType | null }>> => {
|
||||
try {
|
||||
const baseParams = getKnowledgeBaseParams(base)
|
||||
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
||||
|
||||
// 执行搜索
|
||||
const searchResults = await window.api.knowledgeBase.search({
|
||||
search: query,
|
||||
base: baseParams
|
||||
})
|
||||
|
||||
// 过滤阈值不达标的结果
|
||||
const filteredResults = searchResults.filter((item) => item.score >= threshold)
|
||||
|
||||
// 如果有rerank模型,执行重排
|
||||
let rerankResults = filteredResults
|
||||
if (base.rerankModel && filteredResults.length > 0) {
|
||||
rerankResults = await window.api.knowledgeBase.rerank({
|
||||
search: rewrite || query,
|
||||
base: baseParams,
|
||||
results: filteredResults
|
||||
})
|
||||
}
|
||||
|
||||
// 限制文档数量
|
||||
const limitedResults = rerankResults.slice(0, documentCount)
|
||||
|
||||
// 处理文件信息
|
||||
return await Promise.all(
|
||||
limitedResults.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
)
|
||||
} catch (error) {
|
||||
Logger.error(`Error searching knowledge base ${base.name}:`, error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export const processKnowledgeSearch = async (
|
||||
extractResults: ExtractResults,
|
||||
knowledgeBaseIds: string[] | undefined
|
||||
@ -100,6 +145,7 @@ export const processKnowledgeSearch = async (
|
||||
Logger.log('No valid question found in extractResults.knowledge')
|
||||
return []
|
||||
}
|
||||
|
||||
const questions = extractResults.knowledge.question
|
||||
const rewrite = extractResults.knowledge.rewrite
|
||||
|
||||
@ -109,73 +155,35 @@ export const processKnowledgeSearch = async (
|
||||
return []
|
||||
}
|
||||
|
||||
const referencesPromises = bases.map(async (base) => {
|
||||
try {
|
||||
const baseParams = getKnowledgeBaseParams(base)
|
||||
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
||||
// 为每个知识库执行多问题搜索
|
||||
const baseSearchPromises = bases.map(async (base) => {
|
||||
// 为每个问题搜索并合并结果
|
||||
const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite)))
|
||||
|
||||
const allSearchResultsPromises = questions.map((question) =>
|
||||
window.api.knowledgeBase
|
||||
.search({
|
||||
search: question,
|
||||
base: baseParams
|
||||
})
|
||||
.then((results) =>
|
||||
results.filter((item) => {
|
||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
||||
return item.score >= threshold
|
||||
})
|
||||
)
|
||||
)
|
||||
// 合并结果并去重
|
||||
const flatResults = allResults.flat()
|
||||
const uniqueResults = Array.from(
|
||||
new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
|
||||
).sort((a, b) => b.score - a.score)
|
||||
|
||||
const allSearchResults = await Promise.all(allSearchResultsPromises)
|
||||
|
||||
const searchResults = Array.from(
|
||||
new Map(allSearchResults.flat().map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
|
||||
).sort((a, b) => b.score - a.score)
|
||||
|
||||
Logger.log(`Knowledge base ${base.name} search results:`, searchResults)
|
||||
|
||||
let rerankResults = searchResults
|
||||
if (base.rerankModel && searchResults.length > 0) {
|
||||
rerankResults = await window.api.knowledgeBase.rerank({
|
||||
search: rewrite,
|
||||
base: baseParams,
|
||||
results: searchResults
|
||||
})
|
||||
}
|
||||
|
||||
if (rerankResults.length > 0) {
|
||||
rerankResults = rerankResults.slice(0, documentCount)
|
||||
}
|
||||
|
||||
const processdResults = await Promise.all(
|
||||
rerankResults.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
)
|
||||
|
||||
return await Promise.all(
|
||||
processdResults.map(async (item, index) => {
|
||||
// const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
|
||||
return {
|
||||
id: index + 1, // 搜索多个库会导致ID重复
|
||||
// 转换为引用格式
|
||||
return await Promise.all(
|
||||
uniqueResults.map(
|
||||
async (item, index) =>
|
||||
({
|
||||
id: index + 1,
|
||||
content: item.pageContent,
|
||||
sourceUrl: await getKnowledgeSourceUrl(item),
|
||||
type: 'file' // 需要映射 baseItem.type是'localPathLoader' -> 'file'
|
||||
} as KnowledgeReference
|
||||
})
|
||||
type: 'file'
|
||||
}) as KnowledgeReference
|
||||
)
|
||||
} catch (error) {
|
||||
Logger.error(`Error searching knowledge base ${base.name}:`, error)
|
||||
return []
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
const resultsPerBase = await Promise.all(referencesPromises)
|
||||
|
||||
// 汇总所有知识库的结果
|
||||
const resultsPerBase = await Promise.all(baseSearchPromises)
|
||||
const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref)
|
||||
|
||||
// 重新为引用分配ID
|
||||
return allReferencesRaw.map((ref, index) => ({
|
||||
...ref,
|
||||
|
||||
@ -372,7 +372,7 @@ export interface KnowledgeBase {
|
||||
chunkOverlap?: number
|
||||
threshold?: number
|
||||
rerankModel?: Model
|
||||
topN?: number
|
||||
// topN?: number
|
||||
}
|
||||
|
||||
export type KnowledgeBaseParams = {
|
||||
@ -388,7 +388,7 @@ export type KnowledgeBaseParams = {
|
||||
rerankBaseURL?: string
|
||||
rerankModel?: string
|
||||
rerankModelProvider?: string
|
||||
topN?: number
|
||||
documentCount?: number
|
||||
}
|
||||
|
||||
export type GenerateImageParams = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user