Skip to content

Commit e84f5cc

Browse files
committed
extensive updates to middleware functionality
1 parent 913c5ef commit e84f5cc

File tree

8 files changed

+215
-94
lines changed

8 files changed

+215
-94
lines changed

backend/app/middleware/__init__.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,57 @@
11
"""
22
Middleware package for LLSC backend.
3-
"""
3+
"""
4+
5+
from typing import List
6+
7+
from fastapi import FastAPI
8+
from .user_context_middleware import UserContextMiddleware
9+
from .firebase_auth_middleware import FirebaseAuthMiddleware
10+
11+
__all__ = [
12+
"UserContextMiddleware",
13+
"FirebaseAuthMiddleware",
14+
"setup_middlewares"
15+
]
16+
17+
def setup_middlewares(
18+
app: FastAPI,
19+
enable_user_context: bool = True,
20+
enable_firebase_auth: bool = True,
21+
auth_exclude_paths: List[str] = None,
22+
user_context_exclude_paths: List[str] = None
23+
):
24+
"""
25+
Set up middleware for the FastAPI application.
26+
27+
Args:
28+
app: The FastAPI application instance
29+
enable_user_context: Whether to enable the user context middleware
30+
enable_firebase_auth: Whether to enable the Firebase auth middleware
31+
auth_exclude_paths: Paths to exclude from Firebase auth checks
32+
user_context_exclude_paths: Paths to exclude from user context middleware
33+
"""
34+
# Default excluded paths
35+
if auth_exclude_paths is None:
36+
auth_exclude_paths = [
37+
"/docs",
38+
"/redoc",
39+
"/openapi.json",
40+
"/health"
41+
]
42+
43+
if user_context_exclude_paths is None:
44+
user_context_exclude_paths = []
45+
46+
# Add middlewares in reverse order (last added is executed first)
47+
if enable_user_context:
48+
app.add_middleware(
49+
UserContextMiddleware,
50+
exclude_paths=user_context_exclude_paths
51+
)
52+
53+
if enable_firebase_auth:
54+
app.add_middleware(
55+
FirebaseAuthMiddleware,
56+
exclude_paths=auth_exclude_paths
57+
)
Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,73 @@
1-
from fastapi import Request, HTTPException
1+
from fastapi import Request
2+
from typing import List
3+
import logging
4+
import firebase_admin
25
from firebase_admin import auth
3-
from typing import Optional, List
4-
from functools import wraps
6+
from fastapi.responses import JSONResponse
7+
8+
logger = logging.getLogger(__name__)
59

610
class FirebaseAuthMiddleware:
711
def __init__(self, app, exclude_paths: List[str] = None):
812
self.app = app
913
self.exclude_paths = exclude_paths or []
14+
logger.info("FirebaseAuthMiddleware initialized")
1015

11-
async def __call__(self, request: Request, call_next):
12-
if request.url.path in self.exclude_paths:
13-
return await call_next(request)
16+
async def __call__(self, scope, receive, send):
17+
if scope["type"] != "http":
18+
return await self.app(scope, receive, send)
1419

15-
# Get the Authorization header
16-
authorization = request.headers.get("Authorization")
17-
if not authorization or not authorization.startswith("Bearer "):
18-
raise HTTPException(
19-
status_code=401,
20-
detail="Missing or invalid authorization header"
21-
)
20+
request = Request(scope, receive)
21+
22+
# Skip excluded paths
23+
if request.url.path in self.exclude_paths:
24+
logger.debug(f"Path {request.url.path} is excluded from auth check")
25+
return await self.app(scope, receive, send)
2226

27+
# Extract token from headers
2328
try:
24-
# Verify Firebase token
25-
token = authorization.split("Bearer ")[1]
26-
decoded_token = auth.verify_id_token(token)
29+
authorization = request.headers.get("Authorization", "")
30+
if not authorization or not authorization.startswith("Bearer "):
31+
return await self.send_error_response(
32+
scope, receive, send,
33+
status_code=401,
34+
detail="Missing or invalid authorization header"
35+
)
2736

28-
# Add user info to request state
29-
request.state.user_id = decoded_token.get("uid")
30-
request.state.user_claims = decoded_token.get("claims", {})
37+
token = authorization.replace("Bearer ", "")
3138

32-
response = await call_next(request)
33-
return response
34-
39+
try:
40+
# Verify the Firebase token
41+
decoded_token = auth.verify_id_token(token)
42+
user_id = decoded_token.get("uid")
43+
44+
# Add user info to request state
45+
request.state.user_id = user_id
46+
request.state.firebase_user = decoded_token
47+
request.state.user_claims = decoded_token.get("claims", {})
48+
logger.debug(f"Authenticated user: {user_id}")
49+
50+
# Continue with the request
51+
return await self.app(scope, receive, send)
52+
53+
except firebase_admin.exceptions.FirebaseError as e:
54+
logger.error(f"Firebase auth error: {str(e)}")
55+
return await self.send_error_response(
56+
scope, receive, send,
57+
status_code=401,
58+
detail="Invalid or expired token"
59+
)
3560
except Exception as e:
36-
raise HTTPException(
37-
status_code=401,
38-
detail=f"Invalid authentication credentials: {str(e)}"
61+
logger.exception(f"Unexpected error in auth middleware: {str(e)}")
62+
return await self.send_error_response(
63+
scope, receive, send,
64+
status_code=401,
65+
detail="Authentication error"
3966
)
4067

41-
def require_roles(roles: List[str]):
42-
"""Dependency for role-based access control"""
43-
async def role_checker(request: Request):
44-
user_roles = request.state.user_claims.get("roles", [])
45-
if not any(role in user_roles for role in roles):
46-
raise HTTPException(
47-
status_code=403,
48-
detail="Insufficient permissions"
49-
)
50-
return True
51-
return role_checker
68+
async def send_error_response(self, scope, receive, send, status_code, detail):
69+
response = JSONResponse(
70+
status_code=status_code,
71+
content={"detail": detail}
72+
)
73+
await response(scope, receive, send)

backend/app/middleware/user_context_middleware.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class UserContextMiddleware:
1010
def __init__(self, app, exclude_paths: List[str] = None):
1111
self.app = app
1212
self.exclude_paths = exclude_paths or []
13-
print("[DEBUG] UserContextMiddleware initialized")
13+
logger.info("UserContextMiddleware initialized")
1414

1515
async def __call__(self, scope, receive, send):
1616
if scope["type"] != "http":
@@ -19,21 +19,21 @@ async def __call__(self, scope, receive, send):
1919
request = Request(scope, receive)
2020

2121
start_time = time.time()
22-
print(f"[DEBUG] Processing request in UserContextMiddleware: {request.url.path}")
22+
logger.debug(f"Processing request in UserContextMiddleware: {request.url.path}")
2323

2424
# Skip excluded paths
2525
if request.url.path in self.exclude_paths:
26-
print(f"[DEBUG] Path {request.url.path} is excluded")
26+
logger.debug(f"Path {request.url.path} is excluded")
2727
return await self.app(scope, receive, send)
2828

2929
# Initialize state attributes with defaults if they don't exist
3030
if not hasattr(request.state, "request_id"):
3131
request.state.request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
32-
print(f"[DEBUG] Set request_id: {request.state.request_id}")
32+
logger.debug(f"Set request_id: {request.state.request_id}")
3333

3434
if not hasattr(request.state, "request_timestamp"):
3535
request.state.request_timestamp = start_time
36-
print(f"[DEBUG] Set timestamp: {request.state.request_timestamp}")
36+
logger.debug(f"Set timestamp: {request.state.request_timestamp}")
3737

3838
async def send_wrapper(message):
3939
if message["type"] == "http.response.start":
@@ -46,5 +46,5 @@ async def send_wrapper(message):
4646
try:
4747
await self.app(scope, receive, send_wrapper)
4848
except Exception as e:
49-
print(f"[DEBUG] Error in UserContextMiddleware: {str(e)}")
49+
logger.exception(f"Error in UserContextMiddleware: {str(e)}")
5050
raise

backend/app/routes/auth.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ async def refresh(
4444
except Exception as e:
4545
raise HTTPException(status_code=401, detail=str(e))
4646

47+
48+
4749
@router.get("/verify", response_model=dict)
4850
async def verify_token(
4951
current_user = Depends(get_current_user)
Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,37 @@
1-
from fastapi import APIRouter, Depends, Request
2-
from ..middleware.auth_middleware import require_auth, require_roles, require_user_id, get_token_from_header
3-
from ..middleware.firebase_auth_middleware import require_roles as firebase_require_roles
4-
from ..schemas.user import UserRole
1+
# from fastapi import APIRouter, Depends, Request
2+
# # from ..middleware.auth_middleware import require_auth, require_roles, require_user_id, get_token_from_header
3+
# # from ..middleware.firebase_auth_middleware import require_roles as firebase_require_roles
4+
# from ..schemas.user import UserRole
55

6-
router = APIRouter(prefix="/test", tags=["test"])
6+
# router = APIRouter(prefix="/test", tags=["test"])
77

8-
# # Basic auth test - any valid token
9-
# @router.get("/auth")
10-
# @require_auth
11-
# async def test_auth():
12-
# """Test endpoint requiring just authentication"""
13-
# return {"message": "You are authenticated!"}
8+
# # # Basic auth test - any valid token
9+
# # @router.get("/auth")
10+
# # @require_auth
11+
# # async def test_auth():
12+
# # """Test endpoint requiring just authentication"""
13+
# # return {"message": "You are authenticated!"}
1414

15-
# Basic Firebase middleware test
16-
@router.get("/auth-middleware")
17-
async def test_firebase_middleware(request: Request):
18-
"""Test endpoint to verify Firebase middleware is working"""
19-
return {
20-
"message": "Firebase auth successful",
21-
"user_id": request.state.user_id,
22-
"claims": request.state.user_claims
23-
}
15+
# # Basic Firebase middleware test
16+
# @router.get("/auth-middleware")
17+
# async def test_firebase_middleware(request: Request):
18+
# """Test endpoint to verify Firebase middleware is working"""
19+
# return {
20+
# "message": "Firebase auth successful",
21+
# "user_id": request.state.user_id,
22+
# "claims": request.state.user_claims
23+
# }
2424

25-
# Test user context middleware
26-
@router.get("/context")
27-
async def test_context(request: Request):
28-
"""Test endpoint to verify user context middleware"""
29-
try:
30-
return {
31-
"request_id": request.state.request_id,
32-
"timestamp": request.state.request_timestamp,
33-
}
34-
except Exception as e:
35-
return {
36-
"error": str(e)
37-
}
25+
# # Test user context middleware
26+
# @router.get("/context")
27+
# async def test_context(request: Request):
28+
# """Test endpoint to verify user context middleware"""
29+
# try:
30+
# return {
31+
# "request_id": request.state.request_id,
32+
# "timestamp": request.state.request_timestamp,
33+
# }
34+
# except Exception as e:
35+
# return {
36+
# "error": str(e)
37+
# }

0 commit comments

Comments
 (0)