mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-07 13:59:28 +08:00
refactor(ocr): simplify ocr service interface and params handling
- Replace OcrProvider with OcrParams to simplify interface - Remove unused OcrApiClientFactory and related code - Consolidate ocr service calls to use consistent params structure
This commit is contained in:
parent
68aaf9df4a
commit
49c80620ae
@ -18,7 +18,7 @@ import type {
|
|||||||
AgentPersistedMessage,
|
AgentPersistedMessage,
|
||||||
FileMetadata,
|
FileMetadata,
|
||||||
Notification,
|
Notification,
|
||||||
OcrProvider,
|
OcrParams,
|
||||||
Provider,
|
Provider,
|
||||||
Shortcut,
|
Shortcut,
|
||||||
SupportedOcrFile
|
SupportedOcrFile
|
||||||
@ -872,9 +872,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// OCR
|
// OCR
|
||||||
ipcMain.handle(IpcChannel.OCR_Ocr, (_, file: SupportedOcrFile, provider: OcrProvider) =>
|
ipcMain.handle(IpcChannel.OCR_Ocr, (_, file: SupportedOcrFile, params: OcrParams) => ocrService.ocr(file, params))
|
||||||
ocrService.ocr(file, provider)
|
|
||||||
)
|
|
||||||
|
|
||||||
// OVMS
|
// OVMS
|
||||||
ipcMain.handle(IpcChannel.Ovms_AddModel, (_, modelName: string, modelId: string, modelSource: string, task: string) =>
|
ipcMain.handle(IpcChannel.Ovms_AddModel, (_, modelName: string, modelId: string, modelSource: string, task: string) =>
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { OcrProvider, OcrResult, SupportedOcrFile } from '@types'
|
import type { OcrParams, OcrResult, SupportedOcrFile } from '@types'
|
||||||
import { BuiltinOcrProviderIds } from '@types'
|
import { BuiltinOcrProviderIds } from '@types'
|
||||||
|
|
||||||
import type { OcrBaseService } from './builtin/OcrBaseService'
|
import type { OcrBaseService } from './builtin/OcrBaseService'
|
||||||
@ -28,12 +28,12 @@ export class OcrService {
|
|||||||
return Array.from(this.registry.keys())
|
return Array.from(this.registry.keys())
|
||||||
}
|
}
|
||||||
|
|
||||||
public async ocr(file: SupportedOcrFile, provider: OcrProvider): Promise<OcrResult> {
|
public async ocr(file: SupportedOcrFile, params: OcrParams): Promise<OcrResult> {
|
||||||
const service = this.registry.get(provider.id)
|
const service = this.registry.get(params.providerId)
|
||||||
if (!service) {
|
if (!service) {
|
||||||
throw new Error(`Provider ${provider.id} is not registered`)
|
throw new Error(`Provider ${params.providerId} is not registered`)
|
||||||
}
|
}
|
||||||
return service.ocr(file, provider.config)
|
return service.ocr(file)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import type {
|
|||||||
} from '@shared/data/preference/preferenceTypes'
|
} from '@shared/data/preference/preferenceTypes'
|
||||||
import type { UpgradeChannel } from '@shared/data/preference/preferenceTypes'
|
import type { UpgradeChannel } from '@shared/data/preference/preferenceTypes'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import type { Notification } from '@types'
|
import type { Notification, OcrParams } from '@types'
|
||||||
import type {
|
import type {
|
||||||
AddMemoryOptions,
|
AddMemoryOptions,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
@ -27,7 +27,6 @@ import type {
|
|||||||
MemoryConfig,
|
MemoryConfig,
|
||||||
MemoryListOptions,
|
MemoryListOptions,
|
||||||
MemorySearchOptions,
|
MemorySearchOptions,
|
||||||
OcrProvider,
|
|
||||||
OcrResult,
|
OcrResult,
|
||||||
Provider,
|
Provider,
|
||||||
RestartApiServerStatusResult,
|
RestartApiServerStatusResult,
|
||||||
@ -476,8 +475,8 @@ const api = {
|
|||||||
ipcRenderer.invoke(IpcChannel.CodeTools_RemoveCustomTerminalPath, terminalId)
|
ipcRenderer.invoke(IpcChannel.CodeTools_RemoveCustomTerminalPath, terminalId)
|
||||||
},
|
},
|
||||||
ocr: {
|
ocr: {
|
||||||
ocr: (file: SupportedOcrFile, provider: OcrProvider): Promise<OcrResult> =>
|
ocr: (file: SupportedOcrFile, params: OcrParams): Promise<OcrResult> =>
|
||||||
ipcRenderer.invoke(IpcChannel.OCR_Ocr, file, provider)
|
ipcRenderer.invoke(IpcChannel.OCR_Ocr, file, params)
|
||||||
},
|
},
|
||||||
cherryai: {
|
cherryai: {
|
||||||
generateSignature: (params: { method: string; path: string; query: string; body: Record<string, any> }) =>
|
generateSignature: (params: { method: string; path: string; query: string; body: Record<string, any> }) =>
|
||||||
|
|||||||
@ -26,8 +26,10 @@ export const useOcr = () => {
|
|||||||
const ocrImage = useCallback(
|
const ocrImage = useCallback(
|
||||||
async (image: ImageFileMetadata) => {
|
async (image: ImageFileMetadata) => {
|
||||||
if (isProviderAvailable(imageProvider)) {
|
if (isProviderAvailable(imageProvider)) {
|
||||||
logger.debug('ocrImage', { config: imageProvider.config })
|
logger.debug('ocrImage', { provider: imageProvider })
|
||||||
return OcrService.ocr(image, imageProvider)
|
return OcrService.ocr(image, {
|
||||||
|
providerId: imageProvider.id
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
throw new Error(t('ocr.error.provider.'))
|
throw new Error(t('ocr.error.provider.'))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { OcrProvider, OcrResult, SupportedOcrFile } from '@renderer/types'
|
import type { OcrParams, OcrResult, SupportedOcrFile } from '@renderer/types'
|
||||||
import { isOcrApiProvider } from '@renderer/types'
|
|
||||||
|
|
||||||
import { OcrApiClientFactory } from './clients/OcrApiClientFactory'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('renderer:OcrService')
|
const logger = loggerService.withContext('renderer:OcrService')
|
||||||
|
|
||||||
@ -13,12 +10,7 @@ const logger = loggerService.withContext('renderer:OcrService')
|
|||||||
* @returns ocr result
|
* @returns ocr result
|
||||||
* @throws {Error}
|
* @throws {Error}
|
||||||
*/
|
*/
|
||||||
export const ocr = async (file: SupportedOcrFile, provider: OcrProvider): Promise<OcrResult> => {
|
export const ocr = async (file: SupportedOcrFile, params: OcrParams): Promise<OcrResult> => {
|
||||||
logger.info(`ocr file ${file.path}`)
|
logger.info(`ocr file ${file.path}`)
|
||||||
if (isOcrApiProvider(provider)) {
|
return window.api.ocr.ocr(file, params)
|
||||||
const client = OcrApiClientFactory.create(provider)
|
|
||||||
return client.ocr(file, provider.config)
|
|
||||||
} else {
|
|
||||||
return window.api.ocr.ocr(file, provider)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,27 +1,29 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { OcrApiProvider } from '@renderer/types'
|
import type { OcrApiProvider, OcrApiProviderConfig } from '@renderer/types'
|
||||||
|
|
||||||
import type { OcrBaseApiClient } from './OcrBaseApiClient'
|
import type { OcrBaseApiClient } from './OcrBaseApiClient'
|
||||||
import { OcrExampleApiClient } from './OcrExampleApiClient'
|
import { OcrExampleApiClient } from './OcrExampleApiClient'
|
||||||
|
|
||||||
const logger = loggerService.withContext('OcrApiClientFactory')
|
const logger = loggerService.withContext('OcrApiClientFactory')
|
||||||
|
|
||||||
|
// Not being used for now.
|
||||||
|
// TODO: Migrate to main in the future.
|
||||||
export class OcrApiClientFactory {
|
export class OcrApiClientFactory {
|
||||||
/**
|
/**
|
||||||
* Create an ApiClient instance for the given provider
|
* Create an ApiClient instance for the given provider
|
||||||
* 为给定的提供者创建ApiClient实例
|
* 为给定的提供者创建ApiClient实例
|
||||||
*/
|
*/
|
||||||
static create(provider: OcrApiProvider): OcrBaseApiClient {
|
static create(provider: OcrApiProvider, config: OcrApiProviderConfig): OcrBaseApiClient {
|
||||||
logger.debug(`Creating ApiClient for provider:`, {
|
logger.debug(`Creating ApiClient for provider:`, {
|
||||||
id: provider.id,
|
id: provider.id,
|
||||||
config: provider.config
|
config
|
||||||
})
|
})
|
||||||
|
|
||||||
let instance: OcrBaseApiClient
|
let instance: OcrBaseApiClient
|
||||||
|
|
||||||
// Extend other clients here
|
// Extend other clients here
|
||||||
// eslint-disable-next-line prefer-const
|
// eslint-disable-next-line prefer-const
|
||||||
instance = new OcrExampleApiClient(provider)
|
instance = new OcrExampleApiClient(provider, config)
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,26 +1,31 @@
|
|||||||
import { cacheService } from '@data/CacheService'
|
import { cacheService } from '@data/CacheService'
|
||||||
import type { OcrApiProvider, OcrHandler } from '@renderer/types'
|
import type { OcrApiProvider, OcrApiProviderConfig, OcrHandler } from '@renderer/types'
|
||||||
|
|
||||||
|
// Not being used for now.
|
||||||
|
// TODO: Migrate to main in the future.
|
||||||
export abstract class OcrBaseApiClient {
|
export abstract class OcrBaseApiClient {
|
||||||
public provider: OcrApiProvider
|
public provider: OcrApiProvider
|
||||||
|
public config: OcrApiProviderConfig
|
||||||
protected host: string
|
protected host: string
|
||||||
protected apiKey: string
|
protected apiKey: string
|
||||||
|
|
||||||
constructor(provider: OcrApiProvider) {
|
constructor(provider: OcrApiProvider, config: OcrApiProviderConfig) {
|
||||||
this.provider = provider
|
this.provider = provider
|
||||||
this.host = this.getHost()
|
this.host = this.getHost()
|
||||||
this.apiKey = this.getApiKey()
|
this.apiKey = this.getApiKey()
|
||||||
|
this.config = config
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract ocr: OcrHandler
|
abstract ocr: OcrHandler
|
||||||
|
|
||||||
// copy from BaseApiClient
|
// copy from BaseApiClient
|
||||||
public getHost(): string {
|
public getHost(): string {
|
||||||
return this.provider.config.api.apiHost
|
return this.config.api.apiHost
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy from BaseApiClient
|
// copy from BaseApiClient
|
||||||
public getApiKey() {
|
public getApiKey() {
|
||||||
const keys = this.provider.config.api.apiKey.split(',').map((key) => key.trim())
|
const keys = this.config.api.apiKey.split(',').map((key) => key.trim())
|
||||||
const keyName = `ocr_provider:${this.provider.id}:last_used_key`
|
const keyName = `ocr_provider:${this.provider.id}:last_used_key`
|
||||||
|
|
||||||
if (keys.length === 1) {
|
if (keys.length === 1) {
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
import type { OcrApiProvider, SupportedOcrFile } from '@renderer/types'
|
import type { OcrApiProvider, OcrApiProviderConfig, SupportedOcrFile } from '@renderer/types'
|
||||||
|
|
||||||
import { OcrBaseApiClient } from './OcrBaseApiClient'
|
import { OcrBaseApiClient } from './OcrBaseApiClient'
|
||||||
|
|
||||||
export type OcrExampleProvider = OcrApiProvider
|
export type OcrExampleProvider = OcrApiProvider
|
||||||
|
|
||||||
|
// Not being used for now.
|
||||||
|
// TODO: Migrate to main in the future.
|
||||||
export class OcrExampleApiClient extends OcrBaseApiClient {
|
export class OcrExampleApiClient extends OcrBaseApiClient {
|
||||||
constructor(provider: OcrApiProvider) {
|
constructor(provider: OcrApiProvider, config: OcrApiProviderConfig) {
|
||||||
super(provider)
|
super(provider, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
public ocr = async (file: SupportedOcrFile) => {
|
public ocr = async (file: SupportedOcrFile) => {
|
||||||
|
|||||||
@ -104,18 +104,23 @@ export const isOcrProvider = (p: unknown): p is OcrProvider => {
|
|||||||
return OcrProviderSchema.safeParse(p).success
|
return OcrProviderSchema.safeParse(p).success
|
||||||
}
|
}
|
||||||
|
|
||||||
export type OcrApiProviderConfig = OcrProviderBaseConfig & {
|
export const OcrApiProviderConfigSchema = OcrProviderBaseConfigSchema.extend({
|
||||||
api: OcrProviderApiConfig
|
api: OcrProviderApiConfigSchema
|
||||||
|
})
|
||||||
|
|
||||||
|
export type OcrApiProviderConfig = z.infer<typeof OcrApiProviderConfigSchema>
|
||||||
|
|
||||||
|
export const isOcrApiProviderConfig = (config: unknown): config is OcrApiProviderConfig => {
|
||||||
|
return OcrApiProviderConfigSchema.safeParse(config).success
|
||||||
}
|
}
|
||||||
|
|
||||||
/** This type is not being used. */
|
export const OcrApiProviderSchema = OcrProviderSchema
|
||||||
export type OcrApiProvider = OcrProvider & {
|
|
||||||
config: OcrApiProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
/** This function is not being used. */
|
/** Currently, there is no API provider yet, but we've left room for expansion. */
|
||||||
export const isOcrApiProvider = (p: OcrProvider): p is OcrApiProvider => {
|
export type OcrApiProvider = z.infer<typeof OcrApiProviderSchema>
|
||||||
return !!(p.config && p.config.api && isOcrProviderApiConfig(p.config.api))
|
|
||||||
|
export const isOcrApiProvider = (p: unknown): p is OcrApiProvider => {
|
||||||
|
return OcrApiProviderSchema.safeParse(p).success
|
||||||
}
|
}
|
||||||
|
|
||||||
export type BuiltinOcrProvider = OcrProvider & {
|
export type BuiltinOcrProvider = OcrProvider & {
|
||||||
@ -153,13 +158,17 @@ export const isSupportedOcrFile = (file: FileMetadata): file is SupportedOcrFile
|
|||||||
return isImageFileMetadata(file)
|
return isImageFileMetadata(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type OcrParams = {
|
||||||
|
providerId: string
|
||||||
|
}
|
||||||
|
|
||||||
export type OcrResult = {
|
export type OcrResult = {
|
||||||
text: string
|
text: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export type OcrHandler = (file: SupportedOcrFile, options?: OcrProviderBaseConfig) => Promise<OcrResult>
|
export type OcrHandler = (file: SupportedOcrFile) => Promise<OcrResult>
|
||||||
|
|
||||||
export type OcrImageHandler = (file: ImageFileMetadata, options?: OcrProviderBaseConfig) => Promise<OcrResult>
|
export type OcrImageHandler = (file: ImageFileMetadata) => Promise<OcrResult>
|
||||||
|
|
||||||
// Tesseract Types
|
// Tesseract Types
|
||||||
export type OcrTesseractConfig = OcrProviderBaseConfig & {
|
export type OcrTesseractConfig = OcrProviderBaseConfig & {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user