1+ from functools import wraps
2+ from typing import Callable , List , Optional , Set
3+ import logging
4+
5+ from fastapi import Depends , HTTPException , Security
6+ from fastapi .security import HTTPAuthorizationCredentials , HTTPBearer
7+ from sqlalchemy .orm import Session
8+
9+ from app .schemas .user import UserRole
10+ from app .services .implementations .auth_service import AuthService
11+ from app .services .implementations .user_service import UserService
12+ from app .utilities .service_utils import get_auth_service , get_db
13+
14+ security = HTTPBearer ()
15+
16+
17+ def get_token_from_header (
18+ credentials : HTTPAuthorizationCredentials = Security (security ),
19+ ) -> str :
20+ """Extract token from Authorization header."""
21+ print ("\n === Token Extraction ===" )
22+ print (f"Raw credentials type: { type (credentials )} " )
23+ print (f"Raw credentials: { credentials } " )
24+ print (f"Token from header: { credentials .credentials [:50 ]} ..." ) # Print first 50 chars
25+ print ("=== End Token Extraction ===\n " )
26+ return credentials .credentials
27+
28+
29+ def require_auth (
30+ auth_service : AuthService = Depends (get_auth_service ),
31+ token : str = Depends (get_token_from_header ),
32+ ) -> None :
33+ """Verify that the request has a valid access token."""
34+ try :
35+ # The token validation is done in the role check, so we just need to check
36+ # if the token is valid for any role
37+ if not auth_service .is_authorized_by_role (token , {role .value for role in UserRole }):
38+ raise HTTPException (
39+ status_code = 401 ,
40+ detail = "Invalid or expired token" ,
41+ )
42+ except Exception as e :
43+ raise HTTPException (
44+ status_code = 401 ,
45+ detail = "Invalid or expired token" ,
46+ )
47+
48+
49+ def require_roles (roles : Set [UserRole ]):
50+ """Require specific roles to access the endpoint."""
51+
52+ def decorator (func : Callable ) -> Callable :
53+ @wraps (func )
54+ async def wrapper (
55+ * args ,
56+ db : Session = Depends (get_db ),
57+ token : str = Depends (get_token_from_header ),
58+ ** kwargs ,
59+ ):
60+ try :
61+ print ("\n === Role Check ===" )
62+ print (f"Checking roles: { roles } " )
63+ role_values = {role .value for role in roles }
64+ print (f"Role values to check: { role_values } " )
65+
66+ # Create auth service directly
67+ logger = logging .getLogger (__name__ )
68+ auth_service = AuthService (logger = logger , user_service = UserService (db ))
69+
70+ is_authorized = auth_service .is_authorized_by_role (token , role_values )
71+ print (f"Authorization result: { is_authorized } " )
72+
73+ if not is_authorized :
74+ raise HTTPException (
75+ status_code = 403 ,
76+ detail = "Insufficient permissions" ,
77+ )
78+ return await func (* args , ** kwargs )
79+ except HTTPException as e :
80+ print (f"HTTP Exception: { str (e )} " )
81+ raise e
82+ except Exception as e :
83+ print (f"Unexpected error: { str (e )} " )
84+ raise HTTPException (
85+ status_code = 401 ,
86+ detail = "Invalid or expired token" ,
87+ )
88+
89+ return wrapper
90+
91+ return decorator
92+
93+
94+ def require_user_id (user_id_param : str = "user_id" ):
95+ """Require that the token belongs to the requested user."""
96+
97+ def decorator (func : Callable ) -> Callable :
98+ @wraps (func )
99+ async def wrapper (
100+ * args ,
101+ auth_service : AuthService = Depends (get_auth_service ),
102+ token : str = Depends (get_token_from_header ),
103+ ** kwargs ,
104+ ):
105+ try :
106+ user_id = kwargs .get (user_id_param )
107+ if not user_id :
108+ raise HTTPException (
109+ status_code = 400 ,
110+ detail = f"Missing { user_id_param } parameter" ,
111+ )
112+
113+ if not auth_service .is_authorized_by_user_id (token , user_id ):
114+ raise HTTPException (
115+ status_code = 403 ,
116+ detail = "Not authorized to access this resource" ,
117+ )
118+ return await func (* args , ** kwargs )
119+ except HTTPException as e :
120+ raise e
121+ except Exception as e :
122+ raise HTTPException (
123+ status_code = 401 ,
124+ detail = "Invalid or expired token" ,
125+ )
126+
127+ return wrapper
128+
129+ return decorator
0 commit comments