mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 14:29:15 +08:00
fix(knowledge): remove topN
This commit is contained in:
parent
e51a37cc74
commit
ba88a24455
@ -38,7 +38,7 @@ export default abstract class BaseReranker {
|
|||||||
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
||||||
const provider = this.base.rerankModelProvider
|
const provider = this.base.rerankModelProvider
|
||||||
const documents = searchResults.map((doc) => doc.pageContent)
|
const documents = searchResults.map((doc) => doc.pageContent)
|
||||||
const topN = this.base.topN || 10
|
const topN = this.base.documentCount
|
||||||
|
|
||||||
if (provider === 'voyageai') {
|
if (provider === 'voyageai') {
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import { CopyOutlined } from '@ant-design/icons'
|
import { CopyOutlined } from '@ant-design/icons'
|
||||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||||
import { TopView } from '@renderer/components/TopView'
|
import { TopView } from '@renderer/components/TopView'
|
||||||
import { DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
|
import { searchKnowledgeBase } from '@renderer/services/KnowledgeService'
|
||||||
import { getFileFromUrl, getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
|
|
||||||
import { FileType, KnowledgeBase } from '@renderer/types'
|
import { FileType, KnowledgeBase } from '@renderer/types'
|
||||||
import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd'
|
import { Input, List, message, Modal, Spin, Tooltip, Typography } from 'antd'
|
||||||
import { useRef, useState } from 'react'
|
import { useRef, useState } from 'react'
|
||||||
@ -38,29 +37,8 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
|
|||||||
setSearchKeyword(value.trim())
|
setSearchKeyword(value.trim())
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
try {
|
try {
|
||||||
const searchResults = await window.api.knowledgeBase.search({
|
const searchResults = await searchKnowledgeBase(value, base)
|
||||||
search: value,
|
setResults(searchResults)
|
||||||
base: getKnowledgeBaseParams(base)
|
|
||||||
})
|
|
||||||
let rerankResult = searchResults
|
|
||||||
if (base.rerankModel) {
|
|
||||||
rerankResult = await window.api.knowledgeBase.rerank({
|
|
||||||
search: value,
|
|
||||||
base: getKnowledgeBaseParams(base),
|
|
||||||
results: searchResults
|
|
||||||
})
|
|
||||||
}
|
|
||||||
const results = await Promise.all(
|
|
||||||
rerankResult.map(async (item) => {
|
|
||||||
const file = await getFileFromUrl(item.metadata.source)
|
|
||||||
return { ...item, file }
|
|
||||||
})
|
|
||||||
)
|
|
||||||
const filteredResults = results.filter((item) => {
|
|
||||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
|
||||||
return item.score >= threshold
|
|
||||||
})
|
|
||||||
setResults(filteredResults)
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Search failed:', error)
|
console.error('Search failed:', error)
|
||||||
} finally {
|
} finally {
|
||||||
|
|||||||
@ -29,7 +29,6 @@ interface FormData {
|
|||||||
chunkOverlap?: number
|
chunkOverlap?: number
|
||||||
threshold?: number
|
threshold?: number
|
||||||
rerankModel?: string
|
rerankModel?: string
|
||||||
topN?: number
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface Props extends ShowParams {
|
interface Props extends ShowParams {
|
||||||
@ -95,8 +94,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
|||||||
threshold: values.threshold ?? undefined,
|
threshold: values.threshold ?? undefined,
|
||||||
rerankModel: values.rerankModel
|
rerankModel: values.rerankModel
|
||||||
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
|
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
|
||||||
: undefined,
|
: undefined
|
||||||
topN: values.topN
|
|
||||||
}
|
}
|
||||||
updateKnowledgeBase(newBase)
|
updateKnowledgeBase(newBase)
|
||||||
setOpen(false)
|
setOpen(false)
|
||||||
@ -283,23 +281,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
|||||||
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
|
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
<Form.Item
|
|
||||||
name="topN"
|
|
||||||
label={t('knowledge.topN')}
|
|
||||||
layout="horizontal"
|
|
||||||
initialValue={base.topN}
|
|
||||||
rules={[
|
|
||||||
{
|
|
||||||
validator(_, value) {
|
|
||||||
if (value && (value < 0 || value > 30)) {
|
|
||||||
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
|
|
||||||
}
|
|
||||||
return Promise.resolve()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]}>
|
|
||||||
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
|
|
||||||
</Form.Item>
|
|
||||||
<Alert
|
<Alert
|
||||||
message={t('knowledge.chunk_size_change_warning')}
|
message={t('knowledge.chunk_size_change_warning')}
|
||||||
type="warning"
|
type="warning"
|
||||||
|
|||||||
@ -47,8 +47,8 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
|
|||||||
rerankBaseURL: rerankHost,
|
rerankBaseURL: rerankHost,
|
||||||
rerankApiKey: rerankAiProvider.getApiKey() || 'secret',
|
rerankApiKey: rerankAiProvider.getApiKey() || 'secret',
|
||||||
rerankModel: base.rerankModel?.id,
|
rerankModel: base.rerankModel?.id,
|
||||||
rerankModelProvider: base.rerankModel?.provider,
|
rerankModelProvider: base.rerankModel?.provider
|
||||||
topN: base.topN
|
// topN: base.topN
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,6 +88,51 @@ export const getKnowledgeSourceUrl = async (item: ExtractChunkData & { file: Fil
|
|||||||
return item.metadata.source
|
return item.metadata.source
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const searchKnowledgeBase = async (
|
||||||
|
query: string,
|
||||||
|
base: KnowledgeBase,
|
||||||
|
rewrite?: string
|
||||||
|
): Promise<Array<ExtractChunkData & { file: FileType | null }>> => {
|
||||||
|
try {
|
||||||
|
const baseParams = getKnowledgeBaseParams(base)
|
||||||
|
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
||||||
|
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
||||||
|
|
||||||
|
// 执行搜索
|
||||||
|
const searchResults = await window.api.knowledgeBase.search({
|
||||||
|
search: query,
|
||||||
|
base: baseParams
|
||||||
|
})
|
||||||
|
|
||||||
|
// 过滤阈值不达标的结果
|
||||||
|
const filteredResults = searchResults.filter((item) => item.score >= threshold)
|
||||||
|
|
||||||
|
// 如果有rerank模型,执行重排
|
||||||
|
let rerankResults = filteredResults
|
||||||
|
if (base.rerankModel && filteredResults.length > 0) {
|
||||||
|
rerankResults = await window.api.knowledgeBase.rerank({
|
||||||
|
search: rewrite || query,
|
||||||
|
base: baseParams,
|
||||||
|
results: filteredResults
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 限制文档数量
|
||||||
|
const limitedResults = rerankResults.slice(0, documentCount)
|
||||||
|
|
||||||
|
// 处理文件信息
|
||||||
|
return await Promise.all(
|
||||||
|
limitedResults.map(async (item) => {
|
||||||
|
const file = await getFileFromUrl(item.metadata.source)
|
||||||
|
return { ...item, file }
|
||||||
|
})
|
||||||
|
)
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error(`Error searching knowledge base ${base.name}:`, error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export const processKnowledgeSearch = async (
|
export const processKnowledgeSearch = async (
|
||||||
extractResults: ExtractResults,
|
extractResults: ExtractResults,
|
||||||
knowledgeBaseIds: string[] | undefined
|
knowledgeBaseIds: string[] | undefined
|
||||||
@ -100,6 +145,7 @@ export const processKnowledgeSearch = async (
|
|||||||
Logger.log('No valid question found in extractResults.knowledge')
|
Logger.log('No valid question found in extractResults.knowledge')
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
const questions = extractResults.knowledge.question
|
const questions = extractResults.knowledge.question
|
||||||
const rewrite = extractResults.knowledge.rewrite
|
const rewrite = extractResults.knowledge.rewrite
|
||||||
|
|
||||||
@ -109,73 +155,35 @@ export const processKnowledgeSearch = async (
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
const referencesPromises = bases.map(async (base) => {
|
// 为每个知识库执行多问题搜索
|
||||||
try {
|
const baseSearchPromises = bases.map(async (base) => {
|
||||||
const baseParams = getKnowledgeBaseParams(base)
|
// 为每个问题搜索并合并结果
|
||||||
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite)))
|
||||||
|
|
||||||
const allSearchResultsPromises = questions.map((question) =>
|
// 合并结果并去重
|
||||||
window.api.knowledgeBase
|
const flatResults = allResults.flat()
|
||||||
.search({
|
const uniqueResults = Array.from(
|
||||||
search: question,
|
new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
|
||||||
base: baseParams
|
).sort((a, b) => b.score - a.score)
|
||||||
})
|
|
||||||
.then((results) =>
|
|
||||||
results.filter((item) => {
|
|
||||||
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
|
|
||||||
return item.score >= threshold
|
|
||||||
})
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
const allSearchResults = await Promise.all(allSearchResultsPromises)
|
// 转换为引用格式
|
||||||
|
return await Promise.all(
|
||||||
const searchResults = Array.from(
|
uniqueResults.map(
|
||||||
new Map(allSearchResults.flat().map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
|
async (item, index) =>
|
||||||
).sort((a, b) => b.score - a.score)
|
({
|
||||||
|
id: index + 1,
|
||||||
Logger.log(`Knowledge base ${base.name} search results:`, searchResults)
|
|
||||||
|
|
||||||
let rerankResults = searchResults
|
|
||||||
if (base.rerankModel && searchResults.length > 0) {
|
|
||||||
rerankResults = await window.api.knowledgeBase.rerank({
|
|
||||||
search: rewrite,
|
|
||||||
base: baseParams,
|
|
||||||
results: searchResults
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if (rerankResults.length > 0) {
|
|
||||||
rerankResults = rerankResults.slice(0, documentCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
const processdResults = await Promise.all(
|
|
||||||
rerankResults.map(async (item) => {
|
|
||||||
const file = await getFileFromUrl(item.metadata.source)
|
|
||||||
return { ...item, file }
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
return await Promise.all(
|
|
||||||
processdResults.map(async (item, index) => {
|
|
||||||
// const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
|
|
||||||
return {
|
|
||||||
id: index + 1, // 搜索多个库会导致ID重复
|
|
||||||
content: item.pageContent,
|
content: item.pageContent,
|
||||||
sourceUrl: await getKnowledgeSourceUrl(item),
|
sourceUrl: await getKnowledgeSourceUrl(item),
|
||||||
type: 'file' // 需要映射 baseItem.type是'localPathLoader' -> 'file'
|
type: 'file'
|
||||||
} as KnowledgeReference
|
}) as KnowledgeReference
|
||||||
})
|
|
||||||
)
|
)
|
||||||
} catch (error) {
|
)
|
||||||
Logger.error(`Error searching knowledge base ${base.name}:`, error)
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const resultsPerBase = await Promise.all(referencesPromises)
|
// 汇总所有知识库的结果
|
||||||
|
const resultsPerBase = await Promise.all(baseSearchPromises)
|
||||||
const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref)
|
const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref)
|
||||||
|
|
||||||
// 重新为引用分配ID
|
// 重新为引用分配ID
|
||||||
return allReferencesRaw.map((ref, index) => ({
|
return allReferencesRaw.map((ref, index) => ({
|
||||||
...ref,
|
...ref,
|
||||||
|
|||||||
@ -372,7 +372,7 @@ export interface KnowledgeBase {
|
|||||||
chunkOverlap?: number
|
chunkOverlap?: number
|
||||||
threshold?: number
|
threshold?: number
|
||||||
rerankModel?: Model
|
rerankModel?: Model
|
||||||
topN?: number
|
// topN?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export type KnowledgeBaseParams = {
|
export type KnowledgeBaseParams = {
|
||||||
@ -388,7 +388,7 @@ export type KnowledgeBaseParams = {
|
|||||||
rerankBaseURL?: string
|
rerankBaseURL?: string
|
||||||
rerankModel?: string
|
rerankModel?: string
|
||||||
rerankModelProvider?: string
|
rerankModelProvider?: string
|
||||||
topN?: number
|
documentCount?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export type GenerateImageParams = {
|
export type GenerateImageParams = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user