diff --git a/src/main/knowledge/embeddings/EmbeddingsFactory.ts b/src/main/knowledge/embeddings/EmbeddingsFactory.ts index b0ecf360f..6b54f6d90 100644 --- a/src/main/knowledge/embeddings/EmbeddingsFactory.ts +++ b/src/main/knowledge/embeddings/EmbeddingsFactory.ts @@ -2,7 +2,6 @@ import type { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces' import { OllamaEmbeddings } from '@cherrystudio/embedjs-ollama' import { OpenAiEmbeddings } from '@cherrystudio/embedjs-openai' import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-openai-embeddings' -import { getInstanceName } from '@main/utils' import { ApiClient } from '@types' import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils' @@ -45,7 +44,7 @@ export default class EmbeddingsFactory { azureOpenAIApiKey: apiKey, azureOpenAIApiVersion: apiVersion, azureOpenAIApiDeploymentName: model, - azureOpenAIApiInstanceName: getInstanceName(baseURL), + azureOpenAIEndpoint: baseURL, dimensions, batchSize }) diff --git a/src/renderer/src/aiCore/index.ts b/src/renderer/src/aiCore/index.ts index 7e79cb0fc..9e7a9b731 100644 --- a/src/renderer/src/aiCore/index.ts +++ b/src/renderer/src/aiCore/index.ts @@ -2,6 +2,7 @@ import { loggerService } from '@logger' import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory' import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient' import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models' +import { getProviderByModel } from '@renderer/services/AssistantService' import type { GenerateImageParams, Model, Provider } from '@renderer/types' import { RequestOptions, SdkModel } from '@renderer/types/sdk' import { isEnabledToolUse } from '@renderer/utils/mcp-tools' @@ -135,6 +136,9 @@ export default class AiProvider { public async getEmbeddingDimensions(model: Model): Promise { try { // Use the SDK instance to test embedding capabilities + if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') { + this.apiClient = this.apiClient.getClient(model) as BaseApiClient + } const dimensions = await this.apiClient.getEmbeddingDimensions(model) return dimensions } catch (error) {