mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-08 22:39:36 +08:00
feat(ocr): add validation for OCR provider operations
- Add params validation in API handlers to ensure path ID matches body ID - Introduce isDbOcrProvider type guard for runtime validation - Validate provider data before database operations
This commit is contained in:
parent
e4b5e70c34
commit
281632f859
@ -376,10 +376,12 @@ export interface ApiSchemas {
|
||||
response: GetOcrProviderResponse
|
||||
}
|
||||
PATCH: {
|
||||
params: { id: OcrProviderId }
|
||||
body: PatchOcrProviderRequest
|
||||
response: PatchOcrProviderResponse
|
||||
}
|
||||
PUT: {
|
||||
params: { id: OcrProviderId }
|
||||
body: PutOcrProviderRequest
|
||||
response: PutOcrProviderResponse
|
||||
}
|
||||
|
||||
@ -224,10 +224,16 @@ export const apiHandlers: ApiImplementation = {
|
||||
GET: async ({ params }) => {
|
||||
return ocrService.getProvider(params.id)
|
||||
},
|
||||
PATCH: async ({ body }) => {
|
||||
PATCH: async ({ params, body }) => {
|
||||
if (params.id !== body.id) {
|
||||
throw new Error('Provider ID in path does not match ID in body')
|
||||
}
|
||||
return ocrService.patchProvider(body)
|
||||
},
|
||||
PUT: async ({ body }) => {
|
||||
PUT: async ({ params, body }) => {
|
||||
if (params.id !== body.id) {
|
||||
throw new Error('Provider ID in path does not match ID in body')
|
||||
}
|
||||
return ocrService.putProvider(body)
|
||||
},
|
||||
DELETE: async ({ params }) => {
|
||||
|
||||
@ -14,7 +14,7 @@ import type {
|
||||
PutOcrProviderResponse,
|
||||
SupportedOcrFile
|
||||
} from '@types'
|
||||
import { BuiltinOcrProviderIdMap, BuiltinOcrProviderIds } from '@types'
|
||||
import { BuiltinOcrProviderIdMap, BuiltinOcrProviderIds, isDbOcrProvider } from '@types'
|
||||
import dayjs from 'dayjs'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { merge } from 'lodash'
|
||||
@ -93,6 +93,9 @@ export class OcrService {
|
||||
}
|
||||
const found = providers[0]
|
||||
const newProvider = { ...merge({}, found, update), updatedAt: dayjs().valueOf() } satisfies DbOcrProvider
|
||||
if (!isDbOcrProvider(newProvider)) {
|
||||
throw new Error('Invalid OCR provider data')
|
||||
}
|
||||
const [updated] = await dbService
|
||||
.getDb()
|
||||
.update(ocrProviderTable)
|
||||
@ -121,6 +124,9 @@ export class OcrService {
|
||||
updatedAt: timestamp
|
||||
} satisfies DbOcrProvider
|
||||
|
||||
if (!isDbOcrProvider(newProvider)) {
|
||||
throw new Error('Invalid OCR provider data')
|
||||
}
|
||||
const [created] = await dbService.getDb().insert(ocrProviderTable).values(newProvider).returning()
|
||||
|
||||
return { data: created }
|
||||
@ -144,6 +150,9 @@ export class OcrService {
|
||||
createdAt: timestamp,
|
||||
updatedAt: timestamp
|
||||
} satisfies DbOcrProvider
|
||||
if (!isDbOcrProvider(newProvider)) {
|
||||
throw new Error('Invalid OCR provider data')
|
||||
}
|
||||
const [created] = await dbService.getDb().insert(ocrProviderTable).values(newProvider).returning()
|
||||
return { data: created }
|
||||
}
|
||||
@ -154,6 +163,9 @@ export class OcrService {
|
||||
updatedAt: timestamp,
|
||||
createdAt: existed.createdAt
|
||||
} satisfies DbOcrProvider
|
||||
if (!isDbOcrProvider(newProvider)) {
|
||||
throw new Error('Invalid OCR provider data')
|
||||
}
|
||||
const [updated] = await dbService
|
||||
.getDb()
|
||||
.update(ocrProviderTable)
|
||||
|
||||
@ -270,6 +270,10 @@ export const DbOcrProviderSchema = OcrProviderSchema.extend(TimestampExtendShape
|
||||
|
||||
export type DbOcrProvider = z.infer<typeof DbOcrProviderSchema>
|
||||
|
||||
export const isDbOcrProvider = (p: unknown): p is DbOcrProvider => {
|
||||
return DbOcrProviderSchema.safeParse(p).success
|
||||
}
|
||||
|
||||
export type ListOcrProvidersQuery = { registered?: boolean }
|
||||
|
||||
export const ListOcrProvidersResponseSchema = z.object({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user