diff --git a/src/renderer/src/hooks/ocr/useOcrProvider.ts b/src/renderer/src/hooks/ocr/useOcrProvider.ts index 719455429f..8087235e13 100644 --- a/src/renderer/src/hooks/ocr/useOcrProvider.ts +++ b/src/renderer/src/hooks/ocr/useOcrProvider.ts @@ -1,35 +1,106 @@ +import { usePreference } from '@data/hooks/usePreference' import { loggerService } from '@logger' -import { updateOcrProviderConfig } from '@renderer/store/ocr' -import type { OcrProviderConfig } from '@renderer/types' +import type { + BuiltinOcrProviderId, + OcrOvConfig, + OcrOvProvider, + OcrPpocrConfig, + OcrPpocrProvider, + OcrSystemConfig, + OcrSystemProvider, + OcrTesseractConfig, + OcrTesseractProvider +} from '@renderer/types' import { BUILTIN_OCR_PROVIDERS_MAP } from '@shared/config/ocr' -import { useTranslation } from 'react-i18next' -import { useDispatch } from 'react-redux' - -import { useOcrProviders } from './useOcrProviders' +import { merge } from 'lodash' +import { useCallback, useMemo } from 'react' const logger = loggerService.withContext('useOcrProvider') -export const useOcrProvider = (id: string) => { - const { t } = useTranslation() - const dispatch = useDispatch() - const { providers } = useOcrProviders() - let provider = providers.find((p) => p.id === id) +const PROVIDER_REGISTRY = { + ovocr: null as unknown as OcrOvProvider, + paddleocr: null as unknown as OcrPpocrProvider, + system: null as unknown as OcrSystemProvider, + tesseract: null as unknown as OcrTesseractProvider +} - // safely fallback - if (!provider) { - logger.error(`Ocr Provider ${id} not found`) - logger.warn(`Fallback to tesseract`) - window.toast.error(t('ocr.error.provider.not_found')) - window.toast.warning(t('ocr.warning.provider.fallback', { name: 'Tesseract' })) - provider = BUILTIN_OCR_PROVIDERS_MAP.tesseract - } +const CONFIG_REGISTRY = { + ovocr: null as unknown as OcrOvConfig, + paddleocr: null as unknown as OcrPpocrConfig, + system: null as unknown as OcrSystemConfig, + tesseract: null as unknown as OcrTesseractConfig +} as const - const updateConfig = (update: Partial) => { - dispatch(updateOcrProviderConfig({ id: provider.id, update })) - } +type ProviderMap = typeof PROVIDER_REGISTRY + +type ConfigMap = typeof CONFIG_REGISTRY + +type TProvider = ProviderMap[T] + +type TConfig = ConfigMap[T] + +type UseOcrProviderReturn = { + provider: TProvider + config: TConfig + updateConfig: (update: Partial>) => void +} + +export const useOcrProvider = (id: T): UseOcrProviderReturn => { + const provider = useMemo(() => { + switch (id) { + case 'ovocr': + return BUILTIN_OCR_PROVIDERS_MAP.ovocr + case 'paddleocr': + return BUILTIN_OCR_PROVIDERS_MAP.paddleocr + case 'system': + return BUILTIN_OCR_PROVIDERS_MAP.system + case 'tesseract': + return BUILTIN_OCR_PROVIDERS_MAP.tesseract + } + }, [id]) + const [ovConfig, setOvConfig] = usePreference('ocr.provider.config.ovocr') + const [ppConfig, setPpConfig] = usePreference('ocr.provider.config.paddleocr') + const [sysConfig, setSysConfig] = usePreference('ocr.provider.config.system') + const [tesConfig, setTesConfig] = usePreference('ocr.provider.config.tesseract') + + const config = useMemo(() => { + switch (id) { + case 'ovocr': + return ovConfig + case 'paddleocr': + return ppConfig + case 'system': + return sysConfig + case 'tesseract': + return tesConfig + } + }, [id, ovConfig, ppConfig, sysConfig, tesConfig]) + + const updateConfig = useCallback( + (update: Partial>) => { + switch (id) { + case 'ovocr': + setOvConfig(merge({}, ovConfig, update)) + break + case 'paddleocr': + setPpConfig(merge({}, ppConfig, update)) + break + case 'system': + setSysConfig(merge({}, sysConfig, update)) + break + case 'tesseract': + setTesConfig(merge({}, tesConfig, update)) + break + default: + logger.warn(`Unsupported OCR provider id: ${id}`) + } + }, + [id, ovConfig, ppConfig, setOvConfig, setPpConfig, setSysConfig, setTesConfig, sysConfig, tesConfig] + ) return { provider, + config, updateConfig - } + } as UseOcrProviderReturn } diff --git a/src/renderer/src/hooks/ocr/useOcrProviders.ts b/src/renderer/src/hooks/ocr/useOcrProviders.ts index 2974c5c46a..717c1ea565 100644 --- a/src/renderer/src/hooks/ocr/useOcrProviders.ts +++ b/src/renderer/src/hooks/ocr/useOcrProviders.ts @@ -1,13 +1,10 @@ import { useQuery } from '@data/hooks/useDataApi' -import { loggerService } from '@logger' import { getBuiltinOcrProviderLabel } from '@renderer/i18n/label' import type { OcrProvider } from '@renderer/types' import { isBuiltinOcrProvider } from '@renderer/types' import { BUILTIN_OCR_PROVIDERS } from '@shared/config/ocr' import { useMemo } from 'react' -const logger = loggerService.withContext('useOcrProviders') - export const useOcrProviders = () => { const { data: validProviderIds, loading, error } = useQuery('/ocr/providers') const providers = useMemo(