mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-01 09:49:03 +08:00
* refactor(reranker): 重构重排序功能以提高可维护性 - 将 BaseReranker 类中的公共逻辑提取到受保护的方法中 - 优化了 JinaReranker、SiliconFlowReranker 和 VoyageReranker 的实现 - 新增 getRerankUrl 和 getRerankResult 方法以提高代码复用性 - 简化了重排序结果的处理逻辑 * refactor(reranker): 将 formatErrorMessage 方法的访问权限改为受保护 - 将 formatErrorMessage 方法的访问权限从公共 (public) 改为受保护 (protected) - 这一更改限制了方法的访问范围,仅允许子类访问该方法 - 有助于提高代码的封装性和安全性
78 lines
2.0 KiB
TypeScript
78 lines
2.0 KiB
TypeScript
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||
import { KnowledgeBaseParams } from '@types'
|
||
|
||
export default abstract class BaseReranker {
|
||
protected base: KnowledgeBaseParams
|
||
|
||
constructor(base: KnowledgeBaseParams) {
|
||
if (!base.rerankModel) {
|
||
throw new Error('Rerank model is required')
|
||
}
|
||
this.base = base
|
||
}
|
||
|
||
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
|
||
|
||
/**
|
||
* Get Rerank Request Url
|
||
*/
|
||
protected getRerankUrl() {
|
||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||
? this.base.rerankBaseURL.slice(0, -1)
|
||
: this.base.rerankBaseURL
|
||
// 必须携带/v1,否则会404
|
||
if (baseURL && !baseURL.endsWith('/v1')) {
|
||
baseURL = `${baseURL}/v1`
|
||
}
|
||
|
||
return `${baseURL}/rerank`
|
||
}
|
||
|
||
/**
|
||
* Get Rerank Result
|
||
* @param searchResults
|
||
* @param rerankResults
|
||
* @protected
|
||
*/
|
||
protected getRerankResult(
|
||
searchResults: ExtractChunkData[],
|
||
rerankResults: Array<{
|
||
index: number
|
||
relevance_score: number
|
||
}>
|
||
) {
|
||
const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0]))
|
||
|
||
return searchResults
|
||
.map((doc: ExtractChunkData, index: number) => {
|
||
const score = resultMap.get(index)
|
||
if (score === undefined) return undefined
|
||
|
||
return {
|
||
...doc,
|
||
score
|
||
}
|
||
})
|
||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||
.sort((a, b) => b.score - a.score)
|
||
}
|
||
|
||
public defaultHeaders() {
|
||
return {
|
||
Authorization: `Bearer ${this.base.rerankApiKey}`,
|
||
'Content-Type': 'application/json'
|
||
}
|
||
}
|
||
|
||
protected formatErrorMessage(url: string, error: any, requestBody: any) {
|
||
const errorDetails = {
|
||
url: url,
|
||
message: error.message,
|
||
status: error.response?.status,
|
||
statusText: error.response?.statusText,
|
||
requestBody: requestBody
|
||
}
|
||
return JSON.stringify(errorDetails, null, 2)
|
||
}
|
||
}
|