mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-23 10:00:08 +08:00
feat: combine to general reranker (#5818)
* feat: combine to general reranker * chore * chore: set not support provider * chore: add i18n
This commit is contained in:
parent
ac0651a9f3
commit
03c562bf5b
@ -17,6 +17,10 @@ export default abstract class BaseReranker {
|
||||
* Get Rerank Request Url
|
||||
*/
|
||||
protected getRerankUrl() {
|
||||
if (this.base.rerankModelProvider === 'dashscope') {
|
||||
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
}
|
||||
|
||||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||||
? this.base.rerankBaseURL.slice(0, -1)
|
||||
: this.base.rerankBaseURL
|
||||
@ -28,6 +32,56 @@ export default abstract class BaseReranker {
|
||||
return `${baseURL}/rerank`
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Rerank Request Body
|
||||
*/
|
||||
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
||||
const provider = this.base.rerankModelProvider
|
||||
const documents = searchResults.map((doc) => doc.pageContent)
|
||||
const topN = this.base.topN || 5
|
||||
|
||||
if (provider === 'voyageai') {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents,
|
||||
top_k: topN
|
||||
}
|
||||
} else if (provider === 'dashscope') {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
input: {
|
||||
query,
|
||||
documents
|
||||
},
|
||||
parameters: {
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents,
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract Rerank Result
|
||||
*/
|
||||
protected extractRerankResult(data: any) {
|
||||
const provider = this.base.rerankModelProvider
|
||||
if (provider === 'dashscope') {
|
||||
return data.output.results
|
||||
} else if (provider === 'voyageai') {
|
||||
return data.data
|
||||
} else {
|
||||
return data.results
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Rerank Result
|
||||
* @param searchResults
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import axiosProxy from '@main/services/AxiosProxy'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
interface DashscopeRerankResultItem {
|
||||
document: {
|
||||
text: string
|
||||
}
|
||||
index: number
|
||||
relevance_score: number
|
||||
}
|
||||
|
||||
interface DashscopeRerankResponse {
|
||||
output: {
|
||||
results: DashscopeRerankResultItem[]
|
||||
}
|
||||
usage: {
|
||||
total_tokens: number
|
||||
}
|
||||
request_id: string
|
||||
}
|
||||
|
||||
export default class DashscopeReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
input: {
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent)
|
||||
},
|
||||
parameters: {
|
||||
return_documents: true, // Recommended to be true to get document details if needed, though scores are primary
|
||||
top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axiosProxy.axios.post<DashscopeRerankResponse>(url, requestBody, {
|
||||
headers: this.defaultHeaders()
|
||||
})
|
||||
|
||||
const rerankResults = data.output.results
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
console.error('Dashscope Reranker API 错误:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,14 +0,0 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class DefaultReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
async rerank(): Promise<ExtractChunkData[]> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
}
|
||||
@ -4,7 +4,7 @@ import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class JinaReranker extends BaseReranker {
|
||||
export default class GeneralReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
@ -12,21 +12,15 @@ export default class JinaReranker extends BaseReranker {
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = this.getRerankUrl()
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_n: this.base.topN
|
||||
}
|
||||
const requestBody = this.getRerankRequestBody(query, searchResults)
|
||||
|
||||
try {
|
||||
const { data } = await AxiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||
|
||||
const rerankResults = data.results
|
||||
const rerankResults = this.extractRerankResult(data)
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
console.error('Jina Reranker API Error:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
@ -1,13 +1,12 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import RerankerFactory from './RerankerFactory'
|
||||
import GeneralReranker from './GeneralReranker'
|
||||
|
||||
export default class Reranker {
|
||||
private sdk: BaseReranker
|
||||
private sdk: GeneralReranker
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
this.sdk = RerankerFactory.create(base)
|
||||
this.sdk = new GeneralReranker(base)
|
||||
}
|
||||
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
|
||||
return this.sdk.rerank(query, searchResults)
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import DashscopeReranker from './DashscopeReranker'
|
||||
import DefaultReranker from './DefaultReranker'
|
||||
import JinaReranker from './JinaReranker'
|
||||
import SiliconFlowReranker from './SiliconFlowReranker'
|
||||
import VoyageReranker from './VoyageReranker'
|
||||
|
||||
export default class RerankerFactory {
|
||||
static create(base: KnowledgeBaseParams): BaseReranker {
|
||||
if (base.rerankModelProvider === 'silicon') {
|
||||
return new SiliconFlowReranker(base)
|
||||
} else if (base.rerankModelProvider === 'jina') {
|
||||
return new JinaReranker(base)
|
||||
} else if (base.rerankModelProvider === 'voyageai') {
|
||||
return new VoyageReranker(base)
|
||||
} else if (base.rerankModelProvider === 'dashscope') {
|
||||
return new DashscopeReranker(base)
|
||||
}
|
||||
return new DefaultReranker(base)
|
||||
}
|
||||
}
|
||||
@ -1,36 +0,0 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import axiosProxy from '@main/services/AxiosProxy'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class SiliconFlowReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = this.getRerankUrl()
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_n: this.base.topN,
|
||||
max_chunks_per_doc: this.base.chunkSize,
|
||||
overlap_tokens: this.base.chunkOverlap
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||
|
||||
const rerankResults = data.results
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
|
||||
console.error('SiliconFlow Reranker API 错误:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,40 +0,0 @@
|
||||
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import axiosProxy from '@main/services/AxiosProxy'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class VoyageReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = this.getRerankUrl()
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_k: this.base.topN,
|
||||
return_documents: false,
|
||||
truncation: true
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axiosProxy.axios.post(url, requestBody, {
|
||||
headers: {
|
||||
...this.defaultHeaders()
|
||||
}
|
||||
})
|
||||
|
||||
const rerankResults = data.data
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
|
||||
console.error('Voyage Reranker API Error:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -95,7 +95,8 @@ export function getProviderLogo(providerId: string) {
|
||||
return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP]
|
||||
}
|
||||
|
||||
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
|
||||
// export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
|
||||
export const NOT_SUPPORTED_REANK_PROVIDERS = ['ollama']
|
||||
|
||||
export const PROVIDER_CONFIG = {
|
||||
openai: {
|
||||
|
||||
@ -703,6 +703,7 @@
|
||||
"pinned": "Pinned",
|
||||
"rerank_model": "Reordering Model",
|
||||
"rerank_model_support_provider": "Currently, the reordering model only supports some providers ({{provider}})",
|
||||
"rerank_model_not_support_provider": "Currently, the reordering model does not support this provider ({{provider}})",
|
||||
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
|
||||
"search": "Search models...",
|
||||
"stream_output": "Stream output",
|
||||
|
||||
@ -717,7 +717,8 @@
|
||||
"text": "テキスト",
|
||||
"vision": "画像",
|
||||
"websearch": "ウェブ検索"
|
||||
}
|
||||
},
|
||||
"rerank_model_not_support_provider": "現在、並べ替えモデルはこのプロバイダー ({{provider}}) をサポートしていません。"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "ダイアログを展開",
|
||||
|
||||
@ -717,7 +717,8 @@
|
||||
"text": "Текст",
|
||||
"vision": "Визуальные",
|
||||
"websearch": "Веб-поисковые"
|
||||
}
|
||||
},
|
||||
"rerank_model_not_support_provider": "В настоящее время модель переупорядочивания не поддерживает этого провайдера ({{provider}})"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "Развернуть диалоговое окно",
|
||||
|
||||
@ -703,6 +703,7 @@
|
||||
"pinned": "已固定",
|
||||
"rerank_model": "重排模型",
|
||||
"rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})",
|
||||
"rerank_model_not_support_provider": "目前重排序模型不支持该服务商 ({{provider}})",
|
||||
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
||||
"search": "搜索模型...",
|
||||
"stream_output": "流式输出",
|
||||
|
||||
@ -717,7 +717,8 @@
|
||||
"text": "文字",
|
||||
"vision": "視覺",
|
||||
"websearch": "網路搜尋"
|
||||
}
|
||||
},
|
||||
"rerank_model_not_support_provider": "目前,重新排序模型不支援此提供者({{provider}})"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "伸縮對話框",
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { SettingHelpText } from '@renderer/pages/settings'
|
||||
@ -67,7 +68,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
|
||||
const rerankSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
@ -176,8 +177,8 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
|
||||
</Form.Item>
|
||||
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
|
||||
{t('models.rerank_model_support_provider', {
|
||||
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
{t('models.rerank_model_not_support_provider', {
|
||||
provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
})}
|
||||
</SettingHelpText>
|
||||
<Form.Item
|
||||
|
||||
@ -3,7 +3,8 @@ import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { useKnowledge } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { SettingHelpText } from '@renderer/pages/settings'
|
||||
@ -68,7 +69,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
|
||||
const rerankSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
@ -157,8 +158,8 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
/>
|
||||
</Form.Item>
|
||||
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
|
||||
{t('models.rerank_model_support_provider', {
|
||||
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
{t('models.rerank_model_not_support_provider', {
|
||||
provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
|
||||
})}
|
||||
</SettingHelpText>
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user