feat: combine to general reranker (#5818)

* feat: combine to general reranker

* chore

* chore: set not support provider

* chore: add i18n
This commit is contained in:
Chen Tao 2025-05-10 13:42:54 +08:00 committed by GitHub
parent ac0651a9f3
commit 03c562bf5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 85 additions and 201 deletions

View File

@ -17,6 +17,10 @@ export default abstract class BaseReranker {
* Get Rerank Request Url
*/
protected getRerankUrl() {
if (this.base.rerankModelProvider === 'dashscope') {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL
@ -28,6 +32,56 @@ export default abstract class BaseReranker {
return `${baseURL}/rerank`
}
/**
* Get Rerank Request Body
*/
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
const provider = this.base.rerankModelProvider
const documents = searchResults.map((doc) => doc.pageContent)
const topN = this.base.topN || 5
if (provider === 'voyageai') {
return {
model: this.base.rerankModel,
query,
documents,
top_k: topN
}
} else if (provider === 'dashscope') {
return {
model: this.base.rerankModel,
input: {
query,
documents
},
parameters: {
top_n: topN
}
}
} else {
return {
model: this.base.rerankModel,
query,
documents,
top_n: topN
}
}
}
/**
* Extract Rerank Result
*/
protected extractRerankResult(data: any) {
const provider = this.base.rerankModelProvider
if (provider === 'dashscope') {
return data.output.results
} else if (provider === 'voyageai') {
return data.data
} else {
return data.results
}
}
/**
* Get Rerank Result
* @param searchResults

View File

@ -1,58 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
interface DashscopeRerankResultItem {
document: {
text: string
}
index: number
relevance_score: number
}
interface DashscopeRerankResponse {
output: {
results: DashscopeRerankResultItem[]
}
usage: {
total_tokens: number
}
request_id: string
}
export default class DashscopeReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
const requestBody = {
model: this.base.rerankModel,
input: {
query,
documents: searchResults.map((doc) => doc.pageContent)
},
parameters: {
return_documents: true, // Recommended to be true to get document details if needed, though scores are primary
top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example
}
}
try {
const { data } = await axiosProxy.axios.post<DashscopeRerankResponse>(url, requestBody, {
headers: this.defaultHeaders()
})
const rerankResults = data.output.results
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Dashscope Reranker API 错误:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -1,14 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class DefaultReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
async rerank(): Promise<ExtractChunkData[]> {
throw new Error('Method not implemented.')
}
}

View File

@ -4,7 +4,7 @@ import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class JinaReranker extends BaseReranker {
export default class GeneralReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
@ -12,21 +12,15 @@ export default class JinaReranker extends BaseReranker {
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN
}
const requestBody = this.getRerankRequestBody(query, searchResults)
try {
const { data } = await AxiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results
const rerankResults = this.extractRerankResult(data)
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Jina Reranker API Error:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}

View File

@ -1,13 +1,12 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import RerankerFactory from './RerankerFactory'
import GeneralReranker from './GeneralReranker'
export default class Reranker {
private sdk: BaseReranker
private sdk: GeneralReranker
constructor(base: KnowledgeBaseParams) {
this.sdk = RerankerFactory.create(base)
this.sdk = new GeneralReranker(base)
}
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
return this.sdk.rerank(query, searchResults)

View File

@ -1,23 +0,0 @@
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import DashscopeReranker from './DashscopeReranker'
import DefaultReranker from './DefaultReranker'
import JinaReranker from './JinaReranker'
import SiliconFlowReranker from './SiliconFlowReranker'
import VoyageReranker from './VoyageReranker'
export default class RerankerFactory {
static create(base: KnowledgeBaseParams): BaseReranker {
if (base.rerankModelProvider === 'silicon') {
return new SiliconFlowReranker(base)
} else if (base.rerankModelProvider === 'jina') {
return new JinaReranker(base)
} else if (base.rerankModelProvider === 'voyageai') {
return new VoyageReranker(base)
} else if (base.rerankModelProvider === 'dashscope') {
return new DashscopeReranker(base)
}
return new DefaultReranker(base)
}
}

View File

@ -1,36 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class SiliconFlowReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN,
max_chunks_per_doc: this.base.chunkSize,
overlap_tokens: this.base.chunkOverlap
}
try {
const { data } = await axiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('SiliconFlow Reranker API 错误:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -1,40 +0,0 @@
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class VoyageReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_k: this.base.topN,
return_documents: false,
truncation: true
}
try {
const { data } = await axiosProxy.axios.post(url, requestBody, {
headers: {
...this.defaultHeaders()
}
})
const rerankResults = data.data
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Voyage Reranker API Error:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -95,7 +95,8 @@ export function getProviderLogo(providerId: string) {
return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP]
}
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
// export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
export const NOT_SUPPORTED_REANK_PROVIDERS = ['ollama']
export const PROVIDER_CONFIG = {
openai: {

View File

@ -703,6 +703,7 @@
"pinned": "Pinned",
"rerank_model": "Reordering Model",
"rerank_model_support_provider": "Currently, the reordering model only supports some providers ({{provider}})",
"rerank_model_not_support_provider": "Currently, the reordering model does not support this provider ({{provider}})",
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
"search": "Search models...",
"stream_output": "Stream output",
@ -1639,4 +1640,4 @@
"visualization": "Visualization"
}
}
}
}

View File

@ -717,7 +717,8 @@
"text": "テキスト",
"vision": "画像",
"websearch": "ウェブ検索"
}
},
"rerank_model_not_support_provider": "現在、並べ替えモデルはこのプロバイダー ({{provider}}) をサポートしていません。"
},
"navbar": {
"expand": "ダイアログを展開",
@ -1639,4 +1640,4 @@
"visualization": "可視化"
}
}
}
}

View File

@ -717,7 +717,8 @@
"text": "Текст",
"vision": "Визуальные",
"websearch": "Веб-поисковые"
}
},
"rerank_model_not_support_provider": "В настоящее время модель переупорядочивания не поддерживает этого провайдера ({{provider}})"
},
"navbar": {
"expand": "Развернуть диалоговое окно",
@ -1639,4 +1640,4 @@
"visualization": "Визуализация"
}
}
}
}

View File

@ -703,6 +703,7 @@
"pinned": "已固定",
"rerank_model": "重排模型",
"rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})",
"rerank_model_not_support_provider": "目前重排序模型不支持该服务商 ({{provider}})",
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"search": "搜索模型...",
"stream_output": "流式输出",
@ -1639,4 +1640,4 @@
"visualization": "可视化"
}
}
}
}

View File

@ -717,7 +717,8 @@
"text": "文字",
"vision": "視覺",
"websearch": "網路搜尋"
}
},
"rerank_model_not_support_provider": "目前,重新排序模型不支援此提供者({{provider}}"
},
"navbar": {
"expand": "伸縮對話框",
@ -1639,4 +1640,4 @@
"visualization": "視覺化"
}
}
}
}

View File

@ -1,7 +1,8 @@
import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider'
import { SettingHelpText } from '@renderer/pages/settings'
@ -67,7 +68,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
const rerankSelectOptions = providers
.filter((p) => p.models.length > 0)
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id))
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
@ -176,8 +177,8 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
</Form.Item>
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
{t('models.rerank_model_support_provider', {
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
{t('models.rerank_model_not_support_provider', {
provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
})}
</SettingHelpText>
<Form.Item

View File

@ -3,7 +3,8 @@ import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledge } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider'
import { SettingHelpText } from '@renderer/pages/settings'
@ -68,7 +69,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
const rerankSelectOptions = providers
.filter((p) => p.models.length > 0)
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id))
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
@ -157,8 +158,8 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
/>
</Form.Item>
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
{t('models.rerank_model_support_provider', {
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
{t('models.rerank_model_not_support_provider', {
provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
})}
</SettingHelpText>