mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 14:29:15 +08:00
Fix: Handle embedding dimension retrieval failure when creating knowledge base (#7324)
* fix(知识库): 处理获取嵌入维度为0时的错误情况 * fix(aiCore): 修复获取嵌入维度时错误处理不当的问题 修改各AI客户端获取嵌入维度的方法,在出错时抛出异常而不是返回0 同时在调用处移除对返回值为0的特殊处理,直接捕获异常 * refactor(aiCore): 移除获取嵌入维度的冗余try-catch块 简化代码结构,移除不必要的错误处理,因为错误会由上层调用者处理
This commit is contained in:
parent
90805e03d5
commit
48c809da51
@ -125,7 +125,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
|
|
||||||
// @ts-ignore sdk未提供
|
// @ts-ignore sdk未提供
|
||||||
override async getEmbeddingDimensions(): Promise<number> {
|
override async getEmbeddingDimensions(): Promise<number> {
|
||||||
return 0
|
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
|
||||||
}
|
}
|
||||||
|
|
||||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||||
|
|||||||
@ -147,15 +147,12 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
|
|
||||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
const sdk = await this.getSdkInstance()
|
const sdk = await this.getSdkInstance()
|
||||||
try {
|
|
||||||
const data = await sdk.models.embedContent({
|
const data = await sdk.models.embedContent({
|
||||||
model: model.id,
|
model: model.id,
|
||||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||||
})
|
})
|
||||||
return data.embeddings?.[0]?.values?.length || 0
|
return data.embeddings?.[0]?.values?.length || 0
|
||||||
} catch (e) {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override async listModels(): Promise<GeminiModel[]> {
|
override async listModels(): Promise<GeminiModel[]> {
|
||||||
|
|||||||
@ -85,16 +85,13 @@ export abstract class OpenAIBaseClient<
|
|||||||
|
|
||||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
const sdk = await this.getSdkInstance()
|
const sdk = await this.getSdkInstance()
|
||||||
try {
|
|
||||||
const data = await sdk.embeddings.create({
|
const data = await sdk.embeddings.create({
|
||||||
model: model.id,
|
model: model.id,
|
||||||
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
|
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
|
||||||
encoding_format: 'float'
|
encoding_format: 'float'
|
||||||
})
|
})
|
||||||
return data.data[0].embedding.length
|
return data.data[0].embedding.length
|
||||||
} catch (e) {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||||
|
|||||||
@ -114,7 +114,7 @@ export default class AiProvider {
|
|||||||
return dimensions
|
return dimensions
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error getting embedding dimensions:', error)
|
console.error('Error getting embedding dimensions:', error)
|
||||||
return 0
|
throw error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -116,7 +116,6 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
|||||||
const aiProvider = new AiProvider(provider)
|
const aiProvider = new AiProvider(provider)
|
||||||
values.dimensions = await aiProvider.getEmbeddingDimensions(selectedEmbeddingModel)
|
values.dimensions = await aiProvider.getEmbeddingDimensions(selectedEmbeddingModel)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error getting embedding dimensions:', error)
|
|
||||||
window.message.error(t('message.error.get_embedding_dimensions') + '\n' + getErrorMessage(error))
|
window.message.error(t('message.error.get_embedding_dimensions') + '\n' + getErrorMessage(error))
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -570,10 +570,7 @@ export async function checkApi(provider: Provider, model: Model): Promise<void>
|
|||||||
assistant.model = model
|
assistant.model = model
|
||||||
try {
|
try {
|
||||||
if (isEmbeddingModel(model)) {
|
if (isEmbeddingModel(model)) {
|
||||||
const result = await ai.getEmbeddingDimensions(model)
|
await ai.getEmbeddingDimensions(model)
|
||||||
if (result === 0) {
|
|
||||||
throw new Error(i18n.t('message.error.enter.model'))
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
const params: CompletionsParams = {
|
const params: CompletionsParams = {
|
||||||
callType: 'check',
|
callType: 'check',
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user