From 0f34bde749ef20e53f130aa8717f802ca39f5265 Mon Sep 17 00:00:00 2001 From: Chen Tao <70054568+eeee0717@users.noreply.github.com> Date: Wed, 7 May 2025 14:17:03 +0800 Subject: [PATCH] feat: support dashscope reranker (#5725) --- src/main/reranker/DashscopeReranker.ts | 58 ++++++++++++++++++++++++++ src/main/reranker/RerankerFactory.ts | 3 ++ src/renderer/src/config/providers.ts | 2 +- 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 src/main/reranker/DashscopeReranker.ts diff --git a/src/main/reranker/DashscopeReranker.ts b/src/main/reranker/DashscopeReranker.ts new file mode 100644 index 0000000000..ac96092a1e --- /dev/null +++ b/src/main/reranker/DashscopeReranker.ts @@ -0,0 +1,58 @@ +import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' +import axiosProxy from '@main/services/AxiosProxy' +import { KnowledgeBaseParams } from '@types' + +import BaseReranker from './BaseReranker' + +interface DashscopeRerankResultItem { + document: { + text: string + } + index: number + relevance_score: number +} + +interface DashscopeRerankResponse { + output: { + results: DashscopeRerankResultItem[] + } + usage: { + total_tokens: number + } + request_id: string +} + +export default class DashscopeReranker extends BaseReranker { + constructor(base: KnowledgeBaseParams) { + super(base) + } + + public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { + const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank' + + const requestBody = { + model: this.base.rerankModel, + input: { + query, + documents: searchResults.map((doc) => doc.pageContent) + }, + parameters: { + return_documents: true, // Recommended to be true to get document details if needed, though scores are primary + top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example + } + } + + try { + const { data } = await axiosProxy.axios.post(url, requestBody, { + headers: this.defaultHeaders() + }) + + const rerankResults = data.output.results + return this.getRerankResult(searchResults, rerankResults) + } catch (error: any) { + const errorDetails = this.formatErrorMessage(url, error, requestBody) + console.error('Dashscope Reranker API 错误:', errorDetails) + throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`) + } + } +} diff --git a/src/main/reranker/RerankerFactory.ts b/src/main/reranker/RerankerFactory.ts index 9557d58a97..d1ae18d788 100644 --- a/src/main/reranker/RerankerFactory.ts +++ b/src/main/reranker/RerankerFactory.ts @@ -1,6 +1,7 @@ import { KnowledgeBaseParams } from '@types' import BaseReranker from './BaseReranker' +import DashscopeReranker from './DashscopeReranker' import DefaultReranker from './DefaultReranker' import JinaReranker from './JinaReranker' import SiliconFlowReranker from './SiliconFlowReranker' @@ -14,6 +15,8 @@ export default class RerankerFactory { return new JinaReranker(base) } else if (base.rerankModelProvider === 'voyageai') { return new VoyageReranker(base) + } else if (base.rerankModelProvider === 'dashscope') { + return new DashscopeReranker(base) } return new DefaultReranker(base) } diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 7bff824a43..4879e67419 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -95,7 +95,7 @@ export function getProviderLogo(providerId: string) { return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP] } -export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai'] +export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope'] export const PROVIDER_CONFIG = { openai: {