44import firebase_admin .auth
55from ..services .implementations .user_service import UserService
66from ..utilities .db_utils import get_db
7+ from ..schemas .user import UserRole
8+ from functools import wraps
9+ from typing import Set
710
811security = HTTPBearer ()
912logger = logging .getLogger (__name__ )
@@ -13,7 +16,7 @@ async def get_current_user(
1316 db = Depends (get_db )
1417):
1518 """
16- Validates the authorization token and returns the current user
19+ Validates the authorization token and returns the current user and token
1720 """
1821 try :
1922 # Remove 'Bearer ' prefix
@@ -34,18 +37,46 @@ async def get_current_user(
3437 status_code = 401 ,
3538 detail = "User not found in database"
3639 )
37- logger . info ( "workedd boss" )
38- return user
40+
41+ return { " user" : user , "token" : token }
3942
4043 except firebase_admin .auth .InvalidIdTokenError as e :
4144 logger .error (f"Invalid token: { str (e )} " )
4245 raise HTTPException (
43- status_code = 401 ,
46+ status_code = 401 ,
4447 detail = f"Invalid token: { str (e )} "
4548 )
4649 except Exception as e :
4750 logger .error (f"Authentication error: { str (e )} " )
4851 raise HTTPException (
4952 status_code = 401 ,
5053 detail = str (e )
51- )
54+ )
55+
56+ def require_roles (allowed_roles : Set [str ]):
57+ def decorator (func ):
58+ @wraps (func )
59+ async def wrapper (* args , current_user = Depends (get_current_user ), ** kwargs ):
60+ try :
61+ # Get user role using the token from current_user
62+ user_service = UserService (kwargs .get ('db' ))
63+ user_role = user_service .get_user_role_by_auth_id (
64+ firebase_admin .auth .verify_id_token (current_user ["token" ])["uid" ]
65+ )
66+
67+ # Check if user's role is allowed
68+ if user_role not in allowed_roles :
69+ raise HTTPException (
70+ status_code = 403 ,
71+ detail = f"Access denied: role '{ user_role } ' not authorized"
72+ )
73+
74+ return await func (* args , current_user = current_user , ** kwargs )
75+
76+ except Exception as e :
77+ raise HTTPException (
78+ status_code = 403 ,
79+ detail = str (e )
80+ )
81+ return wrapper
82+ return decorator
0 commit comments