cherry-studio/src/main/reranker/BaseReranker.ts
2025-05-18 10:02:05 +08:00

132 lines
3.2 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
export default abstract class BaseReranker {
protected base: KnowledgeBaseParams
constructor(base: KnowledgeBaseParams) {
if (!base.rerankModel) {
throw new Error('Rerank model is required')
}
this.base = base
}
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
/**
* Get Rerank Request Url
*/
protected getRerankUrl() {
if (this.base.rerankModelProvider === 'dashscope') {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL
// 必须携带/v1否则会404
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
return `${baseURL}/rerank`
}
/**
* Get Rerank Request Body
*/
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
const provider = this.base.rerankModelProvider
const documents = searchResults.map((doc) => doc.pageContent)
const topN = this.base.documentCount
if (provider === 'voyageai') {
return {
model: this.base.rerankModel,
query,
documents,
top_k: topN
}
} else if (provider === 'dashscope') {
return {
model: this.base.rerankModel,
input: {
query,
documents
},
parameters: {
top_n: topN
}
}
} else {
return {
model: this.base.rerankModel,
query,
documents,
top_n: topN
}
}
}
/**
* Extract Rerank Result
*/
protected extractRerankResult(data: any) {
const provider = this.base.rerankModelProvider
if (provider === 'dashscope') {
return data.output.results
} else if (provider === 'voyageai') {
return data.data
} else {
return data.results
}
}
/**
* Get Rerank Result
* @param searchResults
* @param rerankResults
* @protected
*/
protected getRerankResult(
searchResults: ExtractChunkData[],
rerankResults: Array<{
index: number
relevance_score: number
}>
) {
const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0]))
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.sort((a, b) => b.score - a.score)
}
public defaultHeaders() {
return {
Authorization: `Bearer ${this.base.rerankApiKey}`,
'Content-Type': 'application/json'
}
}
protected formatErrorMessage(url: string, error: any, requestBody: any) {
const errorDetails = {
url: url,
message: error.message,
status: error.response?.status,
statusText: error.response?.statusText,
requestBody: requestBody
}
return JSON.stringify(errorDetails, null, 2)
}
}