Skip to content

Commit a32b738

Browse files
committed
added stateless refresh token and refactored the domains/timezones and deletion of cookies
1 parent c3a1585 commit a32b738

File tree

3 files changed

+58
-136
lines changed

3 files changed

+58
-136
lines changed

services/api-v3/api/database.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,6 @@ async def insert_access_token(
138138
return result
139139

140140

141-
async def insert_refresh_token(
142-
engine: AsyncEngine,
143-
token: str,
144-
group_id: int,
145-
expiration: datetime.datetime,
146-
token_type: str = "refresh",
147-
):
148-
async with engine.begin() as conn:
149-
q = insert(schemas.Token).values(
150-
token=token,
151-
expires_on=expiration,
152-
group=group_id,
153-
token_type=token_type,
154-
)
155-
result = await conn.execute(q)
156-
return result
157-
158-
159141
async def get_access_token(async_session: async_sessionmaker[AsyncSession], token: str):
160142
async with async_session() as session:
161143
select_stmt = select(schemas.Token).where(
@@ -181,33 +163,6 @@ async def get_access_token(async_session: async_sessionmaker[AsyncSession], toke
181163
return result
182164

183165

184-
async def get_refresh_token(
185-
async_session: async_sessionmaker[AsyncSession], token: str
186-
):
187-
async with async_session() as session:
188-
select_stmt = select(schemas.Token).where(
189-
schemas.Token.token == token,
190-
schemas.Token.token_type == "refresh",
191-
)
192-
193-
result = (await session.scalars(select_stmt)).first()
194-
if result is None:
195-
return None
196-
197-
if result.expires_on < datetime.datetime.now(datetime.timezone.utc):
198-
return None
199-
200-
stmt = (
201-
update(schemas.Token)
202-
.where(schemas.Token.id == result.id)
203-
.values(used_on=datetime.datetime.utcnow())
204-
)
205-
await session.execute(stmt)
206-
await session.commit()
207-
208-
return result
209-
210-
211166
#
212167
# Here starts the use on the engine object directly
213168
#

services/api-v3/api/routes/security.py

Lines changed: 57 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import secrets
66
import string
77
import urllib.parse
8-
from datetime import datetime, timedelta
8+
from datetime import datetime, timedelta, timezone
99
from typing import Annotated, Optional
1010

1111
import aiohttp
@@ -112,15 +112,35 @@ def hash_refresh_token(raw_token: str) -> str:
112112
digest = hmac.new(key, raw_token.encode("utf-8"), hashlib.sha256).digest()
113113
return base64.urlsafe_b64encode(digest).decode("utf-8")
114114

115+
async def get_user_by_id(user_id: int) -> schemas.User | None:
116+
engine = db.get_engine()
117+
async_session = db.get_async_session(engine)
115118

116-
async def get_any_group_id_for_user(session, user_id: int) -> int | None:
117-
stmt = (
118-
select(schemas.GroupMembers.group_id)
119-
.where(schemas.GroupMembers.user_id == user_id)
120-
.order_by(schemas.GroupMembers.group_id.asc())
121-
.limit(1)
122-
)
123-
return await session.scalar(stmt)
119+
async with async_session() as session:
120+
stmt = (
121+
select(schemas.User)
122+
.options(selectinload(schemas.User.groups))
123+
.where(schemas.User.id == user_id)
124+
)
125+
return await session.scalar(stmt)
126+
127+
def parse_redirect_uri():
128+
"""Parse REDIRECT_URI_ENV once and reuse consistently."""
129+
uri = os.environ["REDIRECT_URI_ENV"]
130+
parsed = urllib.parse.urlparse(uri)
131+
hostname = parsed.hostname or ""
132+
scheme = parsed.scheme or "http"
133+
secure = scheme == "https"
134+
cookie_domain = None if hostname in ("localhost", "127.0.0.1") else hostname
135+
return parsed, hostname, cookie_domain, secure
136+
137+
def clear_auth_cookies(response: Response):
138+
"""
139+
Attempt to delete cookies for both host-only and domain cookies"""
140+
_, hostname, cookie_domain, _ = parse_redirect_uri()
141+
for dom in {None, cookie_domain, "localhost", "127.0.0.1", hostname}:
142+
response.delete_cookie(key=access_token_key, domain=dom)
143+
response.delete_cookie(key=refresh_token_key, domain=dom)
124144

125145

126146
async def get_groups_from_header_token(
@@ -147,6 +167,7 @@ async def get_groups_from_header_token(
147167
return token.group
148168

149169

170+
150171
async def get_user(sub: str) -> schemas.User | None:
151172
"""Get an existing user"""
152173

@@ -241,11 +262,6 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
241262
return encoded_jwt
242263

243264

244-
def get_domain(url: str):
245-
parsed_url = urllib.parse.urlparse(url)
246-
return parsed_url.netloc
247-
248-
249265
@router.get("/login")
250266
async def redirect_authorization(return_url: str = None):
251267
"""Redirect to the authorization URL with the appropriate parameters"""
@@ -278,9 +294,7 @@ async def redirect_callback(code: str, state: Optional[str] = None):
278294
"redirect_uri": uri,
279295
}
280296

281-
# Get the domain for the redirect URL
282-
parsed_url = urllib.parse.urlparse(uri)
283-
domain = parsed_url.netloc
297+
parsed_url, hostname, cookie_domain, secure = parse_redirect_uri()
284298

285299
async with aiohttp.ClientSession() as session:
286300
async with session.post(
@@ -346,52 +360,30 @@ async def redirect_callback(code: str, state: Optional[str] = None):
346360

347361
response = RedirectResponse(state if state else "/")
348362

349-
redirect_domain = urllib.parse.urlparse(state).netloc if state else ""
350-
_domain = domain
351-
for override in ["localhost", "127.0.0.1"]:
352-
if override in redirect_domain:
353-
_domain = override
354-
355363
response.set_cookie(
356364
access_token_key,
357365
f"Bearer {access_token}",
358-
domain=_domain,
366+
domain=cookie_domain,
359367
httponly=True,
360368
samesite="lax",
361369
secure=(parsed_url.scheme == "https"),
362370
)
363371

364372
refresh_jwt = jwt.encode(
365373
{
374+
"user_id": user.id,
366375
"sub": user.sub,
367-
"exp": datetime.utcnow()
368-
+ timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
376+
"type": "refresh",
377+
"exp": datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
369378
},
370379
os.environ["SECRET_KEY"],
371380
algorithm=os.environ["JWT_ENCRYPTION_ALGORITHM"],
372381
)
373382

374-
refresh_hash = hash_refresh_token(refresh_jwt)
375-
engine = db.get_engine()
376-
async_session = db.get_async_session(engine)
377-
378-
async with async_session() as db_session:
379-
refresh_group_id = (
380-
await get_any_group_id_for_user(db_session, user.id) or 1
381-
)
382-
await db.insert_refresh_token(
383-
engine=db.get_engine(),
384-
token=refresh_hash,
385-
group_id=refresh_group_id,
386-
expiration=datetime.utcnow()
387-
+ timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
388-
token_type="refresh",
389-
)
390-
391383
response.set_cookie(
392384
refresh_token_key,
393385
refresh_jwt,
394-
domain=_domain,
386+
domain=cookie_domain,
395387
httponly=True,
396388
samesite="lax",
397389
secure=(parsed_url.scheme == "https"),
@@ -409,62 +401,44 @@ async def refresh_token(
409401
if not refresh_token:
410402
raise HTTPException(status_code=401, detail="Not authenticated")
411403

412-
refresh_hash = hash_refresh_token(refresh_token)
413-
414-
engine = db.get_engine()
415-
async_session = db.get_async_session(engine)
416-
417-
# validate refresh token exists in DB + not expired
418-
tok = await db.get_refresh_token(async_session=async_session, token=refresh_hash)
419-
if tok is None:
420-
# delete cookie across possible domainsz
421-
main_domain = get_domain(os.environ["REDIRECT_URI_ENV"])
422-
for dom in [main_domain, "localhost", "127.0.0.1", None]:
423-
response.delete_cookie(refresh_token_key, domain=dom)
424-
raise HTTPException(status_code=401, detail="Refresh token invalid")
425-
426-
# decode refresh JWT to get user id
404+
#verify the jwt is valid/not expired and signature
427405
try:
428406
payload = jwt.decode(
429407
refresh_token,
430408
os.environ["SECRET_KEY"],
431409
algorithms=[os.environ["JWT_ENCRYPTION_ALGORITHM"]],
432410
)
433-
sub = payload.get("sub")
434-
if not sub:
435-
raise HTTPException(status_code=401, detail="Refresh token invalid")
436411
except JWTError:
412+
clear_auth_cookies(response)
413+
raise HTTPException(status_code=401, detail="Refresh token invalid")
414+
415+
if payload.get("type") != "refresh":
416+
clear_auth_cookies(response)
417+
raise HTTPException(status_code=401, detail="Refresh token invalid")
418+
419+
user_id = payload.get("user_id")
420+
if not user_id:
421+
clear_auth_cookies(response)
437422
raise HTTPException(status_code=401, detail="Refresh token invalid")
438423

439-
user = await get_user(sub)
424+
#verifying the user_id and group_id
425+
user = await get_user_by_id(int(user_id))
440426
if user is None:
441427
raise HTTPException(status_code=404, detail="User not found")
442-
443428
names = {g.name for g in user.groups}
444429
ids = {g.id for g in user.groups}
445-
role = (
446-
"web_admin"
447-
if ("web_admin" in names or "admin" in names or 1 in ids)
448-
else "web_user"
449-
)
450-
430+
role = "web_admin" if ("web_admin" in names or "admin" in names or 1 in ids) else "web_user"
431+
#setting new access cookie
451432
access_token = create_access_token(
452433
data={"sub": user.sub, "role": role, "user_id": user.id, "groups": list(ids)}
453434
)
454435

455-
# set new access cookie
456-
uri = os.environ["REDIRECT_URI_ENV"]
457-
parsed_url = urllib.parse.urlparse(uri)
458-
redirect_domain = parsed_url.netloc
459-
_domain = redirect_domain
460-
for override in ["localhost", "127.0.0.1"]:
461-
if override in redirect_domain:
462-
_domain = override
436+
parsed_url, hostname, cookie_domain, secure = parse_redirect_uri()
463437

464438
response.set_cookie(
465439
access_token_key,
466440
f"Bearer {access_token}",
467-
domain=_domain,
441+
domain=cookie_domain,
468442
httponly=True,
469443
samesite="lax",
470444
secure=(parsed_url.scheme == "https"),
@@ -473,6 +447,7 @@ async def refresh_token(
473447
return {"status": "refreshed"}
474448

475449

450+
476451
@router.post("/token", response_model=AccessToken)
477452
async def create_group_token(
478453
group_token_request: GroupTokenRequest,
@@ -500,27 +475,19 @@ async def create_group_token(
500475
engine=db.get_engine(),
501476
token_hash_string=token_hash_string,
502477
group_id=group_token_request.group_id,
503-
expiration_dt=datetime.datetime.fromtimestamp(
504-
group_token_request.expiration, tz=datetime.timezone.utc
505-
),
478+
expiration_dt=datetime.fromtimestamp(group_token_request.expiration, tz=timezone.utc),
506479
)
507480

508481
return AccessToken(group=group_token_request.group_id, token=token)
509482

510483

511484
@router.post("/logout")
512485
async def logout(response: Response):
513-
"""Logout the active user"""
514-
515-
main_domain = get_domain(os.environ["REDIRECT_URI_ENV"])
516-
# Delete all instances of cookies that we might conceivably have set
517-
for domain in [main_domain, "localhost", "127.0.0.1", None]:
518-
response.delete_cookie(key=access_token_key, domain=domain)
519-
response.delete_cookie(key=refresh_token_key, domain=domain)
520-
486+
clear_auth_cookies(response)
521487
return {"status": "success"}
522488

523489

490+
524491
@router.get("/groups")
525492
async def get_security_groups(groups: list[int] = Depends(get_groups)):
526493
"""Get the groups for the current user"""

services/api-v3/api/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class Token(Base):
111111
DateTime(timezone=True), server_default=func.now()
112112
)
113113
token_type: Mapped[str] = mapped_column(
114-
TEXT, # match DDL (or keep VARCHAR(32) if you change DDL)
114+
TEXT,
115115
nullable=False,
116116
server_default=text("'api'"),
117117
)

0 commit comments

Comments
 (0)