Skip to content

Commit dc44b55

Browse files
committed
Token usage security updates
1 parent 242fdb3 commit dc44b55

File tree

4 files changed

+389
-8
lines changed

4 files changed

+389
-8
lines changed

gefapi/__init__.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,58 @@ def api_docs():
763763

764764
jwt = JWTManager(app)
765765

766+
# Token blocklist storage for revoked access tokens
767+
# Uses Redis for production scalability, falls back to in-memory for testing
768+
_revoked_tokens = set()
769+
770+
771+
def get_revoked_tokens_storage():
772+
"""Get the revoked tokens storage (Redis or in-memory fallback)."""
773+
try:
774+
import redis
775+
776+
redis_url = SETTINGS.get("CELERY_BROKER_URL")
777+
if redis_url:
778+
return redis.from_url(redis_url)
779+
except Exception as e:
780+
logger.debug(f"Redis not available for token blocklist, using in-memory: {e}")
781+
return None
782+
783+
784+
def add_token_to_blocklist(jti: str, expires_in_seconds: int = 3600) -> None:
785+
"""Add a token JTI to the blocklist."""
786+
redis_client = get_revoked_tokens_storage()
787+
if redis_client:
788+
try:
789+
# Store in Redis with expiration matching token lifetime
790+
redis_client.setex(f"blocklist:{jti}", expires_in_seconds, "revoked")
791+
return
792+
except Exception as e:
793+
logger.warning(f"Failed to add token to Redis blocklist: {e}")
794+
# Fallback to in-memory (note: not shared across workers)
795+
_revoked_tokens.add(jti)
796+
797+
798+
def is_token_in_blocklist(jti: str) -> bool:
799+
"""Check if a token JTI is in the blocklist."""
800+
redis_client = get_revoked_tokens_storage()
801+
if redis_client:
802+
try:
803+
return redis_client.exists(f"blocklist:{jti}") > 0
804+
except Exception as e:
805+
logger.warning(f"Failed to check Redis blocklist: {e}")
806+
# Fallback to in-memory
807+
return jti in _revoked_tokens
808+
809+
810+
@jwt.token_in_blocklist_loader
811+
def check_if_token_in_blocklist(jwt_header, jwt_payload):
812+
"""Check if JWT access token has been revoked."""
813+
jti = jwt_payload.get("jti")
814+
if not jti:
815+
return False
816+
return is_token_in_blocklist(jti)
817+
766818

767819
from gefapi.models import User # noqa:E402
768820
from gefapi.services import UserService # noqa:E402
@@ -845,7 +897,27 @@ def refresh_token():
845897
@jwt_required()
846898
def logout():
847899
logger.info("[JWT]: User logout...")
848-
refresh_token_string = request.json.get("refresh_token", None)
900+
from flask_jwt_extended import get_jwt
901+
902+
# Revoke the current access token by adding its JTI to blocklist
903+
try:
904+
jwt_data = get_jwt()
905+
jti = jwt_data.get("jti")
906+
if jti:
907+
# Get remaining token lifetime for blocklist expiration
908+
exp = jwt_data.get("exp", 0)
909+
import time
910+
911+
remaining_seconds = max(int(exp - time.time()), 0) + 60 # Add buffer
912+
add_token_to_blocklist(jti, remaining_seconds)
913+
logger.info(f"[JWT]: Access token {jti[:8]}... added to blocklist")
914+
except Exception as e:
915+
logger.warning(f"[JWT]: Failed to revoke access token: {e}")
916+
917+
# Revoke refresh token if provided
918+
refresh_token_string = None
919+
if request.json:
920+
refresh_token_string = request.json.get("refresh_token", None)
849921

850922
if refresh_token_string:
851923
# Import here to avoid circular imports
@@ -878,6 +950,47 @@ def user_lookup_callback(_jwt_header, jwt_data):
878950
return User.query.filter_by(id=identity).one_or_none()
879951

880952

953+
@jwt.expired_token_loader
954+
def expired_token_callback(jwt_header, jwt_payload):
955+
"""Handle expired JWT tokens with consistent error response."""
956+
logger.debug("[JWT]: Expired token detected")
957+
return jsonify(
958+
{"status": 401, "detail": "Token has expired", "error": "token_expired"}
959+
), 401
960+
961+
962+
@jwt.invalid_token_loader
963+
def invalid_token_callback(error_message):
964+
"""Handle invalid JWT tokens with consistent error response."""
965+
logger.warning(f"[JWT]: Invalid token: {error_message}")
966+
return jsonify(
967+
{"status": 401, "detail": "Invalid token", "error": "invalid_token"}
968+
), 401
969+
970+
971+
@jwt.unauthorized_loader
972+
def missing_token_callback(error_message):
973+
"""Handle missing JWT tokens with consistent error response."""
974+
logger.debug(f"[JWT]: Missing token: {error_message}")
975+
return jsonify(
976+
{
977+
"status": 401,
978+
"detail": "Authorization token required",
979+
"error": "authorization_required",
980+
}
981+
), 401
982+
983+
984+
@jwt.revoked_token_loader
985+
def revoked_token_callback(jwt_header, jwt_payload):
986+
"""Handle revoked JWT tokens with consistent error response."""
987+
jti = jwt_payload.get("jti", "unknown")
988+
logger.info(f"[JWT]: Revoked token access attempt: {jti[:8]}...")
989+
return jsonify(
990+
{"status": 401, "detail": "Token has been revoked", "error": "token_revoked"}
991+
), 401
992+
993+
881994
@app.errorhandler(403)
882995
def forbidden(e):
883996
return error(status=403, detail="Forbidden")

