import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { KnowledgeBaseParams } from '@types' export default abstract class BaseReranker { protected base: KnowledgeBaseParams constructor(base: KnowledgeBaseParams) { if (!base.rerankModel) { throw new Error('Rerank model is required') } this.base = base } abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise /** * Get Rerank Request Url */ protected getRerankUrl() { if (this.base.rerankModelProvider === 'dashscope') { return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank' } let baseURL = this.base.rerankBaseURL if (baseURL && baseURL.endsWith('/')) { // `/` 结尾强制使用rerankBaseURL return `${baseURL}rerank` } if (baseURL && !baseURL.endsWith('/v1')) { baseURL = `${baseURL}/v1` } return `${baseURL}/rerank` } /** * Get Rerank Request Body */ protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) { const provider = this.base.rerankModelProvider const documents = searchResults.map((doc) => doc.pageContent) const topN = this.base.documentCount if (provider === 'voyageai') { return { model: this.base.rerankModel, query, documents, top_k: topN } } else if (provider === 'dashscope') { return { model: this.base.rerankModel, input: { query, documents }, parameters: { top_n: topN } } } else { return { model: this.base.rerankModel, query, documents, top_n: topN } } } /** * Extract Rerank Result */ protected extractRerankResult(data: any) { const provider = this.base.rerankModelProvider if (provider === 'dashscope') { return data.output.results } else if (provider === 'voyageai') { return data.data } else { return data.results } } /** * Get Rerank Result * @param searchResults * @param rerankResults * @protected */ protected getRerankResult( searchResults: ExtractChunkData[], rerankResults: Array<{ index: number relevance_score: number }> ) { const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0])) return searchResults .map((doc: ExtractChunkData, index: number) => { const score = resultMap.get(index) if (score === undefined) return undefined return { ...doc, score } }) .filter((doc): doc is ExtractChunkData => doc !== undefined) .sort((a, b) => b.score - a.score) } public defaultHeaders() { return { Authorization: `Bearer ${this.base.rerankApiKey}`, 'Content-Type': 'application/json' } } protected formatErrorMessage(url: string, error: any, requestBody: any) { const errorDetails = { url: url, message: error.message, status: error.response?.status, statusText: error.response?.statusText, requestBody: requestBody } return JSON.stringify(errorDetails, null, 2) } }