mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-25 03:10:08 +08:00
feat(ocr): add config validation and pass provider config to ocr handlers
Add type guards for OCR provider configs and ensure config is passed to OCR handlers Update all built-in OCR services to validate config before processing
This commit is contained in:
parent
96f71f12ec
commit
0176cf7679
@ -224,10 +224,10 @@ class OcrService {
|
||||
}
|
||||
|
||||
// Validate that the provider exists in database
|
||||
await this.getProvider(params.providerId)
|
||||
const provider = await this.getProvider(params.providerId)
|
||||
|
||||
logger.debug(`Performing OCR with provider: ${params.providerId}`)
|
||||
const result = await service.ocr(file)
|
||||
const result = await service.ocr(file, provider.config)
|
||||
|
||||
logger.info(`OCR completed successfully with provider: ${params.providerId}`)
|
||||
return result
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { isWin } from '@main/constant'
|
||||
import type { OcrOvConfig, OcrResult, SupportedOcrFile } from '@types'
|
||||
import { isImageFileMetadata } from '@types'
|
||||
import type { OcrOvConfig, OcrProviderConfig, OcrResult, SupportedOcrFile } from '@types'
|
||||
import { isImageFileMetadata, isOcrOvConfig } from '@types'
|
||||
import { exec } from 'child_process'
|
||||
import * as fs from 'fs'
|
||||
import * as os from 'os'
|
||||
@ -78,8 +78,8 @@ export class OvOcrService extends OcrBaseService {
|
||||
}
|
||||
}
|
||||
|
||||
private async ocrImage(filePath: string, options?: OcrOvConfig): Promise<OcrResult> {
|
||||
logger.info(`OV OCR called on ${filePath} with options ${JSON.stringify(options)}`)
|
||||
private async ocrImage(filePath: string, config?: OcrOvConfig): Promise<OcrResult> {
|
||||
logger.info(`OV OCR called on ${filePath} with options ${JSON.stringify(config)}`)
|
||||
|
||||
try {
|
||||
// 1. Clear img directory and output directory
|
||||
@ -114,9 +114,12 @@ export class OvOcrService extends OcrBaseService {
|
||||
}
|
||||
}
|
||||
|
||||
public ocr = async (file: SupportedOcrFile, options?: OcrOvConfig): Promise<OcrResult> => {
|
||||
public ocr = async (file: SupportedOcrFile, config?: OcrProviderConfig): Promise<OcrResult> => {
|
||||
if (!isOcrOvConfig(config)) {
|
||||
throw new Error('Invalid OCR OV config')
|
||||
}
|
||||
if (isImageFileMetadata(file)) {
|
||||
return this.ocrImage(file.path, options)
|
||||
return this.ocrImage(file.path, config)
|
||||
} else {
|
||||
throw new Error('Unsupported file type, currently only image files are supported')
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { loadOcrImage } from '@main/utils/ocr'
|
||||
import type { ImageFileMetadata, OcrPpocrConfig, OcrResult, SupportedOcrFile } from '@types'
|
||||
import { isImageFileMetadata } from '@types'
|
||||
import { isImageFileMetadata, isOcrPpocrConfig } from '@types'
|
||||
import { net } from 'electron'
|
||||
import * as z from 'zod'
|
||||
|
||||
@ -40,14 +40,17 @@ const OcrResponseSchema = z.object({
|
||||
})
|
||||
|
||||
export class PpocrService extends OcrBaseService {
|
||||
public ocr = async (file: SupportedOcrFile, options?: OcrPpocrConfig): Promise<OcrResult> => {
|
||||
public ocr = async (file: SupportedOcrFile, config?: OcrPpocrConfig): Promise<OcrResult> => {
|
||||
if (!isOcrPpocrConfig(config)) {
|
||||
throw new Error('Invalid OCR config')
|
||||
}
|
||||
if (!isImageFileMetadata(file)) {
|
||||
throw new Error('Only image files are supported currently')
|
||||
}
|
||||
if (!options) {
|
||||
if (!config) {
|
||||
throw new Error('config is required')
|
||||
}
|
||||
return this.imageOcr(file, options)
|
||||
return this.imageOcr(file, config)
|
||||
}
|
||||
|
||||
private async imageOcr(file: ImageFileMetadata, options: OcrPpocrConfig): Promise<OcrResult> {
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import { isLinux, isWin } from '@main/constant'
|
||||
import { loadOcrImage } from '@main/utils/ocr'
|
||||
import { OcrAccuracy, recognize } from '@napi-rs/system-ocr'
|
||||
import type { ImageFileMetadata, OcrResult, OcrSystemConfig, SupportedOcrFile } from '@types'
|
||||
import { isImageFileMetadata } from '@types'
|
||||
import type { ImageFileMetadata, OcrProviderConfig, OcrResult, OcrSystemConfig, SupportedOcrFile } from '@types'
|
||||
import { isImageFileMetadata, isOcrSystemConfig } from '@types'
|
||||
|
||||
import { OcrBaseService } from './OcrBaseService'
|
||||
|
||||
@ -12,19 +12,22 @@ export class SystemOcrService extends OcrBaseService {
|
||||
super()
|
||||
}
|
||||
|
||||
private async ocrImage(file: ImageFileMetadata, options?: OcrSystemConfig): Promise<OcrResult> {
|
||||
private async ocrImage(file: ImageFileMetadata, config?: OcrSystemConfig): Promise<OcrResult> {
|
||||
if (isLinux) {
|
||||
return { text: '' }
|
||||
}
|
||||
const buffer = await loadOcrImage(file)
|
||||
const langs = isWin ? options?.langs : undefined
|
||||
const langs = isWin ? config?.langs : undefined
|
||||
const result = await recognize(buffer, OcrAccuracy.Accurate, langs)
|
||||
return { text: result.text }
|
||||
}
|
||||
|
||||
public ocr = async (file: SupportedOcrFile, options?: OcrSystemConfig): Promise<OcrResult> => {
|
||||
public ocr = async (file: SupportedOcrFile, config?: OcrProviderConfig): Promise<OcrResult> => {
|
||||
if (!isOcrSystemConfig(config)) {
|
||||
throw new Error('Invalid OCR configuration')
|
||||
}
|
||||
if (isImageFileMetadata(file)) {
|
||||
return this.ocrImage(file, options)
|
||||
return this.ocrImage(file, config)
|
||||
} else {
|
||||
throw new Error('Unsupported file type, currently only image files are supported')
|
||||
}
|
||||
|
||||
@ -2,8 +2,8 @@ import { loggerService } from '@logger'
|
||||
import { getIpCountry } from '@main/utils/ipService'
|
||||
import { loadOcrImage } from '@main/utils/ocr'
|
||||
import { MB } from '@shared/config/constant'
|
||||
import type { ImageFileMetadata, OcrResult, OcrTesseractConfig, SupportedOcrFile } from '@types'
|
||||
import { isImageFileMetadata } from '@types'
|
||||
import type { ImageFileMetadata, OcrProviderConfig, OcrResult, OcrTesseractConfig, SupportedOcrFile } from '@types'
|
||||
import { isImageFileMetadata, isOcrTesseractConfig } from '@types'
|
||||
import { app } from 'electron'
|
||||
import fs from 'fs'
|
||||
import { isEqual } from 'lodash'
|
||||
@ -70,8 +70,8 @@ export class TesseractService extends OcrBaseService {
|
||||
return this.worker
|
||||
}
|
||||
|
||||
private async imageOcr(file: ImageFileMetadata, options?: OcrTesseractConfig): Promise<OcrResult> {
|
||||
const worker = await this.getWorker(options)
|
||||
private async imageOcr(file: ImageFileMetadata, config?: OcrTesseractConfig): Promise<OcrResult> {
|
||||
const worker = await this.getWorker(config)
|
||||
const stat = await fs.promises.stat(file.path)
|
||||
if (stat.size > MB_SIZE_THRESHOLD * MB) {
|
||||
throw new Error(`This image is too large (max ${MB_SIZE_THRESHOLD}MB)`)
|
||||
@ -81,11 +81,14 @@ export class TesseractService extends OcrBaseService {
|
||||
return { text: result.data.text }
|
||||
}
|
||||
|
||||
public ocr = async (file: SupportedOcrFile, options?: OcrTesseractConfig): Promise<OcrResult> => {
|
||||
public ocr = async (file: SupportedOcrFile, config?: OcrProviderConfig): Promise<OcrResult> => {
|
||||
if (!isOcrTesseractConfig(config)) {
|
||||
throw new Error('Invalid Tesseract config')
|
||||
}
|
||||
if (!isImageFileMetadata(file)) {
|
||||
throw new Error('Only image files are supported currently')
|
||||
}
|
||||
return this.imageOcr(file, options)
|
||||
return this.imageOcr(file, config)
|
||||
}
|
||||
|
||||
private async _getLangPath(): Promise<string> {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { FileMetadata, ImageFileMetadata } from '..'
|
||||
import type { FileMetadata, ImageFileMetadata, OcrProviderConfig } from '..'
|
||||
import { isImageFileMetadata } from '..'
|
||||
|
||||
export type SupportedOcrFile = ImageFileMetadata
|
||||
@ -15,6 +15,6 @@ export type OcrResult = {
|
||||
text: string
|
||||
}
|
||||
|
||||
export type OcrHandler = (file: SupportedOcrFile) => Promise<OcrResult>
|
||||
export type OcrHandler = (file: SupportedOcrFile, config?: OcrProviderConfig) => Promise<OcrResult>
|
||||
|
||||
export type OcrImageHandler = (file: ImageFileMetadata) => Promise<OcrResult>
|
||||
export type OcrImageHandler = (file: ImageFileMetadata, config?: OcrProviderConfig) => Promise<OcrResult>
|
||||
|
||||
@ -1,17 +1,24 @@
|
||||
import type { TranslateLanguageCode } from '../../translate'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { TranslateLanguageCodeSchema } from '../../translate'
|
||||
import type { OcrProvider } from './base'
|
||||
import { type ImageOcrProvider } from './base'
|
||||
import { type ImageOcrProvider, OcrProviderBaseConfigSchema } from './base'
|
||||
import { type BuiltinOcrProvider } from './base'
|
||||
import { type OcrProviderBaseConfig } from './base'
|
||||
import { BuiltinOcrProviderIdMap } from './base'
|
||||
|
||||
// ==========================================================
|
||||
// System OCR Types
|
||||
// ==========================================================
|
||||
|
||||
export interface OcrSystemConfig extends OcrProviderBaseConfig {
|
||||
langs?: TranslateLanguageCode[]
|
||||
export const OcrSystemConfigSchema = OcrProviderBaseConfigSchema.extend({
|
||||
langs: z.array(TranslateLanguageCodeSchema).optional()
|
||||
})
|
||||
|
||||
export type OcrSystemConfig = z.infer<typeof OcrSystemConfigSchema>
|
||||
export const isOcrSystemConfig = (c: unknown): c is OcrSystemConfig => {
|
||||
return OcrSystemConfigSchema.safeParse(c).success
|
||||
}
|
||||
|
||||
export type OcrSystemProvider = {
|
||||
id: 'system'
|
||||
config: OcrSystemConfig
|
||||
|
||||
@ -18,6 +18,10 @@ export const OcrTesseractConfigSchema = OcrProviderBaseConfigSchema.extend({
|
||||
|
||||
export type OcrTesseractConfig = z.infer<typeof OcrTesseractConfigSchema>
|
||||
|
||||
export const isOcrTesseractConfig = (value: unknown): value is OcrTesseractConfig => {
|
||||
return OcrTesseractConfigSchema.safeParse(value).success
|
||||
}
|
||||
|
||||
export type OcrTesseractProvider = {
|
||||
id: 'tesseract'
|
||||
config: OcrTesseractConfig
|
||||
|
||||
Loading…
Reference in New Issue
Block a user