55import secrets
66import string
77import urllib .parse
8- from datetime import datetime , timedelta
8+ from datetime import datetime , timedelta , timezone
99from typing import Annotated , Optional
1010
1111import aiohttp
@@ -112,15 +112,35 @@ def hash_refresh_token(raw_token: str) -> str:
112112 digest = hmac .new (key , raw_token .encode ("utf-8" ), hashlib .sha256 ).digest ()
113113 return base64 .urlsafe_b64encode (digest ).decode ("utf-8" )
114114
115+ async def get_user_by_id (user_id : int ) -> schemas .User | None :
116+ engine = db .get_engine ()
117+ async_session = db .get_async_session (engine )
115118
116- async def get_any_group_id_for_user (session , user_id : int ) -> int | None :
117- stmt = (
118- select (schemas .GroupMembers .group_id )
119- .where (schemas .GroupMembers .user_id == user_id )
120- .order_by (schemas .GroupMembers .group_id .asc ())
121- .limit (1 )
122- )
123- return await session .scalar (stmt )
119+ async with async_session () as session :
120+ stmt = (
121+ select (schemas .User )
122+ .options (selectinload (schemas .User .groups ))
123+ .where (schemas .User .id == user_id )
124+ )
125+ return await session .scalar (stmt )
126+
127+ def parse_redirect_uri ():
128+ """Parse REDIRECT_URI_ENV once and reuse consistently."""
129+ uri = os .environ ["REDIRECT_URI_ENV" ]
130+ parsed = urllib .parse .urlparse (uri )
131+ hostname = parsed .hostname or ""
132+ scheme = parsed .scheme or "http"
133+ secure = scheme == "https"
134+ cookie_domain = None if hostname in ("localhost" , "127.0.0.1" ) else hostname
135+ return parsed , hostname , cookie_domain , secure
136+
137+ def clear_auth_cookies (response : Response ):
138+ """
139+ Attempt to delete cookies for both host-only and domain cookies"""
140+ _ , hostname , cookie_domain , _ = parse_redirect_uri ()
141+ for dom in {None , cookie_domain , "localhost" , "127.0.0.1" , hostname }:
142+ response .delete_cookie (key = access_token_key , domain = dom )
143+ response .delete_cookie (key = refresh_token_key , domain = dom )
124144
125145
126146async def get_groups_from_header_token (
@@ -147,6 +167,7 @@ async def get_groups_from_header_token(
147167 return token .group
148168
149169
170+
150171async def get_user (sub : str ) -> schemas .User | None :
151172 """Get an existing user"""
152173
@@ -241,11 +262,6 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
241262 return encoded_jwt
242263
243264
244- def get_domain (url : str ):
245- parsed_url = urllib .parse .urlparse (url )
246- return parsed_url .netloc
247-
248-
249265@router .get ("/login" )
250266async def redirect_authorization (return_url : str = None ):
251267 """Redirect to the authorization URL with the appropriate parameters"""
@@ -278,9 +294,7 @@ async def redirect_callback(code: str, state: Optional[str] = None):
278294 "redirect_uri" : uri ,
279295 }
280296
281- # Get the domain for the redirect URL
282- parsed_url = urllib .parse .urlparse (uri )
283- domain = parsed_url .netloc
297+ parsed_url , hostname , cookie_domain , secure = parse_redirect_uri ()
284298
285299 async with aiohttp .ClientSession () as session :
286300 async with session .post (
@@ -346,52 +360,30 @@ async def redirect_callback(code: str, state: Optional[str] = None):
346360
347361 response = RedirectResponse (state if state else "/" )
348362
349- redirect_domain = urllib .parse .urlparse (state ).netloc if state else ""
350- _domain = domain
351- for override in ["localhost" , "127.0.0.1" ]:
352- if override in redirect_domain :
353- _domain = override
354-
355363 response .set_cookie (
356364 access_token_key ,
357365 f"Bearer { access_token } " ,
358- domain = _domain ,
366+ domain = cookie_domain ,
359367 httponly = True ,
360368 samesite = "lax" ,
361369 secure = (parsed_url .scheme == "https" ),
362370 )
363371
364372 refresh_jwt = jwt .encode (
365373 {
374+ "user_id" : user .id ,
366375 "sub" : user .sub ,
367- "exp " : datetime . utcnow ()
368- + timedelta (days = REFRESH_TOKEN_EXPIRE_DAYS ),
376+ "type " : "refresh" ,
377+ "exp" : datetime . utcnow () + timedelta (days = REFRESH_TOKEN_EXPIRE_DAYS ),
369378 },
370379 os .environ ["SECRET_KEY" ],
371380 algorithm = os .environ ["JWT_ENCRYPTION_ALGORITHM" ],
372381 )
373382
374- refresh_hash = hash_refresh_token (refresh_jwt )
375- engine = db .get_engine ()
376- async_session = db .get_async_session (engine )
377-
378- async with async_session () as db_session :
379- refresh_group_id = (
380- await get_any_group_id_for_user (db_session , user .id ) or 1
381- )
382- await db .insert_refresh_token (
383- engine = db .get_engine (),
384- token = refresh_hash ,
385- group_id = refresh_group_id ,
386- expiration = datetime .utcnow ()
387- + timedelta (days = REFRESH_TOKEN_EXPIRE_DAYS ),
388- token_type = "refresh" ,
389- )
390-
391383 response .set_cookie (
392384 refresh_token_key ,
393385 refresh_jwt ,
394- domain = _domain ,
386+ domain = cookie_domain ,
395387 httponly = True ,
396388 samesite = "lax" ,
397389 secure = (parsed_url .scheme == "https" ),
@@ -409,62 +401,44 @@ async def refresh_token(
409401 if not refresh_token :
410402 raise HTTPException (status_code = 401 , detail = "Not authenticated" )
411403
412- refresh_hash = hash_refresh_token (refresh_token )
413-
414- engine = db .get_engine ()
415- async_session = db .get_async_session (engine )
416-
417- # validate refresh token exists in DB + not expired
418- tok = await db .get_refresh_token (async_session = async_session , token = refresh_hash )
419- if tok is None :
420- # delete cookie across possible domainsz
421- main_domain = get_domain (os .environ ["REDIRECT_URI_ENV" ])
422- for dom in [main_domain , "localhost" , "127.0.0.1" , None ]:
423- response .delete_cookie (refresh_token_key , domain = dom )
424- raise HTTPException (status_code = 401 , detail = "Refresh token invalid" )
425-
426- # decode refresh JWT to get user id
404+ #verify the jwt is valid/not expired and signature
427405 try :
428406 payload = jwt .decode (
429407 refresh_token ,
430408 os .environ ["SECRET_KEY" ],
431409 algorithms = [os .environ ["JWT_ENCRYPTION_ALGORITHM" ]],
432410 )
433- sub = payload .get ("sub" )
434- if not sub :
435- raise HTTPException (status_code = 401 , detail = "Refresh token invalid" )
436411 except JWTError :
412+ clear_auth_cookies (response )
413+ raise HTTPException (status_code = 401 , detail = "Refresh token invalid" )
414+
415+ if payload .get ("type" ) != "refresh" :
416+ clear_auth_cookies (response )
417+ raise HTTPException (status_code = 401 , detail = "Refresh token invalid" )
418+
419+ user_id = payload .get ("user_id" )
420+ if not user_id :
421+ clear_auth_cookies (response )
437422 raise HTTPException (status_code = 401 , detail = "Refresh token invalid" )
438423
439- user = await get_user (sub )
424+ #verifying the user_id and group_id
425+ user = await get_user_by_id (int (user_id ))
440426 if user is None :
441427 raise HTTPException (status_code = 404 , detail = "User not found" )
442-
443428 names = {g .name for g in user .groups }
444429 ids = {g .id for g in user .groups }
445- role = (
446- "web_admin"
447- if ("web_admin" in names or "admin" in names or 1 in ids )
448- else "web_user"
449- )
450-
430+ role = "web_admin" if ("web_admin" in names or "admin" in names or 1 in ids ) else "web_user"
431+ #setting new access cookie
451432 access_token = create_access_token (
452433 data = {"sub" : user .sub , "role" : role , "user_id" : user .id , "groups" : list (ids )}
453434 )
454435
455- # set new access cookie
456- uri = os .environ ["REDIRECT_URI_ENV" ]
457- parsed_url = urllib .parse .urlparse (uri )
458- redirect_domain = parsed_url .netloc
459- _domain = redirect_domain
460- for override in ["localhost" , "127.0.0.1" ]:
461- if override in redirect_domain :
462- _domain = override
436+ parsed_url , hostname , cookie_domain , secure = parse_redirect_uri ()
463437
464438 response .set_cookie (
465439 access_token_key ,
466440 f"Bearer { access_token } " ,
467- domain = _domain ,
441+ domain = cookie_domain ,
468442 httponly = True ,
469443 samesite = "lax" ,
470444 secure = (parsed_url .scheme == "https" ),
@@ -473,6 +447,7 @@ async def refresh_token(
473447 return {"status" : "refreshed" }
474448
475449
450+
476451@router .post ("/token" , response_model = AccessToken )
477452async def create_group_token (
478453 group_token_request : GroupTokenRequest ,
@@ -500,27 +475,19 @@ async def create_group_token(
500475 engine = db .get_engine (),
501476 token_hash_string = token_hash_string ,
502477 group_id = group_token_request .group_id ,
503- expiration_dt = datetime .datetime .fromtimestamp (
504- group_token_request .expiration , tz = datetime .timezone .utc
505- ),
478+ expiration_dt = datetime .fromtimestamp (group_token_request .expiration , tz = timezone .utc ),
506479 )
507480
508481 return AccessToken (group = group_token_request .group_id , token = token )
509482
510483
511484@router .post ("/logout" )
512485async def logout (response : Response ):
513- """Logout the active user"""
514-
515- main_domain = get_domain (os .environ ["REDIRECT_URI_ENV" ])
516- # Delete all instances of cookies that we might conceivably have set
517- for domain in [main_domain , "localhost" , "127.0.0.1" , None ]:
518- response .delete_cookie (key = access_token_key , domain = domain )
519- response .delete_cookie (key = refresh_token_key , domain = domain )
520-
486+ clear_auth_cookies (response )
521487 return {"status" : "success" }
522488
523489
490+
524491@router .get ("/groups" )
525492async def get_security_groups (groups : list [int ] = Depends (get_groups )):
526493 """Get the groups for the current user"""
0 commit comments