diff --git a/packages/shared/data/api/apiSchemas.ts b/packages/shared/data/api/apiSchemas.ts index 22f3436772..f183b107aa 100644 --- a/packages/shared/data/api/apiSchemas.ts +++ b/packages/shared/data/api/apiSchemas.ts @@ -1,11 +1,15 @@ // NOTE: Types are defined inline in the schema for simplicity // If needed, specific types can be imported from './apiModels' import type { + CreateOcrProviderRequest, + CreateOcrProviderResponse, GetOcrProviderResponse, ListOcrProvidersResponse, OcrProviderId, PatchOcrProviderRequest, - PatchOcrProviderResponse + PatchOcrProviderResponse, + PutOcrProviderRequest, + PutOcrProviderResponse } from '@types' import type { BodyForPath, ConcreteApiPaths, QueryParamsForPath, ResponseForPath } from './apiPaths' @@ -359,9 +363,8 @@ export interface ApiSchemas { response: ListOcrProvidersResponse } POST: { - body: { - // TODO - } + body: CreateOcrProviderRequest + response: CreateOcrProviderResponse } } @@ -375,7 +378,8 @@ export interface ApiSchemas { response: PatchOcrProviderResponse } PUT: { - // TODO + body: PutOcrProviderRequest + response: PutOcrProviderResponse } DELETE: { // TODO diff --git a/src/main/data/api/handlers/index.ts b/src/main/data/api/handlers/index.ts index 0e843622c4..ae887d4d29 100644 --- a/src/main/data/api/handlers/index.ts +++ b/src/main/data/api/handlers/index.ts @@ -215,8 +215,8 @@ export const apiHandlers: ApiImplementation = { GET: async () => { return ocrService.listProviders() }, - POST: async () => { - throw new Error('Not implemented') + POST: async ({ body }) => { + return ocrService.createProvider(body) } }, @@ -227,8 +227,8 @@ export const apiHandlers: ApiImplementation = { PATCH: async ({ body }) => { return ocrService.patchProvider(body) }, - PUT: async () => { - throw new Error('Not implemented') + PUT: async ({ body }) => { + return ocrService.putProvider(body) }, DELETE: async () => { throw new Error('Not implemented') diff --git a/src/main/services/ocr/OcrService.ts b/src/main/services/ocr/OcrService.ts index 3a837fb049..836d6fa764 100644 --- a/src/main/services/ocr/OcrService.ts +++ b/src/main/services/ocr/OcrService.ts @@ -2,11 +2,15 @@ import { dbService } from '@data/db/DbService' import { ocrProviderTable } from '@data/db/schemas/ocr/provider' import { loggerService } from '@logger' import type { + CreateOcrProviderRequest, + CreateOcrProviderResponse, ListOcrProvidersResponse, OcrParams, OcrResult, PatchOcrProviderRequest, PatchOcrProviderResponse, + PutOcrProviderRequest, + PutOcrProviderResponse, SupportedOcrFile } from '@types' import { BuiltinOcrProviderIds } from '@types' @@ -99,6 +103,46 @@ export class OcrService { return { data: updated } } + public async createProvider(create: CreateOcrProviderRequest): Promise { + const providers = await dbService + .getDb() + .select() + .from(ocrProviderTable) + .where(eq(ocrProviderTable.id, create.id)) + .limit(1) + + if (providers.length > 0) { + throw new Error(`OCR provider ${create.id} already exists`) + } + + const [created] = await dbService.getDb().insert(ocrProviderTable).values(create).returning() + + return { data: created } + } + + public async putProvider(update: PutOcrProviderRequest): Promise { + const providers = await dbService + .getDb() + .select() + .from(ocrProviderTable) + .where(eq(ocrProviderTable.id, update.id)) + .limit(1) + + if (providers.length === 0) { + const [created] = await dbService.getDb().insert(ocrProviderTable).values(update).returning() + return { data: created } + } + + const [updated] = await dbService + .getDb() + .update(ocrProviderTable) + .set(update) + .where(eq(ocrProviderTable.id, update.id)) + .returning() + + return { data: updated } + } + public async ocr(file: SupportedOcrFile, params: OcrParams): Promise { const service = this.registry.get(params.providerId) if (!service) { diff --git a/src/renderer/src/types/ocr.ts b/src/renderer/src/types/ocr.ts index e469314179..f29a350cc5 100644 --- a/src/renderer/src/types/ocr.ts +++ b/src/renderer/src/types/ocr.ts @@ -291,3 +291,23 @@ export const PatchOcrProviderResponseSchema = z.object({ }) export type PatchOcrProviderResponse = z.infer + +export const CreateOcrProviderRequestSchema = OcrProviderSchema + +export type CreateOcrProviderRequest = z.infer + +export const CreateOcrProviderResponseSchema = z.object({ + data: DbOcrProviderSchema +}) + +export type CreateOcrProviderResponse = z.infer + +export const PutOcrProviderRequestSchema = OcrProviderSchema + +export type PutOcrProviderRequest = z.infer + +export const PutOcrProviderResponseSchema = z.object({ + data: DbOcrProviderSchema +}) + +export type PutOcrProviderResponse = z.infer