diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 2277299dc6..1f1f9dded9 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -10,6 +10,7 @@ import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH, UpgradeChannel } from '@shared/config/constant' import { IpcChannel } from '@shared/IpcChannel' import { FileMetadata, Provider, Shortcut, ThemeMode } from '@types' +import { OcrProvider, SupportedOcrFile } from '@types' import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron' import { Notification } from 'src/renderer/src/types/notification' @@ -30,7 +31,7 @@ import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceServic import NotificationService from './services/NotificationService' import * as NutstoreService from './services/NutstoreService' import ObsidianVaultService from './services/ObsidianVaultService' -import { ipcOcr } from './services/ocr/OcrService' +import { ocrService } from './services/ocr/OcrService' import { proxyManager } from './services/ProxyManager' import { pythonService } from './services/PythonService' import { FileServiceManager } from './services/remotefile/FileServiceManager' @@ -712,5 +713,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.CodeTools_Run, codeToolsService.run) // OCR - ipcMain.handle(IpcChannel.OCR_ocr, ipcOcr) + ipcMain.handle(IpcChannel.OCR_ocr, (_, file: SupportedOcrFile, provider: OcrProvider) => + ocrService.ocr(file, provider) + ) } diff --git a/src/main/services/ocr/OcrService.ts b/src/main/services/ocr/OcrService.ts index dd6dd5de1a..1361b9a6af 100644 --- a/src/main/services/ocr/OcrService.ts +++ b/src/main/services/ocr/OcrService.ts @@ -1,91 +1,32 @@ -import { loggerService } from '@logger' -import { MB } from '@shared/config/constant' -import { - ImageFileMetadata, - ImageOcrProvider, - isBuiltinOcrProvider, - isImageFile, - isImageOcrProvider, - OcrProvider, - OcrResult, - SupportedOcrFile -} from '@types' -import { statSync } from 'fs' -import { readFile } from 'fs/promises' +import { BuiltinOcrProviderIds, FileMetadata, OcrProvider, OcrResult, SupportedOcrFile } from '@types' import { tesseractService } from './tesseract/TesseractService' -const logger = loggerService.withContext('main:OcrService') +type OcrHandler = (file: FileMetadata) => Promise -/** - * ocr by tesseract - * @param file image file or base64 string - * @returns ocr result - * @throws {Error} - */ -const tesseractOcr = async (file: ImageFileMetadata | string): Promise => { - try { - const worker = await tesseractService.getWorker() - let ret: Tesseract.RecognizeResult - if (typeof file === 'string') { - ret = await worker.recognize(file) - } else { - const stat = statSync(file.path) - if (stat.size > 50 * MB) { - throw new Error('This image is too large (max 50MB)') - } - const buffer = await readFile(file.path) - ret = await worker.recognize(buffer) +export class OcrService { + private registry: Map = new Map() + + register(providerId: string, handler: OcrHandler): void { + this.registry.set(providerId, handler) + } + + unregister(providerId: string): void { + this.registry.delete(providerId) + } + + public async ocr(file: SupportedOcrFile, provider: OcrProvider): Promise { + const handler = this.registry.get(provider.id) + if (!handler) { + throw new Error(`Provider ${provider.id} is not registered`) } - return ret - } catch (e) { - logger.error('Failed to ocr with tesseract.', e as Error) - throw e + return handler(file) } } -/** - * ocr image file - * @param file image file - * @param provider ocr provider that supports image ocr - * @returns ocr result - * @throws {Error} - */ -const imageOcr = async (file: ImageFileMetadata, provider: ImageOcrProvider): Promise => { - if (isBuiltinOcrProvider(provider)) { - if (provider.id === 'tesseract') { - const result = await tesseractOcr(file) - return { text: result.data.text } - } else { - throw new Error(`Unsupported built-in ocr provider: ${provider.id}`) - } - } - throw new Error(`Provider ${provider.id} is not supported.`) -} +export const ocrService = new OcrService() -/** - * ocr a file - * @param file any supported file - * @param provider ocr provider - * @returns ocr result - * @throws {Error} - */ -export const ocr = async (file: SupportedOcrFile, provider: OcrProvider): Promise => { - if (isImageFile(file) && isImageOcrProvider(provider)) { - return imageOcr(file, provider) - } else { - throw new Error(`File type and provider capability is not matched, otherwise one of them is not supported.`) - } -} - -/** - * ocr a file - * @param _ ipc event - * @param file any supported file - * @param provider ocr provider - * @returns ocr result - * @throws {Error} - */ -export const ipcOcr = async (_: Electron.IpcMainInvokeEvent, ...args: Parameters) => { - return ocr(...args) -} +// Register built-in providers +ocrService.register(BuiltinOcrProviderIds.tesseract, async (file) => { + return tesseractService.ocr(file) +}) diff --git a/src/main/services/ocr/tesseract/TesseractService.ts b/src/main/services/ocr/tesseract/TesseractService.ts index d58c53bfea..2472f02fba 100644 --- a/src/main/services/ocr/tesseract/TesseractService.ts +++ b/src/main/services/ocr/tesseract/TesseractService.ts @@ -1,6 +1,7 @@ import { loggerService } from '@logger' import { getIpCountry } from '@main/utils/ipService' -import { TesseractLangsDownloadUrl } from '@shared/config/constant' +import { MB, TesseractLangsDownloadUrl } from '@shared/config/constant' +import { FileMetadata, ImageFileMetadata, isImageFile, OcrResult } from '@types' import { app } from 'electron' import fs from 'fs' import path from 'path' @@ -129,6 +130,24 @@ export class TesseractService { return this.worker } + async imageOcr(file: ImageFileMetadata): Promise { + const worker = await this.getWorker() + const stat = await fs.promises.stat(file.path) + if (stat.size > 50 * MB) { + throw new Error('This image is too large (max 50MB)') + } + const buffer = await fs.promises.readFile(file.path) + const result = await worker.recognize(buffer) + return { text: result.data.text } + } + + async ocr(file: FileMetadata): Promise { + if (!isImageFile(file)) { + throw new Error('Only image files are supported currently') + } + return this.imageOcr(file) + } + private async _getLangPath(): Promise { const country = await getIpCountry() return country.toLowerCase() === 'cn' ? TesseractLangsDownloadUrl.CN : TesseractLangsDownloadUrl.GLOBAL diff --git a/src/preload/index.ts b/src/preload/index.ts index af4803fd50..2c20f1ed57 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -18,7 +18,6 @@ import { MemoryListOptions, MemorySearchOptions, OcrProvider, - OcrResult, Provider, S3Config, Shortcut, @@ -411,8 +410,7 @@ const api = { ) => ipcRenderer.invoke(IpcChannel.CodeTools_Run, cliTool, model, directory, env, options) }, ocr: { - ocr: (file: SupportedOcrFile, provider: OcrProvider): Promise => - ipcRenderer.invoke(IpcChannel.OCR_ocr, file, provider) + ocr: (file: SupportedOcrFile, provider: OcrProvider) => ipcRenderer.invoke(IpcChannel.OCR_ocr, file, provider) } }