diff --git a/.github/workflows/pr_stats.yml b/.github/workflows/pr_stats.yml index e5f16557a..134ab68ed 100644 --- a/.github/workflows/pr_stats.yml +++ b/.github/workflows/pr_stats.yml @@ -14,4 +14,4 @@ jobs: charts: true disable-links: false sort-by: 'REVIEWS' - organization: 'UWOrbital' + organization: 'UWOrbital' \ No newline at end of file diff --git a/gs/backend/exceptions/exception_handlers.py b/gs/backend/exceptions/exception_handlers.py new file mode 100644 index 000000000..275028339 --- /dev/null +++ b/gs/backend/exceptions/exception_handlers.py @@ -0,0 +1,136 @@ +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_401_UNAUTHORIZED, + HTTP_404_NOT_FOUND, + HTTP_409_CONFLICT, + HTTP_422_UNPROCESSABLE_ENTITY, +) + +from gs.backend.exceptions.exceptions import ( + BaseOrbitalError, + DatabaseError, + InvalidArgumentError, + InvalidStateError, + NotFoundError, + ServiceError, + SunPositionError, + UnauthorizedError, + UnknownError, +) + + +def setup_exception_handlers(app: FastAPI) -> None: + """ + @brief Add exception handlers to the FastAPI app + @attribute app (FastAPI) - The FastAPI app to add the exception handlers to + """ + + @app.exception_handler(BaseOrbitalError) + async def base_orbital_exception_handler(request: Request, exc: BaseOrbitalError) -> JSONResponse: + """ + @brief handle all BaseOrbitalError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (BaseOrbitalError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_400_BAD_REQUEST, + content={"message": exc.message}, + ) + + @app.exception_handler(ServiceError) + async def service_exception_handler(request: Request, exc: ServiceError) -> JSONResponse: + """ + @brief handle all ServiceError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (ServiceError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_400_BAD_REQUEST, + content={"message": exc.message}, + ) + + @app.exception_handler(NotFoundError) + async def not_found_exception_handler(request: Request, exc: NotFoundError) -> JSONResponse: + """ + @brief handle all NotFoundError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (NotFoundError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_404_NOT_FOUND, + content={"message": exc.message}, + ) + + @app.exception_handler(InvalidArgumentError) + async def invalid_argument_exception_handler(request: Request, exc: InvalidArgumentError) -> JSONResponse: + """ + @brief handle all InvalidArgumentError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (InvalidArgumentError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + content={"message": exc.message}, + ) + + @app.exception_handler(InvalidStateError) + async def invalid_state_exception_handler(request: Request, exc: InvalidStateError) -> JSONResponse: + """ + @brief handle all InvalidStateError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (InvalidStateError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_409_CONFLICT, + content={"message": exc.message}, + ) + + @app.exception_handler(DatabaseError) + async def data_base_exception_handler(request: Request, exc: DatabaseError) -> JSONResponse: + """ + @brief handle all DatabaseError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (DatabaseError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_400_BAD_REQUEST, + content={"message": exc.message}, + ) + + @app.exception_handler(UnauthorizedError) + async def unauthorized_exception_handler(request: Request, exc: UnauthorizedError) -> JSONResponse: + """ + @brief handle all UnauthorizedError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (UnauthorizedError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_401_UNAUTHORIZED, + content={"message": exc.message}, + ) + + @app.exception_handler(UnknownError) + async def unknown_exception_handler(request: Request, exc: UnknownError) -> JSONResponse: + """ + @brief handle all UnknownError exceptions with HTTPException + @attribute request (Request) - The request that caused the exception + @attribute exc (UnknownError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_400_BAD_REQUEST, + content={"detail": exc.message}, + ) + + @app.exception_handler(SunPositionError) + async def sun_position_exception_handler(request: Request, exc: SunPositionError) -> JSONResponse: + """ + @brief handle all SunPositionError exceptions + @attribute request (Request) - The request that caused the exception + @attribute exc (SunPositionError) - The exception that was raised + """ + return JSONResponse( + status_code=HTTP_400_BAD_REQUEST, + content={"message": exc.message}, + ) diff --git a/gs/backend/main.py b/gs/backend/main.py index 2f4dc9a78..1c55b6539 100644 --- a/gs/backend/main.py +++ b/gs/backend/main.py @@ -2,7 +2,9 @@ from gs.backend.api.backend_setup import setup_middlewares, setup_routes from gs.backend.api.lifespan import lifespan +from gs.backend.exceptions.exception_handlers import setup_exception_handlers app = FastAPI(lifespan=lifespan) setup_routes(app) setup_middlewares(app) +setup_exception_handlers(app) diff --git a/python_test/test_custom_exceptions.py b/python_test/test_custom_exceptions.py new file mode 100644 index 000000000..75411a728 --- /dev/null +++ b/python_test/test_custom_exceptions.py @@ -0,0 +1,86 @@ +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from gs.backend.exceptions.exception_handlers import setup_exception_handlers +from gs.backend.exceptions.exceptions import ( + BaseOrbitalError, + DatabaseError, + InvalidArgumentError, + InvalidStateError, + NotFoundError, + ServiceError, + SunPositionError, + UnauthorizedError, + UnknownError, +) +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_401_UNAUTHORIZED, + HTTP_404_NOT_FOUND, + HTTP_409_CONFLICT, + HTTP_422_UNPROCESSABLE_ENTITY, +) + + +def create_exception_handler_client() -> TestClient: + app = FastAPI() + setup_exception_handlers(app) + + @app.get("/test_base_orbital_exceptions") + async def test_base_orbital_exceptions(): + raise BaseOrbitalError("Test BaseOrbitalError") + + @app.get("/test_database_exceptions") + async def test_database_exceptions(): + raise DatabaseError("Test DatabaseError") + + @app.get("/test_invalid_argument_exceptions") + async def test_invalid_argument_exceptions(): + raise InvalidArgumentError("Test InvalidArgumentError") + + @app.get("/test_invalid_state_exceptions") + async def test_invalid_state_exceptions(): + raise InvalidStateError("Test InvalidStateError") + + @app.get("/test_not_found_exceptions") + async def test_not_found_exceptions(): + raise NotFoundError("Test NotFoundError") + + @app.get("/test_service_exceptions") + async def test_service_exceptions(): + raise ServiceError("Test ServiceError") + + @app.get("/test_sun_position_exceptions") + async def test_sun_position_exceptions(): + raise SunPositionError("Test SunPositionError") + + @app.get("/test_unauthorized_exceptions") + async def test_unauthorized_exceptions(): + raise UnauthorizedError("Test UnauthorizedError") + + @app.get("/test_unknown_exceptions") + async def test_unknown_exceptions(): + raise UnknownError("Test UnknownError") + + return TestClient(app) + + +@pytest.mark.parametrize( + "endpoint, status_code, response_key, expected_message", + [ + ("/test_base_orbital_exceptions", HTTP_400_BAD_REQUEST, "message", "Test BaseOrbitalError"), + ("/test_database_exceptions", HTTP_400_BAD_REQUEST, "message", "Test DatabaseError"), + ("/test_invalid_argument_exceptions", HTTP_422_UNPROCESSABLE_ENTITY, "message", "Test InvalidArgumentError"), + ("/test_invalid_state_exceptions", HTTP_409_CONFLICT, "message", "Test InvalidStateError"), + ("/test_not_found_exceptions", HTTP_404_NOT_FOUND, "message", "Test NotFoundError"), + ("/test_service_exceptions", HTTP_400_BAD_REQUEST, "message", "Test ServiceError"), + ("/test_sun_position_exceptions", HTTP_400_BAD_REQUEST, "message", "Test SunPositionError"), + ("/test_unauthorized_exceptions", HTTP_401_UNAUTHORIZED, "message", "Test UnauthorizedError"), + ("/test_unknown_exceptions", HTTP_400_BAD_REQUEST, "detail", "Test UnknownError"), + ], +) +def test_custom_exceptions(endpoint, status_code, response_key, expected_message): + fastapi_test_client = create_exception_handler_client() + response = fastapi_test_client.get(endpoint) + assert response.status_code == status_code + assert response.json().get(response_key) == expected_message