feat: combine to general reranker (#5818)

* feat: combine to general reranker

* chore

* chore: set not support provider

* chore: add i18n
This commit is contained in:
Chen Tao 2025-05-10 13:42:54 +08:00 committed by GitHub
parent ac0651a9f3
commit 03c562bf5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 85 additions and 201 deletions

View File

@ -17,6 +17,10 @@ export default abstract class BaseReranker {
* Get Rerank Request Url * Get Rerank Request Url
*/ */
protected getRerankUrl() { protected getRerankUrl() {
if (this.base.rerankModelProvider === 'dashscope') {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
let baseURL = this.base?.rerankBaseURL?.endsWith('/') let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1) ? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL : this.base.rerankBaseURL
@ -28,6 +32,56 @@ export default abstract class BaseReranker {
return `${baseURL}/rerank` return `${baseURL}/rerank`
} }
/**
* Get Rerank Request Body
*/
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
const provider = this.base.rerankModelProvider
const documents = searchResults.map((doc) => doc.pageContent)
const topN = this.base.topN || 5
if (provider === 'voyageai') {
return {
model: this.base.rerankModel,
query,
documents,
top_k: topN
}
} else if (provider === 'dashscope') {
return {
model: this.base.rerankModel,
input: {
query,
documents
},
parameters: {
top_n: topN
}
}
} else {
return {
model: this.base.rerankModel,
query,
documents,
top_n: topN
}
}
}
/**
* Extract Rerank Result
*/
protected extractRerankResult(data: any) {
const provider = this.base.rerankModelProvider
if (provider === 'dashscope') {
return data.output.results
} else if (provider === 'voyageai') {
return data.data
} else {
return data.results
}
}
/** /**
* Get Rerank Result * Get Rerank Result
* @param searchResults * @param searchResults

View File

@ -1,58 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
interface DashscopeRerankResultItem {
document: {
text: string
}
index: number
relevance_score: number
}
interface DashscopeRerankResponse {
output: {
results: DashscopeRerankResultItem[]
}
usage: {
total_tokens: number
}
request_id: string
}
export default class DashscopeReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
const requestBody = {
model: this.base.rerankModel,
input: {
query,
documents: searchResults.map((doc) => doc.pageContent)
},
parameters: {
return_documents: true, // Recommended to be true to get document details if needed, though scores are primary
top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example
}
}
try {
const { data } = await axiosProxy.axios.post<DashscopeRerankResponse>(url, requestBody, {
headers: this.defaultHeaders()
})
const rerankResults = data.output.results
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Dashscope Reranker API 错误:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -1,14 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class DefaultReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
async rerank(): Promise<ExtractChunkData[]> {
throw new Error('Method not implemented.')
}
}

View File

@ -4,7 +4,7 @@ import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker' import BaseReranker from './BaseReranker'
export default class JinaReranker extends BaseReranker { export default class GeneralReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) { constructor(base: KnowledgeBaseParams) {
super(base) super(base)
} }
@ -12,21 +12,15 @@ export default class JinaReranker extends BaseReranker {
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => { public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl() const url = this.getRerankUrl()
const requestBody = { const requestBody = this.getRerankRequestBody(query, searchResults)
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN
}
try { try {
const { data } = await AxiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() }) const { data } = await AxiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results const rerankResults = this.extractRerankResult(data)
return this.getRerankResult(searchResults, rerankResults) return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) { } catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody) const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Jina Reranker API Error:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`) throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
} }
} }

View File

@ -1,13 +1,12 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types' import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker' import GeneralReranker from './GeneralReranker'
import RerankerFactory from './RerankerFactory'
export default class Reranker { export default class Reranker {
private sdk: BaseReranker private sdk: GeneralReranker
constructor(base: KnowledgeBaseParams) { constructor(base: KnowledgeBaseParams) {
this.sdk = RerankerFactory.create(base) this.sdk = new GeneralReranker(base)
} }
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> { public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
return this.sdk.rerank(query, searchResults) return this.sdk.rerank(query, searchResults)

View File

@ -1,23 +0,0 @@
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import DashscopeReranker from './DashscopeReranker'
import DefaultReranker from './DefaultReranker'
import JinaReranker from './JinaReranker'
import SiliconFlowReranker from './SiliconFlowReranker'
import VoyageReranker from './VoyageReranker'
export default class RerankerFactory {
static create(base: KnowledgeBaseParams): BaseReranker {
if (base.rerankModelProvider === 'silicon') {
return new SiliconFlowReranker(base)
} else if (base.rerankModelProvider === 'jina') {
return new JinaReranker(base)
} else if (base.rerankModelProvider === 'voyageai') {
return new VoyageReranker(base)
} else if (base.rerankModelProvider === 'dashscope') {
return new DashscopeReranker(base)
}
return new DefaultReranker(base)
}
}

View File

@ -1,36 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class SiliconFlowReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN,
max_chunks_per_doc: this.base.chunkSize,
overlap_tokens: this.base.chunkOverlap
}
try {
const { data } = await axiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('SiliconFlow Reranker API 错误:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -1,40 +0,0 @@
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class VoyageReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_k: this.base.topN,
return_documents: false,
truncation: true
}
try {
const { data } = await axiosProxy.axios.post(url, requestBody, {
headers: {
...this.defaultHeaders()
}
})
const rerankResults = data.data
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Voyage Reranker API Error:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -95,7 +95,8 @@ export function getProviderLogo(providerId: string) {
return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP] return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP]
} }
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix'] // export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
export const NOT_SUPPORTED_REANK_PROVIDERS = ['ollama']
export const PROVIDER_CONFIG = { export const PROVIDER_CONFIG = {
openai: { openai: {

View File

@ -703,6 +703,7 @@
"pinned": "Pinned", "pinned": "Pinned",
"rerank_model": "Reordering Model", "rerank_model": "Reordering Model",
"rerank_model_support_provider": "Currently, the reordering model only supports some providers ({{provider}})", "rerank_model_support_provider": "Currently, the reordering model only supports some providers ({{provider}})",
"rerank_model_not_support_provider": "Currently, the reordering model does not support this provider ({{provider}})",
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.", "rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
"search": "Search models...", "search": "Search models...",
"stream_output": "Stream output", "stream_output": "Stream output",

View File

@ -717,7 +717,8 @@
"text": "テキスト", "text": "テキスト",
"vision": "画像", "vision": "画像",
"websearch": "ウェブ検索" "websearch": "ウェブ検索"
} },
"rerank_model_not_support_provider": "現在、並べ替えモデルはこのプロバイダー ({{provider}}) をサポートしていません。"
}, },
"navbar": { "navbar": {
"expand": "ダイアログを展開", "expand": "ダイアログを展開",

View File

@ -717,7 +717,8 @@
"text": "Текст", "text": "Текст",
"vision": "Визуальные", "vision": "Визуальные",
"websearch": "Веб-поисковые" "websearch": "Веб-поисковые"
} },
"rerank_model_not_support_provider": "В настоящее время модель переупорядочивания не поддерживает этого провайдера ({{provider}})"
}, },
"navbar": { "navbar": {
"expand": "Развернуть диалоговое окно", "expand": "Развернуть диалоговое окно",

View File

@ -703,6 +703,7 @@
"pinned": "已固定", "pinned": "已固定",
"rerank_model": "重排模型", "rerank_model": "重排模型",
"rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})", "rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})",
"rerank_model_not_support_provider": "目前重排序模型不支持该服务商 ({{provider}})",
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加", "rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"search": "搜索模型...", "search": "搜索模型...",
"stream_output": "流式输出", "stream_output": "流式输出",

View File

@ -717,7 +717,8 @@
"text": "文字", "text": "文字",
"vision": "視覺", "vision": "視覺",
"websearch": "網路搜尋" "websearch": "網路搜尋"
} },
"rerank_model_not_support_provider": "目前,重新排序模型不支援此提供者({{provider}}"
}, },
"navbar": { "navbar": {
"expand": "伸縮對話框", "expand": "伸縮對話框",

View File

@ -1,7 +1,8 @@
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers' import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge' import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import { SettingHelpText } from '@renderer/pages/settings' import { SettingHelpText } from '@renderer/pages/settings'
@ -67,7 +68,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
const rerankSelectOptions = providers const rerankSelectOptions = providers
.filter((p) => p.models.length > 0) .filter((p) => p.models.length > 0)
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id)) .filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({ .map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name, label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name, title: p.name,
@ -176,8 +177,8 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} /> <Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
</Form.Item> </Form.Item>
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}> <SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
{t('models.rerank_model_support_provider', { {t('models.rerank_model_not_support_provider', {
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`)) provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
})} })}
</SettingHelpText> </SettingHelpText>
<Form.Item <Form.Item

View File

@ -3,7 +3,8 @@ import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings' import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers' import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledge } from '@renderer/hooks/useKnowledge' import { useKnowledge } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import { SettingHelpText } from '@renderer/pages/settings' import { SettingHelpText } from '@renderer/pages/settings'
@ -68,7 +69,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
const rerankSelectOptions = providers const rerankSelectOptions = providers
.filter((p) => p.models.length > 0) .filter((p) => p.models.length > 0)
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id)) .filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({ .map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name, label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name, title: p.name,
@ -157,8 +158,8 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
/> />
</Form.Item> </Form.Item>
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}> <SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
{t('models.rerank_model_support_provider', { {t('models.rerank_model_not_support_provider', {
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`)) provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
})} })}
</SettingHelpText> </SettingHelpText>