diff --git a/packages/shared/data/api/apiSchemas.ts b/packages/shared/data/api/apiSchemas.ts index 7fdb417e22..e92a5756f8 100644 --- a/packages/shared/data/api/apiSchemas.ts +++ b/packages/shared/data/api/apiSchemas.ts @@ -376,10 +376,12 @@ export interface ApiSchemas { response: GetOcrProviderResponse } PATCH: { + params: { id: OcrProviderId } body: PatchOcrProviderRequest response: PatchOcrProviderResponse } PUT: { + params: { id: OcrProviderId } body: PutOcrProviderRequest response: PutOcrProviderResponse } diff --git a/src/main/data/api/handlers/index.ts b/src/main/data/api/handlers/index.ts index 9b524ea46c..70fac91cbb 100644 --- a/src/main/data/api/handlers/index.ts +++ b/src/main/data/api/handlers/index.ts @@ -224,10 +224,16 @@ export const apiHandlers: ApiImplementation = { GET: async ({ params }) => { return ocrService.getProvider(params.id) }, - PATCH: async ({ body }) => { + PATCH: async ({ params, body }) => { + if (params.id !== body.id) { + throw new Error('Provider ID in path does not match ID in body') + } return ocrService.patchProvider(body) }, - PUT: async ({ body }) => { + PUT: async ({ params, body }) => { + if (params.id !== body.id) { + throw new Error('Provider ID in path does not match ID in body') + } return ocrService.putProvider(body) }, DELETE: async ({ params }) => { diff --git a/src/main/services/ocr/OcrService.ts b/src/main/services/ocr/OcrService.ts index bce1e3b01a..7dac466149 100644 --- a/src/main/services/ocr/OcrService.ts +++ b/src/main/services/ocr/OcrService.ts @@ -14,7 +14,7 @@ import type { PutOcrProviderResponse, SupportedOcrFile } from '@types' -import { BuiltinOcrProviderIdMap, BuiltinOcrProviderIds } from '@types' +import { BuiltinOcrProviderIdMap, BuiltinOcrProviderIds, isDbOcrProvider } from '@types' import dayjs from 'dayjs' import { eq } from 'drizzle-orm' import { merge } from 'lodash' @@ -93,6 +93,9 @@ export class OcrService { } const found = providers[0] const newProvider = { ...merge({}, found, update), updatedAt: dayjs().valueOf() } satisfies DbOcrProvider + if (!isDbOcrProvider(newProvider)) { + throw new Error('Invalid OCR provider data') + } const [updated] = await dbService .getDb() .update(ocrProviderTable) @@ -121,6 +124,9 @@ export class OcrService { updatedAt: timestamp } satisfies DbOcrProvider + if (!isDbOcrProvider(newProvider)) { + throw new Error('Invalid OCR provider data') + } const [created] = await dbService.getDb().insert(ocrProviderTable).values(newProvider).returning() return { data: created } @@ -144,6 +150,9 @@ export class OcrService { createdAt: timestamp, updatedAt: timestamp } satisfies DbOcrProvider + if (!isDbOcrProvider(newProvider)) { + throw new Error('Invalid OCR provider data') + } const [created] = await dbService.getDb().insert(ocrProviderTable).values(newProvider).returning() return { data: created } } @@ -154,6 +163,9 @@ export class OcrService { updatedAt: timestamp, createdAt: existed.createdAt } satisfies DbOcrProvider + if (!isDbOcrProvider(newProvider)) { + throw new Error('Invalid OCR provider data') + } const [updated] = await dbService .getDb() .update(ocrProviderTable) diff --git a/src/renderer/src/types/ocr.ts b/src/renderer/src/types/ocr.ts index cdf1686fab..a0355f1428 100644 --- a/src/renderer/src/types/ocr.ts +++ b/src/renderer/src/types/ocr.ts @@ -270,6 +270,10 @@ export const DbOcrProviderSchema = OcrProviderSchema.extend(TimestampExtendShape export type DbOcrProvider = z.infer +export const isDbOcrProvider = (p: unknown): p is DbOcrProvider => { + return DbOcrProviderSchema.safeParse(p).success +} + export type ListOcrProvidersQuery = { registered?: boolean } export const ListOcrProvidersResponseSchema = z.object({