|
1 | 1 | import type { Column, TableRelationalConfig } from 'drizzle-orm' |
2 | 2 | import { is, Table } from 'drizzle-orm' |
3 | 3 | import type { IsNever, Simplify, ValueOf } from 'type-fest' |
4 | | -import type { ZodObject, ZodOptional, ZodType } from 'zod' |
| 4 | +import type { ZodError, ZodObject, ZodOptional, ZodType } from 'zod' |
5 | 5 |
|
| 6 | +import type { MaybePromise } from './collection' |
| 7 | +import type { ApiHttpStatus, ApiRouteHandlerPayloadWithContext, ApiRouteSchema } from './endpoint' |
6 | 8 | import type { Field, FieldRelation, Fields, FieldsInitial, FieldsWithFieldName } from './field' |
7 | 9 |
|
8 | 10 | export function isRelationField(field: Field): field is FieldRelation { |
@@ -115,6 +117,93 @@ export function mapValueToTsValue( |
115 | 117 | return Object.fromEntries(mappedEntries.filter((r) => r.length > 0)) |
116 | 118 | } |
117 | 119 |
|
| 120 | +export async function validateRequestBody< |
| 121 | + TApiRouteSchema extends ApiRouteSchema = any, |
| 122 | + TContext extends Record<string, unknown> = Record<string, unknown>, |
| 123 | +>(schema: TApiRouteSchema, payload: ApiRouteHandlerPayloadWithContext<TApiRouteSchema, TContext>) { |
| 124 | + const zodErrors: ZodError[] = [] |
| 125 | + if (schema.query) { |
| 126 | + const err = await schema.query.safeParseAsync((payload as any).query) |
| 127 | + if (!err.success) { |
| 128 | + zodErrors.push(err.error) |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + if (schema.pathParams) { |
| 133 | + const err = await schema.pathParams.safeParseAsync((payload as any).pathParams) |
| 134 | + if (!err.success) { |
| 135 | + zodErrors.push(err.error) |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + if (schema.headers) { |
| 140 | + const err = await schema.headers.safeParseAsync(payload.headers) |
| 141 | + if (!err.success) { |
| 142 | + zodErrors.push(err.error) |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + if (schema.method !== 'GET' && schema.body) { |
| 147 | + const err = await schema.body.safeParseAsync((payload as any).body) |
| 148 | + if (!err.success) { |
| 149 | + zodErrors.push(err.error) |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + return zodErrors |
| 154 | +} |
| 155 | + |
| 156 | +export function validateResponseBody<TApiRouteSchema extends ApiRouteSchema = any>( |
| 157 | + schema: TApiRouteSchema, |
| 158 | + statusCode: ApiHttpStatus, |
| 159 | + response: any |
| 160 | +) { |
| 161 | + if (!schema.responses[statusCode]) { |
| 162 | + throw new Error(`No response schema defined for status code ${statusCode}`) |
| 163 | + } |
| 164 | + |
| 165 | + const result = schema.responses[statusCode].safeParse(response) |
| 166 | + return result.error |
| 167 | +} |
| 168 | + |
| 169 | +export function withValidator< |
| 170 | + TApiRouteSchema extends ApiRouteSchema, |
| 171 | + TContext extends Record<string, unknown>, |
| 172 | +>( |
| 173 | + schema: TApiRouteSchema, |
| 174 | + handler: ( |
| 175 | + payload: ApiRouteHandlerPayloadWithContext<TApiRouteSchema, TContext> |
| 176 | + ) => MaybePromise<any> |
| 177 | +): (payload: ApiRouteHandlerPayloadWithContext<TApiRouteSchema, TContext>) => MaybePromise<any> { |
| 178 | + return async (payload: ApiRouteHandlerPayloadWithContext<TApiRouteSchema, TContext>) => { |
| 179 | + const zodErrors = await validateRequestBody(schema, payload) |
| 180 | + if (zodErrors.length > 0) { |
| 181 | + return { |
| 182 | + status: 400, |
| 183 | + body: { |
| 184 | + error: 'Validation failed', |
| 185 | + details: zodErrors.map((e) => e.message), |
| 186 | + }, |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + const response = await handler(payload) |
| 191 | + |
| 192 | + const validationError = validateResponseBody(schema, response.status, response.body) |
| 193 | + if (validationError) { |
| 194 | + return { |
| 195 | + status: 500, |
| 196 | + body: { |
| 197 | + error: 'Response validation failed', |
| 198 | + details: validationError.errors.map((e) => e.message), |
| 199 | + }, |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + return response |
| 204 | + } |
| 205 | +} |
| 206 | + |
118 | 207 | export type JoinArrays<T extends any[]> = Simplify< |
119 | 208 | T extends [infer A] |
120 | 209 | ? IsNever<A> extends true |
|
0 commit comments