diff --git a/src/main/knowledage/embeddings/EmbeddingsFactory.ts b/src/main/knowledage/embeddings/EmbeddingsFactory.ts index 808db05794..5a7561cac8 100644 --- a/src/main/knowledage/embeddings/EmbeddingsFactory.ts +++ b/src/main/knowledage/embeddings/EmbeddingsFactory.ts @@ -5,26 +5,19 @@ import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-op import { getInstanceName } from '@main/utils' import { KnowledgeBaseParams } from '@types' -import { SUPPORTED_DIM_MODELS as VOYAGE_SUPPORTED_DIM_MODELS, VoyageEmbeddings } from './VoyageEmbeddings' +import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils' +import { VoyageEmbeddings } from './VoyageEmbeddings' export default class EmbeddingsFactory { static create({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings { const batchSize = 10 if (provider === 'voyageai') { - if (VOYAGE_SUPPORTED_DIM_MODELS.includes(model)) { - return new VoyageEmbeddings({ - modelName: model, - apiKey, - outputDimension: dimensions, - batchSize: 8 - }) - } else { - return new VoyageEmbeddings({ - modelName: model, - apiKey, - batchSize: 8 - }) - } + return new VoyageEmbeddings({ + modelName: model, + apiKey, + outputDimension: VOYAGE_SUPPORTED_DIM_MODELS.includes(model) ? dimensions : undefined, + batchSize: 8 + }) } if (provider === 'ollama') { if (baseURL.includes('v1/')) { diff --git a/src/main/knowledage/embeddings/VoyageEmbeddings.ts b/src/main/knowledage/embeddings/VoyageEmbeddings.ts index edec32dc51..61f23ad767 100644 --- a/src/main/knowledage/embeddings/VoyageEmbeddings.ts +++ b/src/main/knowledage/embeddings/VoyageEmbeddings.ts @@ -1,27 +1,29 @@ import { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces' import { VoyageEmbeddings as _VoyageEmbeddings } from '@langchain/community/embeddings/voyage' +import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils' + /** * 支持设置嵌入维度的模型 */ -export const SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3'] export class VoyageEmbeddings extends BaseEmbeddings { private model: _VoyageEmbeddings constructor(private readonly configuration?: ConstructorParameters[0]) { super() - if (!this.configuration) this.configuration = {} - if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3' - if (!SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) { - throw new Error(`VoyageEmbeddings only supports ${SUPPORTED_DIM_MODELS.join(', ')}`) + if (!this.configuration) { + throw new Error('Pass in a configuration.') } + if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3' - this.model = new _VoyageEmbeddings(this.configuration) + if (!VOYAGE_SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) { + console.error(`VoyageEmbeddings only supports ${VOYAGE_SUPPORTED_DIM_MODELS.join(', ')} to set outputDimension.`) + this.model = new _VoyageEmbeddings({ ...this.configuration, outputDimension: undefined }) + } else { + this.model = new _VoyageEmbeddings(this.configuration) + } } override async getDimensions(): Promise { - if (!this.configuration?.outputDimension) { - throw new Error('You need to pass in the optional dimensions parameter for this model') - } - return this.configuration?.outputDimension + return this.configuration?.outputDimension ?? (this.configuration?.modelName === 'voyage-code-2' ? 1536 : 1024) } override async embedDocuments(texts: string[]): Promise { diff --git a/src/main/knowledage/embeddings/utils.ts b/src/main/knowledage/embeddings/utils.ts new file mode 100644 index 0000000000..9b6bd54935 --- /dev/null +++ b/src/main/knowledage/embeddings/utils.ts @@ -0,0 +1,45 @@ +export const VOYAGE_SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3'] + +// NOTE: 下面的暂时没用上,但先留着吧 +export const OPENAI_SUPPORTED_DIM_MODELS = ['text-embedding-3-small', 'text-embedding-3-large'] + +export const DASHSCOPE_SUPPORTED_DIM_MODELS = ['text-embedding-v3', 'text-embedding-v4'] + +export const OPENSOURCE_SUPPORTED_DIM_MODELS = ['qwen3-embedding-0.6B', 'qwen3-embedding-4B', 'qwen3-embedding-8B'] + +export const GOOGLE_SUPPORTED_DIM_MODELS = ['gemini-embedding-exp-03-07', 'gemini-embedding-exp'] + +export const SUPPORTED_DIM_MODELS = [ + ...VOYAGE_SUPPORTED_DIM_MODELS, + ...OPENAI_SUPPORTED_DIM_MODELS, + ...DASHSCOPE_SUPPORTED_DIM_MODELS, + ...OPENSOURCE_SUPPORTED_DIM_MODELS, + ...GOOGLE_SUPPORTED_DIM_MODELS +] + +/** + * 从模型 ID 中提取基础名称。 + * 例如: + * - 'deepseek/deepseek-r1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 基础名称 + */ +export const getBaseModelName = (id: string, delimiter: string = '/'): string => { + const parts = id.split(delimiter) + return parts[parts.length - 1] +} + +/** + * 从模型 ID 中提取基础名称并转换为小写。 + * 例如: + * - 'deepseek/DeepSeek-R1' => 'deepseek-r1' + * - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1' + * @param {string} id 模型 ID + * @param {string} [delimiter='/'] 分隔符,默认为 '/' + * @returns {string} 小写的基础名称 + */ +export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => { + return getBaseModelName(id, delimiter).toLowerCase() +} diff --git a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts index 72fa1d7df8..95ddcbedd0 100644 --- a/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts +++ b/src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts @@ -89,7 +89,7 @@ export abstract class OpenAIBaseClient< const data = await sdk.embeddings.create({ model: model.id, input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi', - encoding_format: 'float' + encoding_format: this.provider.id === 'voyageai' ? undefined : 'float' }) return data.data[0].embedding.length }