|
3 | 3 | from collections.abc import Callable, Coroutine |
4 | 4 | from contextvars import ContextVar |
5 | 5 | from io import StringIO |
6 | | -from typing import Any, Literal, Union |
| 6 | +from typing import Any, Literal, Optional, Union |
7 | 7 |
|
8 | 8 | import fastapi |
9 | 9 | import pytest |
10 | 10 | from dirty_equals import IsPartialDict, IsStr |
11 | 11 | from fastapi import APIRouter, Body, Cookie, File, Header, HTTPException, Query, Request, Response, UploadFile |
12 | 12 | from fastapi.responses import JSONResponse |
13 | 13 | from fastapi.routing import APIRoute |
| 14 | +from fastapi.testclient import TestClient |
14 | 15 | from pydantic import BaseModel, Field, RootModel |
15 | 16 | from starlette.responses import StreamingResponse |
16 | 17 |
|
|
32 | 33 | from cadwyn.structure.schemas import schema |
33 | 34 | from cadwyn.structure.versions import Version, VersionBundle |
34 | 35 | from tests.conftest import ( |
| 36 | + CreateVersionedApp, |
35 | 37 | CreateVersionedClients, |
36 | 38 | client, |
37 | 39 | version_change, |
@@ -1226,3 +1228,41 @@ def response_converter(response: ResponseInfo): |
1226 | 1228 | resp_2001 = client_2001.post(f"/{endpoint}", json={"i": ["original_request"]}) |
1227 | 1229 | assert resp_2001.status_code == 200 |
1228 | 1230 | assert resp_2001.json() == {"i": ["original_request", endpoint]} |
| 1231 | + |
| 1232 | + |
| 1233 | +def test__response_migrations__with_custom_http_exception( |
| 1234 | + create_versioned_app: CreateVersionedApp, |
| 1235 | + router: VersionedAPIRouter, |
| 1236 | +) -> None: |
| 1237 | + class CustomHTTPException(HTTPException): |
| 1238 | + error_code: Optional[str] = None |
| 1239 | + |
| 1240 | + def __init__(self, detail: str, error_code: Optional[str] = None): |
| 1241 | + self.error_code = error_code |
| 1242 | + super().__init__(status_code=400, detail=detail) |
| 1243 | + |
| 1244 | + def http_exception_handler(request, exc): |
| 1245 | + # Check if the exception has an error_code attribute |
| 1246 | + error_code = exc.error_code if hasattr(exc, "error_code") else "generic_error" |
| 1247 | + |
| 1248 | + return JSONResponse( |
| 1249 | + status_code=exc.status_code, |
| 1250 | + content={"code": error_code, "message": exc.detail}, |
| 1251 | + ) |
| 1252 | + |
| 1253 | + # Register exception handler for Cadwyn |
| 1254 | + |
| 1255 | + @router.post("/test") |
| 1256 | + async def endpoint(): |
| 1257 | + raise CustomHTTPException("Cadwyn error occurred", error_code="cadwyn_error") |
| 1258 | + |
| 1259 | + app = create_versioned_app(version_change()) |
| 1260 | + app.add_exception_handler(HTTPException, http_exception_handler) |
| 1261 | + with TestClient(app) as client: |
| 1262 | + resp = client.post("/test", headers={"X-API-VERSION": "2000-01-01"}) |
| 1263 | + assert resp.status_code == 400 |
| 1264 | + assert resp.json() == {"code": "cadwyn_error", "message": "Cadwyn error occurred"} |
| 1265 | + |
| 1266 | + resp_2001 = client.post("/test", headers={"X-API-VERSION": "2001-01-01"}) |
| 1267 | + assert resp_2001.status_code == 400 |
| 1268 | + assert resp_2001.json() == {"code": "cadwyn_error", "message": "Cadwyn error occurred"} |
0 commit comments