mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-22 17:00:14 +08:00
fix: use ModernAiProvider for embedding dimensions (#11876)
This commit is contained in:
parent
96aba33077
commit
66feee714b
@ -2,7 +2,6 @@ import { loggerService } from '@logger'
|
|||||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||||
import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
import type { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
|
||||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
@ -160,9 +159,6 @@ export default class AiProvider {
|
|||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
try {
|
try {
|
||||||
// Use the SDK instance to test embedding capabilities
|
// Use the SDK instance to test embedding capabilities
|
||||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
|
||||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
|
||||||
}
|
|
||||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||||
return dimensions
|
return dimensions
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import AiProvider from '@renderer/aiCore'
|
|
||||||
import { RefreshIcon } from '@renderer/components/Icons'
|
import { RefreshIcon } from '@renderer/components/Icons'
|
||||||
import { useProvider } from '@renderer/hooks/useProvider'
|
import { useProvider } from '@renderer/hooks/useProvider'
|
||||||
import type { Model } from '@renderer/types'
|
import type { Model } from '@renderer/types'
|
||||||
@ -8,6 +7,8 @@ import { Button, InputNumber, Space, Tooltip } from 'antd'
|
|||||||
import { memo, useCallback, useMemo, useState } from 'react'
|
import { memo, useCallback, useMemo, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
|
import AiProviderNew from '../aiCore/index_new'
|
||||||
|
|
||||||
const logger = loggerService.withContext('DimensionsInput')
|
const logger = loggerService.withContext('DimensionsInput')
|
||||||
|
|
||||||
interface InputEmbeddingDimensionProps {
|
interface InputEmbeddingDimensionProps {
|
||||||
@ -47,7 +48,7 @@ const InputEmbeddingDimension = ({
|
|||||||
|
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
try {
|
try {
|
||||||
const aiProvider = new AiProvider(provider)
|
const aiProvider = new AiProviderNew(provider)
|
||||||
const dimension = await aiProvider.getEmbeddingDimensions(model)
|
const dimension = await aiProvider.getEmbeddingDimensions(model)
|
||||||
// for controlled input
|
// for controlled input
|
||||||
if (ref?.current) {
|
if (ref?.current) {
|
||||||
|
|||||||
@ -79,7 +79,7 @@ vi.mock('antd', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
vi.mock('@renderer/aiCore', () => ({
|
vi.mock('@renderer/aiCore/index_new', () => ({
|
||||||
default: vi.fn().mockImplementation(() => ({
|
default: vi.fn().mockImplementation(() => ({
|
||||||
getEmbeddingDimensions: mocks.aiCore.getEmbeddingDimensions
|
getEmbeddingDimensions: mocks.aiCore.getEmbeddingDimensions
|
||||||
}))
|
}))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user