feat: support dashscope reranker (#5725)

This commit is contained in:
Chen Tao 2025-05-07 14:17:03 +08:00 committed by GitHub
parent afcb0eeae2
commit 0f34bde749
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 1 deletions

View File

@ -0,0 +1,58 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
interface DashscopeRerankResultItem {
document: {
text: string
}
index: number
relevance_score: number
}
interface DashscopeRerankResponse {
output: {
results: DashscopeRerankResultItem[]
}
usage: {
total_tokens: number
}
request_id: string
}
export default class DashscopeReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
const requestBody = {
model: this.base.rerankModel,
input: {
query,
documents: searchResults.map((doc) => doc.pageContent)
},
parameters: {
return_documents: true, // Recommended to be true to get document details if needed, though scores are primary
top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example
}
}
try {
const { data } = await axiosProxy.axios.post<DashscopeRerankResponse>(url, requestBody, {
headers: this.defaultHeaders()
})
const rerankResults = data.output.results
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Dashscope Reranker API 错误:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -1,6 +1,7 @@
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import DashscopeReranker from './DashscopeReranker'
import DefaultReranker from './DefaultReranker'
import JinaReranker from './JinaReranker'
import SiliconFlowReranker from './SiliconFlowReranker'
@ -14,6 +15,8 @@ export default class RerankerFactory {
return new JinaReranker(base)
} else if (base.rerankModelProvider === 'voyageai') {
return new VoyageReranker(base)
} else if (base.rerankModelProvider === 'dashscope') {
return new DashscopeReranker(base)
}
return new DefaultReranker(base)
}

View File

@ -95,7 +95,7 @@ export function getProviderLogo(providerId: string) {
return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP]
}
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai']
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope']
export const PROVIDER_CONFIG = {
openai: {