fix: VoyageEmbeddings (#8034)

* fix(embeddings): 修复VoyageAI嵌入格式和模型验证错误

修复OpenAIBaseClient中VoyageAI提供商的embedding格式设置问题
完善VoyageEmbeddings模型验证的错误提示信息

* refactor(embeddings): 移除VoyageEmbeddings的模型维度限制检查

简化VoyageEmbeddings的创建逻辑,不再对支持的模型维度进行校验

* fix(embeddings): 修复VoyageEmbeddings模型维度设置问题

修复VoyageEmbeddings中未正确校验模型是否支持设置outputDimension的问题
当provider为voyageai且模型不支持设置dimensions时,自动忽略传入的dimensions参数

* refactor(embeddings): 集中管理支持设置维度的模型列表

将各嵌入模型支持设置维度的模型列表集中到utils模块
不再让VoyageEmbeddings中getDimensions抛出错误,而是自动修复
This commit is contained in:
Phantom 2025-07-10 21:53:37 +08:00 committed by GitHub
parent 446ebae175
commit c05c06b7a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 26 deletions

View File

@ -5,26 +5,19 @@ import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-op
import { getInstanceName } from '@main/utils' import { getInstanceName } from '@main/utils'
import { KnowledgeBaseParams } from '@types' import { KnowledgeBaseParams } from '@types'
import { SUPPORTED_DIM_MODELS as VOYAGE_SUPPORTED_DIM_MODELS, VoyageEmbeddings } from './VoyageEmbeddings' import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils'
import { VoyageEmbeddings } from './VoyageEmbeddings'
export default class EmbeddingsFactory { export default class EmbeddingsFactory {
static create({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings { static create({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings {
const batchSize = 10 const batchSize = 10
if (provider === 'voyageai') { if (provider === 'voyageai') {
if (VOYAGE_SUPPORTED_DIM_MODELS.includes(model)) { return new VoyageEmbeddings({
return new VoyageEmbeddings({ modelName: model,
modelName: model, apiKey,
apiKey, outputDimension: VOYAGE_SUPPORTED_DIM_MODELS.includes(model) ? dimensions : undefined,
outputDimension: dimensions, batchSize: 8
batchSize: 8 })
})
} else {
return new VoyageEmbeddings({
modelName: model,
apiKey,
batchSize: 8
})
}
} }
if (provider === 'ollama') { if (provider === 'ollama') {
if (baseURL.includes('v1/')) { if (baseURL.includes('v1/')) {

View File

@ -1,27 +1,29 @@
import { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces' import { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
import { VoyageEmbeddings as _VoyageEmbeddings } from '@langchain/community/embeddings/voyage' import { VoyageEmbeddings as _VoyageEmbeddings } from '@langchain/community/embeddings/voyage'
import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils'
/** /**
* *
*/ */
export const SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3']
export class VoyageEmbeddings extends BaseEmbeddings { export class VoyageEmbeddings extends BaseEmbeddings {
private model: _VoyageEmbeddings private model: _VoyageEmbeddings
constructor(private readonly configuration?: ConstructorParameters<typeof _VoyageEmbeddings>[0]) { constructor(private readonly configuration?: ConstructorParameters<typeof _VoyageEmbeddings>[0]) {
super() super()
if (!this.configuration) this.configuration = {} if (!this.configuration) {
if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3' throw new Error('Pass in a configuration.')
if (!SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) {
throw new Error(`VoyageEmbeddings only supports ${SUPPORTED_DIM_MODELS.join(', ')}`)
} }
if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3'
this.model = new _VoyageEmbeddings(this.configuration) if (!VOYAGE_SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) {
console.error(`VoyageEmbeddings only supports ${VOYAGE_SUPPORTED_DIM_MODELS.join(', ')} to set outputDimension.`)
this.model = new _VoyageEmbeddings({ ...this.configuration, outputDimension: undefined })
} else {
this.model = new _VoyageEmbeddings(this.configuration)
}
} }
override async getDimensions(): Promise<number> { override async getDimensions(): Promise<number> {
if (!this.configuration?.outputDimension) { return this.configuration?.outputDimension ?? (this.configuration?.modelName === 'voyage-code-2' ? 1536 : 1024)
throw new Error('You need to pass in the optional dimensions parameter for this model')
}
return this.configuration?.outputDimension
} }
override async embedDocuments(texts: string[]): Promise<number[][]> { override async embedDocuments(texts: string[]): Promise<number[][]> {

View File

@ -0,0 +1,45 @@
export const VOYAGE_SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3']
// NOTE: 下面的暂时没用上,但先留着吧
export const OPENAI_SUPPORTED_DIM_MODELS = ['text-embedding-3-small', 'text-embedding-3-large']
export const DASHSCOPE_SUPPORTED_DIM_MODELS = ['text-embedding-v3', 'text-embedding-v4']
export const OPENSOURCE_SUPPORTED_DIM_MODELS = ['qwen3-embedding-0.6B', 'qwen3-embedding-4B', 'qwen3-embedding-8B']
export const GOOGLE_SUPPORTED_DIM_MODELS = ['gemini-embedding-exp-03-07', 'gemini-embedding-exp']
export const SUPPORTED_DIM_MODELS = [
...VOYAGE_SUPPORTED_DIM_MODELS,
...OPENAI_SUPPORTED_DIM_MODELS,
...DASHSCOPE_SUPPORTED_DIM_MODELS,
...OPENSOURCE_SUPPORTED_DIM_MODELS,
...GOOGLE_SUPPORTED_DIM_MODELS
]
/**
* ID
*
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* ID
*
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
return getBaseModelName(id, delimiter).toLowerCase()
}

View File

@ -89,7 +89,7 @@ export abstract class OpenAIBaseClient<
const data = await sdk.embeddings.create({ const data = await sdk.embeddings.create({
model: model.id, model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi', input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: 'float' encoding_format: this.provider.id === 'voyageai' ? undefined : 'float'
}) })
return data.data[0].embedding.length return data.data[0].embedding.length
} }