From c05c06b7a15d2eadc0eb03865d73c9ccc0a547d3 Mon Sep 17 00:00:00 2001 From: Phantom <59059173+EurFelux@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:53:37 +0800 Subject: [PATCH] fix: VoyageEmbeddings (#8034) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(embeddings): 修复VoyageAI嵌入格式和模型验证错误 修复OpenAIBaseClient中VoyageAI提供商的embedding格式设置问题 完善VoyageEmbeddings模型验证的错误提示信息 * refactor(embeddings): 移除VoyageEmbeddings的模型维度限制检查 简化VoyageEmbeddings的创建逻辑,不再对支持的模型维度进行校验 * fix(embeddings): 修复VoyageEmbeddings模型维度设置问题 修复VoyageEmbeddings中未正确校验模型是否支持设置outputDimension的问题 当provider为voyageai且模型不支持设置dimensions时,自动忽略传入的dimensions参数 * refactor(embeddings): 集中管理支持设置维度的模型列表 将各嵌入模型支持设置维度的模型列表集中到utils模块 不再让VoyageEmbeddings中getDimensions抛出错误,而是自动修复 --- .../embeddings/EmbeddingsFactory.ts | 23 ++++------ .../knowledage/embeddings/VoyageEmbeddings.ts | 22 ++++----- src/main/knowledage/embeddings/utils.ts | 45 +++++++++++++++++++ .../aiCore/clients/openai/OpenAIBaseClient.ts | 2 +- 4 files changed, 66 insertions(+), 26 deletions(-) create mode 100644 src/main/knowledage/embeddings/utils.ts diff --git a/src/main/knowledage/embeddings/EmbeddingsFactory.ts b/src/main/knowledage/embeddings/EmbeddingsFactory.ts index 808db05794..5a7561cac8 100644 --- a/src/main/knowledage/embeddings/EmbeddingsFactory.ts +++ b/src/main/knowledage/embeddings/EmbeddingsFactory.ts @@ -5,26 +5,19 @@ import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-op import { getInstanceName } from '@main/utils' import { KnowledgeBaseParams } from '@types' -import { SUPPORTED_DIM_MODELS as VOYAGE_SUPPORTED_DIM_MODELS, VoyageEmbeddings } from './VoyageEmbeddings' +import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils' +import { VoyageEmbeddings } from './VoyageEmbeddings' export default class EmbeddingsFactory { static create({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings { const batchSize = 10 if (provider === 'voyageai') { - if (VOYAGE_SUPPORTED_DIM_MODELS.includes(model)) { - return new VoyageEmbeddings({ - modelName: model, - apiKey, - outputDimension: dimensions, - batchSize: 8 - }) - } else { - return new VoyageEmbeddings({ - modelName: model, - apiKey, - batchSize: 8 - }) - } + return new VoyageEmbeddings({ + modelName: model, + apiKey, + outputDimension: VOYAGE_SUPPORTED_DIM_MODELS.includes(model) ? dimensions : undefined, + batchSize: 8 + }) } if (provider === 'ollama') { if (baseURL.includes('v1/')) { diff --git a/src/main/knowledage/embeddings/VoyageEmbeddings.ts b/src/main/knowledage/embeddings/VoyageEmbeddings.ts index edec32dc51..61f23ad767 100644 --- a/src/main/knowledage/embeddings/VoyageEmbeddings.ts +++ b/src/main/knowledage/embeddings/VoyageEmbeddings.ts @@ -1,27 +1,29 @@ import { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces' import { VoyageEmbeddings as _VoyageEmbeddings } from '@langchain/community/embeddings/voyage' +import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils' + /** * 支持设置嵌入维度的模型 */ -export const SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3'] export class VoyageEmbeddings extends BaseEmbeddings { private model: _VoyageEmbeddings constructor(private readonly configuration?: ConstructorParameters[0]) { super() - if (!this.configuration) this.configuration = {} - if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3' - if (!SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) { - throw new Error(`VoyageEmbeddings only supports ${SUPPORTED_DIM_MODELS.join(', ')}`) + if (!this.configuration) { + throw new Error('Pass in a configuration.') } + if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3' - this.model = new _VoyageEmbeddings(this.configuration) + if (!VOYAGE_SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) { + console.error(`VoyageEmbeddings only supports ${VOYAGE_SUPPORTED_DIM_MODELS.join(', ')} to set outputDimension.`) + this.model = new _VoyageEmbeddings({ ...this.configuration, outputDimension: undefined }) + } else { + this.model = new _VoyageEmbeddings(this.configuration) + } } override async getDimensions(): Promise { - if (!this.configuration?.outputDimension) { - throw new Error('You need to pass in the optional dimensions parameter for this model') - } - return this.configuration?.outputDimension + return this.configuration?.outputDimension ?? (this.configuration?.modelName === 'voyage-code-2' ? 1536 : 1024) } override async embedDocuments(texts: string[]): Promise { diff --git a/src/main/knowledage/embeddings/utils.ts b/src/main/knowledage/embeddings/utils.ts new file mode 100644 index 0000000000..9b6bd54935 --- /dev/null +++ b/src/main/knowledage/embeddings/utils.ts @@ -0,0 +1,45 @@ +export const VOYAGE_SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3'] + +// NOTE: 下面的暂时没用上,但先留着吧 +export const OPENAI_SUPPORTED_DIM_MODELS = ['text-embedding-3-small', 'text-embedding-3-large'] + +export const DASHSCOPE_SUPPORTED_DIM_MODELS = ['text-embedding-v3', 'text-embedding-v4'] + +export const OPENSOURCE_SUPPORTED_DIM_MODELS = ['qwen3-embedding-0.6B', 'qwen3-embedding-4B', 'qwen3-embedding-8B'] + +export const GOOGLE_SUPPORTED_DIM_MODELS = ['gemini-embedding-exp-03-07', 'gemini-embedding-exp'] + +export const SUPPORTED_DIM_MODELS = [ + ...VOYAGE_SUPPORTED_DIM_MODELS, + ...OPENAI_SUPPORTED_DIM_MODELS, + ...DASHSCOPE_SUPPORTED_DIM_MODELS, + ...OPENSOURCE_SUPPORTED_DIM_MODELS, + ...GOOGLE_SUPPORTED_DIM_MODELS +] + +/** + * 从模型 ID 中提取基础名称。 + * 例如: + * - 'deepseek/deepseek-r1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 基础名称 + */ +export const getBaseModelName = (id: string, delimiter: string = '/'): string => { + const parts = id.split(delimiter) + return parts[parts.length - 1] +} + +/** + * 从模型 ID 中提取基础名称并转换为小写。 + * 例如: + * - 'deepseek/DeepSeek-R1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 小写的基础名称 + */ +export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => { + return getBaseModelName(id, delimiter).toLowerCase() +} diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts index 72fa1d7df8..95ddcbedd0 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts +++ b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts @@ -89,7 +89,7 @@ export abstract class OpenAIBaseClient< const data = await sdk.embeddings.create({ model: model.id, input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi', - encoding_format: 'float' + encoding_format: this.provider.id === 'voyageai' ? undefined : 'float' }) return data.data[0].embedding.length }