Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr_stats.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ jobs:
charts: true
disable-links: false
sort-by: 'REVIEWS'
organization: 'UWOrbital'
organization: 'UWOrbital'
136 changes: 136 additions & 0 deletions gs/backend/exceptions/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_404_NOT_FOUND,
HTTP_409_CONFLICT,
HTTP_422_UNPROCESSABLE_ENTITY,
)

from gs.backend.exceptions.exceptions import (
BaseOrbitalError,
DatabaseError,
InvalidArgumentError,
InvalidStateError,
NotFoundError,
ServiceError,
SunPositionError,
UnauthorizedError,
UnknownError,
)


def setup_exception_handlers(app: FastAPI) -> None:
"""
@brief Add exception handlers to the FastAPI app
@attribute app (FastAPI) - The FastAPI app to add the exception handlers to
"""

@app.exception_handler(BaseOrbitalError)
async def base_orbital_exception_handler(request: Request, exc: BaseOrbitalError) -> JSONResponse:
"""
@brief handle all BaseOrbitalError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (BaseOrbitalError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"message": exc.message},
)

@app.exception_handler(ServiceError)
async def service_exception_handler(request: Request, exc: ServiceError) -> JSONResponse:
"""
@brief handle all ServiceError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (ServiceError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"message": exc.message},
)

@app.exception_handler(NotFoundError)
async def not_found_exception_handler(request: Request, exc: NotFoundError) -> JSONResponse:
"""
@brief handle all NotFoundError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (NotFoundError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_404_NOT_FOUND,
content={"message": exc.message},
)

@app.exception_handler(InvalidArgumentError)
async def invalid_argument_exception_handler(request: Request, exc: InvalidArgumentError) -> JSONResponse:
"""
@brief handle all InvalidArgumentError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (InvalidArgumentError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={"message": exc.message},
)

@app.exception_handler(InvalidStateError)
async def invalid_state_exception_handler(request: Request, exc: InvalidStateError) -> JSONResponse:
"""
@brief handle all InvalidStateError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (InvalidStateError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_409_CONFLICT,
content={"message": exc.message},
)

@app.exception_handler(DatabaseError)
async def data_base_exception_handler(request: Request, exc: DatabaseError) -> JSONResponse:
"""
@brief handle all DatabaseError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (DatabaseError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"message": exc.message},
)

@app.exception_handler(UnauthorizedError)
async def unauthorized_exception_handler(request: Request, exc: UnauthorizedError) -> JSONResponse:
"""
@brief handle all UnauthorizedError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (UnauthorizedError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_401_UNAUTHORIZED,
content={"message": exc.message},
)

@app.exception_handler(UnknownError)
async def unknown_exception_handler(request: Request, exc: UnknownError) -> JSONResponse:
"""
@brief handle all UnknownError exceptions with HTTPException
@attribute request (Request) - The request that caused the exception
@attribute exc (UnknownError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"detail": exc.message},
)

@app.exception_handler(SunPositionError)
async def sun_position_exception_handler(request: Request, exc: SunPositionError) -> JSONResponse:
"""
@brief handle all SunPositionError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (SunPositionError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"message": exc.message},
)
2 changes: 2 additions & 0 deletions gs/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from gs.backend.api.backend_setup import setup_middlewares, setup_routes
from gs.backend.api.lifespan import lifespan
from gs.backend.exceptions.exception_handlers import setup_exception_handlers

app = FastAPI(lifespan=lifespan)
setup_routes(app)
setup_middlewares(app)
setup_exception_handlers(app)
86 changes: 86 additions & 0 deletions python_test/test_custom_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from gs.backend.exceptions.exception_handlers import setup_exception_handlers
from gs.backend.exceptions.exceptions import (
BaseOrbitalError,
DatabaseError,
InvalidArgumentError,
InvalidStateError,
NotFoundError,
ServiceError,
SunPositionError,
UnauthorizedError,
UnknownError,
)
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_404_NOT_FOUND,
HTTP_409_CONFLICT,
HTTP_422_UNPROCESSABLE_ENTITY,
)


def create_exception_handler_client() -> TestClient:
app = FastAPI()
setup_exception_handlers(app)

@app.get("/test_base_orbital_exceptions")
async def test_base_orbital_exceptions():
raise BaseOrbitalError("Test BaseOrbitalError")

@app.get("/test_database_exceptions")
async def test_database_exceptions():
raise DatabaseError("Test DatabaseError")

@app.get("/test_invalid_argument_exceptions")
async def test_invalid_argument_exceptions():
raise InvalidArgumentError("Test InvalidArgumentError")

@app.get("/test_invalid_state_exceptions")
async def test_invalid_state_exceptions():
raise InvalidStateError("Test InvalidStateError")

@app.get("/test_not_found_exceptions")
async def test_not_found_exceptions():
raise NotFoundError("Test NotFoundError")

@app.get("/test_service_exceptions")
async def test_service_exceptions():
raise ServiceError("Test ServiceError")

@app.get("/test_sun_position_exceptions")
async def test_sun_position_exceptions():
raise SunPositionError("Test SunPositionError")

@app.get("/test_unauthorized_exceptions")
async def test_unauthorized_exceptions():
raise UnauthorizedError("Test UnauthorizedError")

@app.get("/test_unknown_exceptions")
async def test_unknown_exceptions():
raise UnknownError("Test UnknownError")

return TestClient(app)


@pytest.mark.parametrize(
"endpoint, status_code, response_key, expected_message",
[
("/test_base_orbital_exceptions", HTTP_400_BAD_REQUEST, "message", "Test BaseOrbitalError"),
("/test_database_exceptions", HTTP_400_BAD_REQUEST, "message", "Test DatabaseError"),
("/test_invalid_argument_exceptions", HTTP_422_UNPROCESSABLE_ENTITY, "message", "Test InvalidArgumentError"),
("/test_invalid_state_exceptions", HTTP_409_CONFLICT, "message", "Test InvalidStateError"),
("/test_not_found_exceptions", HTTP_404_NOT_FOUND, "message", "Test NotFoundError"),
("/test_service_exceptions", HTTP_400_BAD_REQUEST, "message", "Test ServiceError"),
("/test_sun_position_exceptions", HTTP_400_BAD_REQUEST, "message", "Test SunPositionError"),
("/test_unauthorized_exceptions", HTTP_401_UNAUTHORIZED, "message", "Test UnauthorizedError"),
("/test_unknown_exceptions", HTTP_400_BAD_REQUEST, "detail", "Test UnknownError"),
],
)
def test_custom_exceptions(endpoint, status_code, response_key, expected_message):
fastapi_test_client = create_exception_handler_client()
response = fastapi_test_client.get(endpoint)
assert response.status_code == status_code
assert response.json().get(response_key) == expected_message
Loading