mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-23 18:10:26 +08:00
feat: support dashscope reranker (#5725)
This commit is contained in:
parent
afcb0eeae2
commit
0f34bde749
58
src/main/reranker/DashscopeReranker.ts
Normal file
58
src/main/reranker/DashscopeReranker.ts
Normal file
@ -0,0 +1,58 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import axiosProxy from '@main/services/AxiosProxy'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
interface DashscopeRerankResultItem {
|
||||
document: {
|
||||
text: string
|
||||
}
|
||||
index: number
|
||||
relevance_score: number
|
||||
}
|
||||
|
||||
interface DashscopeRerankResponse {
|
||||
output: {
|
||||
results: DashscopeRerankResultItem[]
|
||||
}
|
||||
usage: {
|
||||
total_tokens: number
|
||||
}
|
||||
request_id: string
|
||||
}
|
||||
|
||||
export default class DashscopeReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
input: {
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent)
|
||||
},
|
||||
parameters: {
|
||||
return_documents: true, // Recommended to be true to get document details if needed, though scores are primary
|
||||
top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axiosProxy.axios.post<DashscopeRerankResponse>(url, requestBody, {
|
||||
headers: this.defaultHeaders()
|
||||
})
|
||||
|
||||
const rerankResults = data.output.results
|
||||
return this.getRerankResult(searchResults, rerankResults)
|
||||
} catch (error: any) {
|
||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||
console.error('Dashscope Reranker API 错误:', errorDetails)
|
||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import DashscopeReranker from './DashscopeReranker'
|
||||
import DefaultReranker from './DefaultReranker'
|
||||
import JinaReranker from './JinaReranker'
|
||||
import SiliconFlowReranker from './SiliconFlowReranker'
|
||||
@ -14,6 +15,8 @@ export default class RerankerFactory {
|
||||
return new JinaReranker(base)
|
||||
} else if (base.rerankModelProvider === 'voyageai') {
|
||||
return new VoyageReranker(base)
|
||||
} else if (base.rerankModelProvider === 'dashscope') {
|
||||
return new DashscopeReranker(base)
|
||||
}
|
||||
return new DefaultReranker(base)
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ export function getProviderLogo(providerId: string) {
|
||||
return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP]
|
||||
}
|
||||
|
||||
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai']
|
||||
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope']
|
||||
|
||||
export const PROVIDER_CONFIG = {
|
||||
openai: {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user