mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-11 16:39:15 +08:00
132 lines
3.2 KiB
TypeScript
132 lines
3.2 KiB
TypeScript
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||
import { KnowledgeBaseParams } from '@types'
|
||
|
||
export default abstract class BaseReranker {
|
||
protected base: KnowledgeBaseParams
|
||
|
||
constructor(base: KnowledgeBaseParams) {
|
||
if (!base.rerankModel) {
|
||
throw new Error('Rerank model is required')
|
||
}
|
||
this.base = base
|
||
}
|
||
|
||
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
|
||
|
||
/**
|
||
* Get Rerank Request Url
|
||
*/
|
||
protected getRerankUrl() {
|
||
if (this.base.rerankModelProvider === 'dashscope') {
|
||
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||
}
|
||
|
||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||
? this.base.rerankBaseURL.slice(0, -1)
|
||
: this.base.rerankBaseURL
|
||
// 必须携带/v1,否则会404
|
||
if (baseURL && !baseURL.endsWith('/v1')) {
|
||
baseURL = `${baseURL}/v1`
|
||
}
|
||
|
||
return `${baseURL}/rerank`
|
||
}
|
||
|
||
/**
|
||
* Get Rerank Request Body
|
||
*/
|
||
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
|
||
const provider = this.base.rerankModelProvider
|
||
const documents = searchResults.map((doc) => doc.pageContent)
|
||
const topN = this.base.documentCount
|
||
|
||
if (provider === 'voyageai') {
|
||
return {
|
||
model: this.base.rerankModel,
|
||
query,
|
||
documents,
|
||
top_k: topN
|
||
}
|
||
} else if (provider === 'dashscope') {
|
||
return {
|
||
model: this.base.rerankModel,
|
||
input: {
|
||
query,
|
||
documents
|
||
},
|
||
parameters: {
|
||
top_n: topN
|
||
}
|
||
}
|
||
} else {
|
||
return {
|
||
model: this.base.rerankModel,
|
||
query,
|
||
documents,
|
||
top_n: topN
|
||
}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Extract Rerank Result
|
||
*/
|
||
protected extractRerankResult(data: any) {
|
||
const provider = this.base.rerankModelProvider
|
||
if (provider === 'dashscope') {
|
||
return data.output.results
|
||
} else if (provider === 'voyageai') {
|
||
return data.data
|
||
} else {
|
||
return data.results
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Get Rerank Result
|
||
* @param searchResults
|
||
* @param rerankResults
|
||
* @protected
|
||
*/
|
||
protected getRerankResult(
|
||
searchResults: ExtractChunkData[],
|
||
rerankResults: Array<{
|
||
index: number
|
||
relevance_score: number
|
||
}>
|
||
) {
|
||
const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0]))
|
||
|
||
return searchResults
|
||
.map((doc: ExtractChunkData, index: number) => {
|
||
const score = resultMap.get(index)
|
||
if (score === undefined) return undefined
|
||
|
||
return {
|
||
...doc,
|
||
score
|
||
}
|
||
})
|
||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||
.sort((a, b) => b.score - a.score)
|
||
}
|
||
|
||
public defaultHeaders() {
|
||
return {
|
||
Authorization: `Bearer ${this.base.rerankApiKey}`,
|
||
'Content-Type': 'application/json'
|
||
}
|
||
}
|
||
|
||
protected formatErrorMessage(url: string, error: any, requestBody: any) {
|
||
const errorDetails = {
|
||
url: url,
|
||
message: error.message,
|
||
status: error.response?.status,
|
||
statusText: error.response?.statusText,
|
||
requestBody: requestBody
|
||
}
|
||
return JSON.stringify(errorDetails, null, 2)
|
||
}
|
||
}
|