fix: azure openai embed (#8250)

* fix/Azure embed

* fix/azure-patch1
This commit is contained in:
SuYao 2025-07-20 09:38:57 +08:00 committed by GitHub
parent 4962f692a7
commit 411c5bc94e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View File

@ -2,7 +2,6 @@ import type { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
import { OllamaEmbeddings } from '@cherrystudio/embedjs-ollama'
import { OpenAiEmbeddings } from '@cherrystudio/embedjs-openai'
import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-openai-embeddings'
import { getInstanceName } from '@main/utils'
import { ApiClient } from '@types'
import { VOYAGE_SUPPORTED_DIM_MODELS } from './utils'
@ -45,7 +44,7 @@ export default class EmbeddingsFactory {
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: apiVersion,
azureOpenAIApiDeploymentName: model,
azureOpenAIApiInstanceName: getInstanceName(baseURL),
azureOpenAIEndpoint: baseURL,
dimensions,
batchSize
})

View File

@ -2,6 +2,7 @@ import { loggerService } from '@logger'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
@ -135,6 +136,9 @@ export default class AiProvider {
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
}
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {