Skip to content

Commit 913c5ef

Browse files
committed
updated middleware structure
1 parent 0d3d771 commit 913c5ef

File tree

7 files changed

+208
-59
lines changed

7 files changed

+208
-59
lines changed

backend/app/middleware/auth_middleware.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,59 @@
1111
from app.services.implementations.user_service import UserService
1212
from app.utilities.service_utils import get_auth_service, get_db
1313

14+
import firebase_admin
15+
1416
security = HTTPBearer()
1517

1618

1719
def get_token_from_header(
1820
credentials: HTTPAuthorizationCredentials = Security(security),
1921
) -> str:
20-
"""Extract token from Authorization header."""
21-
print("\n=== Token Extraction ===")
22-
print(f"Raw credentials type: {type(credentials)}")
23-
print(f"Raw credentials: {credentials}")
24-
print(f"Token from header: {credentials.credentials[:50]}...") # Print first 50 chars
25-
print("=== End Token Extraction ===\n")
2622
return credentials.credentials
2723

2824

2925
def require_auth(
30-
auth_service: AuthService = Depends(get_auth_service),
3126
token: str = Depends(get_token_from_header),
3227
) -> None:
3328
"""Verify that the request has a valid access token."""
3429
try:
3530
# The token validation is done in the role check, so we just need to check
3631
# if the token is valid for any role
37-
if not auth_service.is_authorized_by_role(token, {role.value for role in UserRole}):
38-
raise HTTPException(
39-
status_code=401,
40-
detail="Invalid or expired token",
41-
)
32+
# if not auth_service.is_authorized_by_role(token, {role.value for role in UserRole}):
33+
# raise HTTPException(
34+
# status_code=401,
35+
# detail="Invalid or expired token",
36+
# )
37+
if token:
38+
return
4239
except Exception as e:
4340
raise HTTPException(
4441
status_code=401,
4542
detail="Invalid or expired token",
4643
)
4744

45+
# def require_auth(func: Callable) -> Callable:
46+
# """Decorator to verify that the request has a valid access token."""
47+
# @wraps(func)
48+
# async def wrapper(
49+
# *args,
50+
# # auth_service: AuthService = Depends(get_auth_service),
51+
# token: str = Depends(get_token_from_header),
52+
# **kwargs
53+
# ):
54+
# try:
55+
# if token:
56+
# raise HTTPException(
57+
# status_code=401,
58+
# detail="valid token",
59+
# )
60+
# return await func(*args, **kwargs)
61+
# except Exception as e:
62+
# raise HTTPException(
63+
# status_code=401,
64+
# detail="Invalid or expired token",
65+
# )
66+
# return wrapper
4867

4968
def require_roles(roles: Set[UserRole]):
5069
"""Require specific roles to access the endpoint."""
@@ -58,17 +77,13 @@ async def wrapper(
5877
**kwargs,
5978
):
6079
try:
61-
print("\n=== Role Check ===")
62-
print(f"Checking roles: {roles}")
6380
role_values = {role.value for role in roles}
64-
print(f"Role values to check: {role_values}")
6581

6682
# Create auth service directly
6783
logger = logging.getLogger(__name__)
6884
auth_service = AuthService(logger=logger, user_service=UserService(db))
6985

7086
is_authorized = auth_service.is_authorized_by_role(token, role_values)
71-
print(f"Authorization result: {is_authorized}")
7287

