Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions gs/backend/exceptions/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_404_NOT_FOUND,
HTTP_409_CONFLICT,
HTTP_422_UNPROCESSABLE_ENTITY,
)

from gs.backend.exceptions.exceptions import (
BaseOrbitalError,
DatabaseError,
InvalidArgumentError,
InvalidStateError,
NotFoundError,
ServiceError,
SunPositionError,
UnauthorizedError,
UnknownError,
)


def setup_exception_handlers(app: FastAPI) -> None:
"""
@brief Add exception handlers to the FastAPI app
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you're following our docstring convention for Python

@attribute app (FastAPI) - The FastAPI app to add the exception handlers to
"""

@app.exception_handler(BaseOrbitalError)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this base exception handler ever be called? What handler is used in the case of a general exception (fallback)?

async def base_orbital_exception_handler(request: Request, exc: BaseOrbitalError) -> JSONResponse:
"""
@brief handle all BaseOrbitalError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (BaseOrbitalError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're returning 400 as a fallback for a lot of exceptions, but this is a client-side error that we're returning, even if it wasn't the client's fault. We should also be more specific.

content={"message": exc.message},
)

@app.exception_handler(ServiceError)
async def service_exception_handler(request: Request, exc: ServiceError) -> JSONResponse:
"""
@brief handle all ServiceError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (ServiceError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"message": exc.message},
)

@app.exception_handler(NotFoundError)
async def not_found_exception_handler(request: Request, exc: NotFoundError) -> JSONResponse:
"""
@brief handle all NotFoundError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (NotFoundError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_404_NOT_FOUND,
content={"message": exc.message},
)

@app.exception_handler(InvalidArgumentError)
async def invalid_argument_exception_handler(request: Request, exc: InvalidArgumentError) -> JSONResponse:
"""
@brief handle all InvalidArgumentError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (InvalidArgumentError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content={"message": exc.message},
)

@app.exception_handler(InvalidStateError)
async def invalid_state_exception_handler(request: Request, exc: InvalidStateError) -> JSONResponse:
"""
@brief handle all InvalidStateError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (InvalidStateError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_409_CONFLICT,
content={"message": exc.message},
)

@app.exception_handler(DatabaseError)
async def data_base_exception_handler(request: Request, exc: DatabaseError) -> JSONResponse:
"""
@brief handle all DatabaseError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (DatabaseError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"message": exc.message},
)

@app.exception_handler(UnauthorizedError)
async def unauthorized_exception_handler(request: Request, exc: UnauthorizedError) -> JSONResponse:
"""
@brief handle all UnauthorizedError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (UnauthorizedError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_401_UNAUTHORIZED,
content={"message": exc.message},
)

@app.exception_handler(UnknownError)
async def unknown_exception_handler(request: Request, exc: UnknownError) -> JSONResponse:
"""
@brief handle all UnknownError exceptions with HTTPException
@attribute request (Request) - The request that caused the exception
@attribute exc (UnknownError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"detail": exc.message},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this key is detail and not message?

)

@app.exception_handler(SunPositionError)
async def sun_position_exception_handler(request: Request, exc: SunPositionError) -> JSONResponse:
"""
@brief handle all SunPositionError exceptions
@attribute request (Request) - The request that caused the exception
@attribute exc (SunPositionError) - The exception that was raised
"""
return JSONResponse(
status_code=HTTP_400_BAD_REQUEST,
content={"message": exc.message},
)
2 changes: 2 additions & 0 deletions gs/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from gs.backend.api.backend_setup import setup_middlewares, setup_routes
from gs.backend.api.lifespan import lifespan
from gs.backend.exceptions.exception_handlers import setup_exception_handlers

app = FastAPI(lifespan=lifespan)
setup_routes(app)
setup_middlewares(app)
setup_exception_handlers(app)
86 changes: 86 additions & 0 deletions python_test/test_custom_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from gs.backend.exceptions.exception_handlers import setup_exception_handlers
from gs.backend.exceptions.exceptions import (
BaseOrbitalError,
DatabaseError,
InvalidArgumentError,
InvalidStateError,
NotFoundError,
ServiceError,
SunPositionError,
UnauthorizedError,
UnknownError,
)
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_404_NOT_FOUND,
HTTP_409_CONFLICT,
HTTP_422_UNPROCESSABLE_ENTITY,
)


def create_exception_handler_client() -> TestClient:
app = FastAPI()
setup_exception_handlers(app)

@app.get("/test_base_orbital_exceptions")
async def test_base_orbital_exceptions():
raise BaseOrbitalError("Test BaseOrbitalError")

@app.get("/test_database_exceptions")
async def test_database_exceptions():
raise DatabaseError("Test DatabaseError")

@app.get("/test_invalid_argument_exceptions")
async def test_invalid_argument_exceptions():
raise InvalidArgumentError("Test InvalidArgumentError")

@app.get("/test_invalid_state_exceptions")
async def test_invalid_state_exceptions():
raise InvalidStateError("Test InvalidStateError")

@app.get("/test_not_found_exceptions")
async def test_not_found_exceptions():
raise NotFoundError("Test NotFoundError")

@app.get("/test_service_exceptions")
async def test_service_exceptions():
raise ServiceError("Test ServiceError")

@app.get("/test_sun_position_exceptions")
async def test_sun_position_exceptions():
raise SunPositionError("Test SunPositionError")

@app.get("/test_unauthorized_exceptions")
async def test_unauthorized_exceptions():
raise UnauthorizedError("Test UnauthorizedError")

@app.get("/test_unknown_exceptions")
async def test_unknown_exceptions():
raise UnknownError("Test UnknownError")

return TestClient(app)


@pytest.mark.parametrize(
"endpoint, status_code, response_key, expected_message",
[
("/test_base_orbital_exceptions", HTTP_400_BAD_REQUEST, "message", "Test BaseOrbitalError"),
("/test_database_exceptions", HTTP_400_BAD_REQUEST, "message", "Test DatabaseError"),
("/test_invalid_argument_exceptions", HTTP_422_UNPROCESSABLE_ENTITY, "message", "Test InvalidArgumentError"),
("/test_invalid_state_exceptions", HTTP_409_CONFLICT, "message", "Test InvalidStateError"),
("/test_not_found_exceptions", HTTP_404_NOT_FOUND, "message", "Test NotFoundError"),
("/test_service_exceptions", HTTP_400_BAD_REQUEST, "message", "Test ServiceError"),
("/test_sun_position_exceptions", HTTP_400_BAD_REQUEST, "message", "Test SunPositionError"),
("/test_unauthorized_exceptions", HTTP_401_UNAUTHORIZED, "message", "Test UnauthorizedError"),
("/test_unknown_exceptions", HTTP_400_BAD_REQUEST, "detail", "Test UnknownError"),
],
)
def test_custom_exceptions(endpoint, status_code, response_key, expected_message):
fastapi_test_client = create_exception_handler_client()
response = fastapi_test_client.get(endpoint)
assert response.status_code == status_code
assert response.json().get(response_key) == expected_message
Loading