diff --git a/backend/nebula/exceptions.py b/backend/nebula/exceptions.py index d719f49a..a790386b 100644 --- a/backend/nebula/exceptions.py +++ b/backend/nebula/exceptions.py @@ -14,23 +14,10 @@ class NebulaException(Exception): status: int = 500 log: bool = True - def __init__( - self, - detail: str | None = None, - log: bool | str = False, - user_name: str | None = None, - **kwargs: Any, - ) -> None: + def __init__(self, detail: str | None = None, **kwargs: Any) -> None: self.kwargs = kwargs - if detail is not None: self.detail = detail - - if log is True or self.log: - logger.error(f"EXCEPTION: {self.status} {self.detail}", user=user_name) - elif isinstance(log, str): - logger.error(f"EXCEPTION: {self.status} {log}", user=user_name) - super().__init__(self.detail) diff --git a/backend/nebula/log.py b/backend/nebula/log.py index 9dd7b268..9c8a0212 100644 --- a/backend/nebula/log.py +++ b/backend/nebula/log.py @@ -1,9 +1,19 @@ +import contextvars import enum import logging import sys import traceback +from collections.abc import Generator +from contextlib import contextmanager from typing import Any +_log_context: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( + "_log_context", default=None +) + +def get_log_context() -> dict[str, Any] | None: + """Get the current log context.""" + return _log_context.get() def indent(text: str, level: int = 4) -> str: return text.replace("\n", f"\n{' ' * level}") @@ -26,17 +36,22 @@ class Logger: level = LogLevel.DEBUG user_max_length: int = 16 + def __call__( self, level: LogLevel, *args: Any, - user: str | None = None, + **kwargs: Any, ) -> None: if level < self.level: return + context = get_log_context() + usr = self.user + if context: + usr = context.get("user", self.user) + lvl = level.name.upper() - usr = user or self.user usr = usr[: self.user_max_length].ljust(self.user_max_length) msg = " ".join([str(arg) for arg in args]) @@ -77,6 +92,15 @@ def traceback(self, *args: Any, user: str | None = None) -> str: def critical(self, *args: Any, user: str | None = None) -> None: self(LogLevel.CRITICAL, *args, user=user) + @contextmanager + def contextualize(self, **context: Any) -> Generator[None, None, None]: + token = _log_context.set(context) + try: + yield + finally: + _log_context.reset(token) + + log = Logger() diff --git a/backend/server/middleware/bubblewrap.py b/backend/server/middleware/bubblewrap.py new file mode 100644 index 00000000..c0c0b762 --- /dev/null +++ b/backend/server/middleware/bubblewrap.py @@ -0,0 +1,58 @@ +import traceback + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR + +import nebula +from nebula.exceptions import NebulaException + + +class Bubblewrap(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + return await call_next(request) + + except NebulaException as exc: + nebula.log.warning(f"[Bubblewrap] NebulaException: {exc.status} - {exc.detail}") + return JSONResponse( + {"status": exc.status, "detail": exc.detail}, status_code=exc.status + ) + + except ExceptionGroup as eg: + messages = [] + for e in eg.exceptions: + if isinstance(e, BaseException): + tb = "".join( + traceback.format_exception(type(e), e, e.__traceback__) + ) + nebula.log.error( + f"[Bubblewrap] ExceptionGroup member: {e.__class__.__name__} - {e}\n{tb}" + ) + messages.append(f"{e.__class__.__name__}: {str(e)}") + else: + messages.append(f"Non-standard exception: {repr(e)}") + + return JSONResponse( + { + "status": HTTP_500_INTERNAL_SERVER_ERROR, + "detail": f"ExceptionGroup: {len(eg.exceptions)} exceptions", + "traceback": messages, + }, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + ) + + except Exception as e: + tb = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + nebula.log.error( + f"[Bubblewrap] Uncaught Exception: {e.__class__.__name__} - {e}\n{tb}" + ) + return JSONResponse( + { + "status": HTTP_500_INTERNAL_SERVER_ERROR, + "detail": str(e), + "traceback": tb, + }, + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/backend/server/middleware/gatekeeper.py b/backend/server/middleware/gatekeeper.py new file mode 100644 index 00000000..c2a6b453 --- /dev/null +++ b/backend/server/middleware/gatekeeper.py @@ -0,0 +1,112 @@ +import time + +from shortuuid import ShortUUID +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +import nebula +from server.session import Session +from server.utils import parse_access_token + +# SESSION_TTL = 3600 +# SESSION_REFRESH_THRESHOLD = 300 + + +async def authenticate_token(request: Request) -> nebula.User: + token1 = request.query_params.get("token", None) + token2 = request.headers.get("Authorization", None) + session_id = token1 or parse_access_token(token2 or "") + if not session_id: + raise nebula.UnauthorizedException("No access token provided") + session = await Session.check(session_id, request) + if session is None: + raise nebula.UnauthorizedException("Invalid access token") + return nebula.User(meta=session.user) + + +async def authenticate_api_key(request: Request) -> nebula.User: + key1 = request.headers.get("x-api-key") + key2 = request.query_params.get("api_key") + if (api_key := key1 or key2) is None: + raise nebula.UnauthorizedException("No API key provided") + try: + return await nebula.User.by_api_key(api_key) + except nebula.NotFoundException as e: + raise nebula.UnauthorizedException("Invalid API key") from e + + +async def authenticate_session(request: Request) -> nebula.User: + if session_id := request.cookies.get("session_id"): + session = await Session.check(session_id, request) + if session is not None: + return nebula.User(meta=session.user) + raise nebula.UnauthorizedException("No session ID provided") + + # if session_id: + # key = f"session:{session_id}" + # session_raw = await redis.get(key) + # if session_raw: + # try: + # session = json.loads(session_raw) + # user_id = session.get("user_id") + # persistent = session.get("persistent", False) + # ttl = await redis.ttl(key) + # + # # Optional: verify IP / UA here if you want + # + # # Refresh if TTL low + # if ttl is not None and ttl < SESSION_REFRESH_THRESHOLD: + # new_ttl = 30*24*3600 if persistent else SESSION_TTL + # await redis.expire(key, new_ttl) + # request.state.refresh_cookie = (session_id, new_ttl) + # except Exception as e: + # # log error if needed + # pass + # + # request.state.user_id = user_id + + +async def authenticate(request: Request) -> nebula.User: + for auth_method in [authenticate_token, authenticate_session, authenticate_api_key]: + try: + user = await auth_method(request) + except nebula.UnauthorizedException: + continue + return user + raise nebula.UnauthorizedException("No authentication method provided") + + +def req_id() -> str: + return ShortUUID().random(length=16) + + +class GatekeeperMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + request_id = req_id() + context = {"request_id": request_id} + path = request.url.path + + with nebula.log.contextualize(**context): + try: + user = await authenticate(request) + context["user"] = user.name + request.state.user = user + request.state.unauthorized_reason = None + except nebula.UnauthorizedException as e: + request.state.user = None + request.state.unauthorized_reason = str(e) + + with nebula.log.contextualize(**context): + start_time = time.perf_counter() + status_code = 100 + + try: + response = await call_next(request) + status_code = response.status_code + return response + finally: + process_time = round(time.perf_counter() - start_time, 3) + f_result = f"| {status_code} in {process_time}s" + nebula.log.trace(f"[{request.method}] {path} {f_result}") + +