7388
if not is_authorized:
7489
raise HTTPException(
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from fastapi import Request, HTTPException
2+
from firebase_admin import auth
3+
from typing import Optional, List
4+
from functools import wraps
5+
6+
class FirebaseAuthMiddleware:
7+
def __init__(self, app, exclude_paths: List[str] = None):
8+
self.app = app
9+
self.exclude_paths = exclude_paths or []
10+
11+
async def __call__(self, request: Request, call_next):
12+
if request.url.path in self.exclude_paths:
13+
return await call_next(request)
14+
15+
# Get the Authorization header
16+
authorization = request.headers.get("Authorization")
17+
if not authorization or not authorization.startswith("Bearer "):
18+
raise HTTPException(
19+
status_code=401,
20+
detail="Missing or invalid authorization header"
21+
)
22+
23+
try:
24+
# Verify Firebase token
25+
token = authorization.split("Bearer ")[1]
26+
decoded_token = auth.verify_id_token(token)
27+
28+
# Add user info to request state
29+
request.state.user_id = decoded_token.get("uid")
30+
request.state.user_claims = decoded_token.get("claims", {})
31+
32+
response = await call_next(request)
33+
return response
34+
35+
except Exception as e:
36+
raise HTTPException(
37+
status_code=401,
38+
detail=f"Invalid authentication credentials: {str(e)}"
39+
)
40+
41+
def require_roles(roles: List[str]):
42+
"""Dependency for role-based access control"""
43+
async def role_checker(request: Request):
44+
user_roles = request.state.user_claims.get("roles", [])
45+
if not any(role in user_roles for role in roles):
46+
raise HTTPException(
47+
status_code=403,
48+
detail="Insufficient permissions"
49+
)
50+
return True
51+
return role_checker
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from fastapi import Request
2+
from typing import List
3+
import time
4+
import logging
5+
import uuid
6+
7+
logger = logging.getLogger(__name__)
8+
9+
class UserContextMiddleware:
10+
def __init__(self, app, exclude_paths: List[str] = None):
11+
self.app = app
12+
self.exclude_paths = exclude_paths or []
13+
print("[DEBUG] UserContextMiddleware initialized")
14+
15+
async def __call__(self, scope, receive, send):
16+
if scope["type"] != "http":
17+
return await self.app(scope, receive, send)
18+
19+
request = Request(scope, receive)
20+
21+
start_time = time.time()
22+
print(f"[DEBUG] Processing request in UserContextMiddleware: {request.url.path}")
23+
24+
# Skip excluded paths
25+
if request.url.path in self.exclude_paths:
26+
print(f"[DEBUG] Path {request.url.path} is excluded")
27+
return await self.app(scope, receive, send)
28+
29+
# Initialize state attributes with defaults if they don't exist
30+
if not hasattr(request.state, "request_id"):
31+
request.state.request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
32+
print(f"[DEBUG] Set request_id: {request.state.request_id}")
33+
34+
if not hasattr(request.state, "request_timestamp"):
35+
request.state.request_timestamp = start_time
36+
print(f"[DEBUG] Set timestamp: {request.state.request_timestamp}")
37+
38+
async def send_wrapper(message):
39+
if message["type"] == "http.response.start":
40+
headers = dict(message.get("headers", []))
41+
headers[b"X-Process-Time"] = str(time.time() - start_time).encode()
42+
headers[b"X-Request-ID"] = str(request.state.request_id).encode()
43+
message["headers"] = [(k, v) for k, v in headers.items()]
44+
await send(message)
45+
46+
try:
47+
await self.app(scope, receive, send_wrapper)
48+
except Exception as e:
49+
print(f"[DEBUG] Error in UserContextMiddleware: {str(e)}")
50+
raise

backend/app/routes/auth.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from ..services.implementations.user_service import UserService
77
from ..utilities.db_utils import get_db
88
from ..utilities.service_utils import get_auth_service
9-
from ..middleware.auth import get_current_user
9+
from ..middleware.auth import get_current_user, require_roles
1010
import logging
1111

12+
from ..schemas.user import UserRole
13+
1214
router = APIRouter(prefix="/auth", tags=["auth"])
1315
security = HTTPBearer()
1416

@@ -42,3 +44,16 @@ async def refresh(
4244
except Exception as e:
4345
raise HTTPException(status_code=401, detail=str(e))
4446

47+
@router.get("/verify", response_model=dict)
48+
async def verify_token(
49+
current_user = Depends(get_current_user)
50+
):
51+
"""Test endpoint to verify a bearer token is valid and return the user info"""
52+
try:
53+
return {
54+
"message": "Token is valid",
55+
"user": current_user["user"]
56+
}
57+
except Exception as e:
58+
raise HTTPException(status_code=401, detail=str(e))
59+
Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from fastapi import APIRouter, Depends
2-
from ..middleware.auth_middleware import require_auth, require_roles, require_user_id
1+
from fastapi import APIRouter, Depends, Request
2+
from ..middleware.auth_middleware import require_auth, require_roles, require_user_id, get_token_from_header
3+
from ..middleware.firebase_auth_middleware import require_roles as firebase_require_roles
34
from ..schemas.user import UserRole
45

56
router = APIRouter(prefix="/test", tags=["test"])
@@ -11,36 +12,26 @@
1112
# """Test endpoint requiring just authentication"""
1213
# return {"message": "You are authenticated!"}
1314

14-
# Role-based tests
15-
@router.get("/admin-only")
16-
@require_roles({UserRole.ADMIN})
17-
async def test_admin_only():
18-
"""Test endpoint requiring admin role"""
19-
return {"message": "You are an admin!"}
15+
# Basic Firebase middleware test
16+
@router.get("/auth-middleware")
17+
async def test_firebase_middleware(request: Request):
18+
"""Test endpoint to verify Firebase middleware is working"""
19+
return {
20+
"message": "Firebase auth successful",
21+
"user_id": request.state.user_id,
22+
"claims": request.state.user_claims
23+
}
2024

21-
@router.get("/volunteer-or-admin")
22-
@require_roles({UserRole.VOLUNTEER, UserRole.ADMIN})
23-
async def test_volunteer_or_admin():
24-
"""Test endpoint requiring volunteer or admin role"""
25-
return {"message": "You are a volunteer or admin!"}
26-
27-
@router.get("/participant-only")
28-
@require_roles({UserRole.PARTICIPANT})
29-
async def test_participant_only():
30-
"""Test endpoint requiring participant role"""
31-
return {"message": "You are a participant!"}
32-
33-
# User-specific tests
34-
@router.get("/users/{user_id}/profile")
35-
@require_user_id()
36-
async def test_user_specific(user_id: str):
37-
"""Test endpoint requiring specific user access"""
38-
return {"message": f"You can access user {user_id}'s profile!"}
39-
40-
# Combined tests
41-
@router.get("/users/{user_id}/admin-action")
42-
@require_roles({UserRole.ADMIN})
43-
@require_user_id()
44-
async def test_admin_user_specific(user_id: str):
45-
"""Test endpoint requiring both admin role and specific user access"""
46-
return {"message": f"You are an admin accessing user {user_id}'s data!"}
25+
# Test user context middleware
26+
@router.get("/context")
27+
async def test_context(request: Request):
28+
"""Test endpoint to verify user context middleware"""
29+
try:
30+
return {
31+
"request_id": request.state.request_id,
32+
"timestamp": request.state.request_timestamp,
33+
}
34+
except Exception as e:
35+
return {
36+
"error": str(e)
37+
}

backend/app/server.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44

55
from dotenv import load_dotenv
66
from fastapi import FastAPI
7+
from fastapi.middleware.cors import CORSMiddleware
78

89
from app.routes import auth, email, test_endpoints
10+
from app.middleware.firebase_auth_middleware import FirebaseAuthMiddleware
11+
from app.middleware.user_context_middleware import UserContextMiddleware
912

1013
load_dotenv()
1114

@@ -16,6 +19,16 @@
1619

1720
log = logging.getLogger("uvicorn")
1821

22+
# Define paths that don't require authentication
23+
PUBLIC_PATHS = [
24+
"/",
25+
"/docs",
26+
"/redoc",
27+
"/openapi.json",
28+
"/auth/login",
29+
"/auth/register",
30+
"/health"
31+
]
1932

2033
@asynccontextmanager
2134
async def lifespan(_: FastAPI):
@@ -25,10 +38,31 @@ async def lifespan(_: FastAPI):
2538
yield
2639
log.info("Shutting down...")
2740

41+
app = FastAPI(lifespan=lifespan)
42+
43+
# Add CORS middleware first
44+
# app.add_middleware(
45+
# CORSMiddleware,
46+
# allow_origins=["*"], # Configure this appropriately for production
47+
# allow_credentials=True,
48+
# allow_methods=["*"],
49+
# allow_headers=["*"],
50+
# )
51+
52+
# Add our custom middleware
53+
# Note: Middleware is executed in reverse order (last added = first executed)
54+
app.add_middleware(
55+
UserContextMiddleware,
56+
exclude_paths=PUBLIC_PATHS
57+
)
58+
59+
# app.add_middleware(
60+
# FirebaseAuthMiddleware,
61+
# exclude_paths=PUBLIC_PATHS
62+
# )
2863

2964
# Source: https://stackoverflow.com/questions/77170361/
3065
# running-alembic-migrations-on-fastapi-startup
31-
app = FastAPI(lifespan=lifespan)
3266
app.include_router(auth.router)
3367
app.include_router(user.router)
3468
app.include_router(email.router)

backend/app/services/implementations/auth_service.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,10 @@ def send_email_verification_link(self, email: str) -> None:
5555

5656
def is_authorized_by_role(self, access_token: str, roles: set[str]) -> bool:
5757
try:
58-
print("\n=== Auth Service Role Check ===")
59-
# print(f"Verifying token: {access_token[:50]}...")
6058
decoded_token = firebase_admin.auth.verify_id_token(access_token, check_revoked=True)
61-
print(f"Decoded token UID: {decoded_token.get('uid')}")
6259
user_role = self.user_service.get_user_role_by_auth_id(decoded_token["uid"])
63-
print(f"User role from DB: {user_role}")
64-
print(f"Checking against roles: {roles}")
6560
firebase_user = firebase_admin.auth.get_user(decoded_token["uid"])
66-
print(f"Email verified: {firebase_user.email_verified}")
6761
result = firebase_user.email_verified and user_role in roles
68-
print(f"Final authorization result: {result}")
6962
return result
7063
except Exception as e:
7164
print(f"Authorization error: {str(e)}")

0 commit comments

Comments
 (0)