Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions backend/nebula/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,10 @@ class NebulaException(Exception):
status: int = 500
log: bool = True

def __init__(
self,
detail: str | None = None,
log: bool | str = False,
user_name: str | None = None,
**kwargs: Any,
) -> None:
def __init__(self, detail: str | None = None, **kwargs: Any) -> None:
self.kwargs = kwargs

if detail is not None:
self.detail = detail

if log is True or self.log:
logger.error(f"EXCEPTION: {self.status} {self.detail}", user=user_name)
elif isinstance(log, str):
logger.error(f"EXCEPTION: {self.status} {log}", user=user_name)

super().__init__(self.detail)


Expand Down
28 changes: 26 additions & 2 deletions backend/nebula/log.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import contextvars
import enum
import logging
import sys
import traceback
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any

_log_context: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar(
"_log_context", default=None
)

def get_log_context() -> dict[str, Any] | None:
"""Get the current log context."""
return _log_context.get()

def indent(text: str, level: int = 4) -> str:
return text.replace("\n", f"\n{' ' * level}")
Expand All @@ -26,17 +36,22 @@ class Logger:
level = LogLevel.DEBUG
user_max_length: int = 16


def __call__(
self,
level: LogLevel,
*args: Any,
user: str | None = None,
**kwargs: Any,
) -> None:
if level < self.level:
return

context = get_log_context()
usr = self.user
if context:
usr = context.get("user", self.user)

lvl = level.name.upper()
usr = user or self.user
usr = usr[: self.user_max_length].ljust(self.user_max_length)
msg = " ".join([str(arg) for arg in args])

Expand Down Expand Up @@ -77,6 +92,15 @@ def traceback(self, *args: Any, user: str | None = None) -> str:
def critical(self, *args: Any, user: str | None = None) -> None:
self(LogLevel.CRITICAL, *args, user=user)

@contextmanager
def contextualize(self, **context: Any) -> Generator[None, None, None]:
token = _log_context.set(context)
try:
yield
finally:
_log_context.reset(token)



log = Logger()

Expand Down
58 changes: 58 additions & 0 deletions backend/server/middleware/bubblewrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import traceback

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR

import nebula
from nebula.exceptions import NebulaException


class Bubblewrap(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
return await call_next(request)

except NebulaException as exc:
nebula.log.warning(f"[Bubblewrap] NebulaException: {exc.status} - {exc.detail}")
return JSONResponse(
{"status": exc.status, "detail": exc.detail}, status_code=exc.status
)

except ExceptionGroup as eg:
messages = []
for e in eg.exceptions:
if isinstance(e, BaseException):
tb = "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
nebula.log.error(
f"[Bubblewrap] ExceptionGroup member: {e.__class__.__name__} - {e}\n{tb}"
)
messages.append(f"{e.__class__.__name__}: {str(e)}")
else:
messages.append(f"Non-standard exception: {repr(e)}")

return JSONResponse(
{
"status": HTTP_500_INTERNAL_SERVER_ERROR,
"detail": f"ExceptionGroup: {len(eg.exceptions)} exceptions",
"traceback": messages,
},
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)

except Exception as e:
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
nebula.log.error(
f"[Bubblewrap] Uncaught Exception: {e.__class__.__name__} - {e}\n{tb}"
)
return JSONResponse(
{
"status": HTTP_500_INTERNAL_SERVER_ERROR,
"detail": str(e),
"traceback": tb,
},
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)
112 changes: 112 additions & 0 deletions backend/server/middleware/gatekeeper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import time

from shortuuid import ShortUUID
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request

import nebula
from server.session import Session
from server.utils import parse_access_token

# SESSION_TTL = 3600
# SESSION_REFRESH_THRESHOLD = 300


async def authenticate_token(request: Request) -> nebula.User:
token1 = request.query_params.get("token", None)
token2 = request.headers.get("Authorization", None)
session_id = token1 or parse_access_token(token2 or "")
if not session_id:
raise nebula.UnauthorizedException("No access token provided")
session = await Session.check(session_id, request)
if session is None:
raise nebula.UnauthorizedException("Invalid access token")
return nebula.User(meta=session.user)


async def authenticate_api_key(request: Request) -> nebula.User:
key1 = request.headers.get("x-api-key")
key2 = request.query_params.get("api_key")
if (api_key := key1 or key2) is None:
raise nebula.UnauthorizedException("No API key provided")
try:
return await nebula.User.by_api_key(api_key)
except nebula.NotFoundException as e:
raise nebula.UnauthorizedException("Invalid API key") from e


async def authenticate_session(request: Request) -> nebula.User:
if session_id := request.cookies.get("session_id"):
session = await Session.check(session_id, request)
if session is not None:
return nebula.User(meta=session.user)
raise nebula.UnauthorizedException("No session ID provided")

# if session_id:
# key = f"session:{session_id}"
# session_raw = await redis.get(key)
# if session_raw:
# try:
# session = json.loads(session_raw)
# user_id = session.get("user_id")
# persistent = session.get("persistent", False)
# ttl = await redis.ttl(key)
#
# # Optional: verify IP / UA here if you want
#
# # Refresh if TTL low
# if ttl is not None and ttl < SESSION_REFRESH_THRESHOLD:
# new_ttl = 30*24*3600 if persistent else SESSION_TTL
# await redis.expire(key, new_ttl)
# request.state.refresh_cookie = (session_id, new_ttl)
# except Exception as e:
# # log error if needed
# pass
#
# request.state.user_id = user_id


async def authenticate(request: Request) -> nebula.User:
for auth_method in [authenticate_token, authenticate_session, authenticate_api_key]:
try:
user = await auth_method(request)
except nebula.UnauthorizedException:
continue
return user
raise nebula.UnauthorizedException("No authentication method provided")


def req_id() -> str:
return ShortUUID().random(length=16)


class GatekeeperMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_id = req_id()
context = {"request_id": request_id}
path = request.url.path

with nebula.log.contextualize(**context):
try:
user = await authenticate(request)
context["user"] = user.name
request.state.user = user
request.state.unauthorized_reason = None
except nebula.UnauthorizedException as e:
request.state.user = None
request.state.unauthorized_reason = str(e)

with nebula.log.contextualize(**context):
start_time = time.perf_counter()
status_code = 100

try:
response = await call_next(request)
status_code = response.status_code
return response
finally:
process_time = round(time.perf_counter() - start_time, 3)
f_result = f"| {status_code} in {process_time}s"
nebula.log.trace(f"[{request.method}] {path} {f_result}")