gefapi/config/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@
7272
"JWT_ACCESS_TOKEN_EXPIRES": timedelta(seconds=60 * 60 * 1),
7373
"JWT_REFRESH_TOKEN_EXPIRES": timedelta(days=30), # 30 days for refresh tokens
7474
"JWT_TOKEN_LOCATION": ["headers"],
75+
"JWT_IDENTITY_CLAIM": "sub", # Standard JWT subject claim for identity
76+
"JWT_BLOCKLIST_ENABLED": True, # Enable token blocklist for revocation
77+
"JWT_BLOCKLIST_TOKEN_CHECKS": ["access"], # Check access tokens against blocklist
7578
"TRUSTED_PROXY_COUNT": int(os.getenv("TRUSTED_PROXY_COUNT", "0")),
7679
"INTERNAL_NETWORKS": [
7780
net.strip().strip("\"'") # Remove quotes and whitespace

gefapi/models/refresh_token.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ class RefreshToken(db.Model):
2020
token = db.Column(db.String(255), nullable=False, index=True)
2121
expires_at = db.Column(db.DateTime, nullable=False, index=True)
2222
created_at = db.Column(
23-
db.DateTime, default=datetime.datetime.utcnow, nullable=False
23+
db.DateTime(timezone=True),
24+
default=lambda: datetime.datetime.now(datetime.UTC),
25+
nullable=False,
2426
)
2527
is_revoked = db.Column(db.Boolean, default=False, nullable=False)
2628
device_info = db.Column(db.String(500)) # Store user agent, IP, etc.
@@ -48,19 +50,62 @@ def generate_token():
4850
@staticmethod
4951
def default_expiry():
5052
"""Default expiry time (30 days from now)"""
51-
return datetime.datetime.utcnow() + datetime.timedelta(days=30)
52-
53-
def is_valid(self):
54-
"""Check if token is valid (not expired and not revoked)"""
55-
return not self.is_revoked and self.expires_at > datetime.datetime.utcnow()
53+
return datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=30)
54+
55+
def is_valid(self, verify_client_ip=False, current_ip=None):
56+
"""Check if token is valid (not expired and not revoked).
57+
58+
Args:
59+
verify_client_ip: If True, optionally verify the client IP matches
60+
current_ip: Current client IP address for verification
61+
62+
Returns:
63+
bool: True if token is valid, False otherwise
64+
"""
65+
if self.is_revoked or self.expires_at <= datetime.datetime.now(datetime.UTC):
66+
return False
67+
68+
# Optional client IP verification for additional security
69+
if verify_client_ip and current_ip and self.device_info:
70+
stored_ip = self._extract_ip_from_device_info()
71+
if stored_ip and stored_ip != current_ip:
72+
# Log suspicious activity but don't fail by default
73+
# This provides visibility without breaking existing clients
74+
import logging
75+
76+
logger = logging.getLogger(__name__)
77+
logger.warning(
78+
f"Token IP mismatch: stored={stored_ip}, current={current_ip}, "
79+
f"token_id={self.id}"
80+
)
81+
82+
return True
83+
84+
def _extract_ip_from_device_info(self):
85+
"""Extract IP address from stored device_info string."""
86+
if not self.device_info:
87+
return None
88+
# device_info format: "IP: x.x.x.x | UA: ..."
89+
if self.device_info.startswith("IP: "):
90+
parts = self.device_info.split(" | ")
91+
if parts:
92+
return parts[0].replace("IP: ", "").strip()
93+
return None
94+
95+
def get_client_fingerprint(self):
96+
"""Get a fingerprint of the client that created this token."""
97+
return {
98+
"ip_address": self._extract_ip_from_device_info(),
99+
"device_info": self.device_info,
100+
}
56101

57102
def revoke(self):
58103
"""Revoke the refresh token"""
59104
self.is_revoked = True
60105

61106
def update_last_used(self):
62107
"""Update last used timestamp"""
63-
self.last_used_at = datetime.datetime.utcnow()
108+
self.last_used_at = datetime.datetime.now(datetime.UTC)
64109

65110
def serialize(self):
66111
"""Return object data in easily serializable format"""

0 commit comments

Comments
 (0)