Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def decode_complete(
issuer: str | Container[str] | None = None,
subject: str | None = None,
leeway: float | timedelta = 0,
now: float | datetime | None = None,
# kwargs
**kwargs: Any,
) -> dict[str, Any]:
Expand Down Expand Up @@ -213,6 +214,8 @@ def decode_complete(
:type issuer: str or typing.Container[str] or None
:param leeway: a time margin in seconds for the expiration check
:type leeway: float or datetime.timedelta
:param now: optional, a Unix time or datetime that is considered to be now
:type now: float or datetime or None
:rtype: dict[str, typing.Any]
:returns: Decoded JWT with the JOSE Header on the key ``header``, the JWS
Payload on the key ``payload``, and the JWS Signature on the key ``signature``.
Expand Down Expand Up @@ -262,6 +265,7 @@ def decode_complete(
issuer=issuer,
leeway=leeway,
subject=subject,
now=now,
)

decoded["payload"] = payload
Expand Down Expand Up @@ -299,6 +303,7 @@ def decode(
subject: str | None = None,
issuer: str | Container[str] | None = None,
leeway: float | timedelta = 0,
now: float | datetime | None = None,
# kwargs
**kwargs: Any,
) -> dict[str, Any]:
Expand Down Expand Up @@ -337,6 +342,8 @@ def decode(
:type issuer: str or typing.Container[str] or None
:param leeway: a time margin in seconds for the expiration check
:type leeway: float or datetime.timedelta
:param now: optional, a Unix time or datetime that is considered to be now
:type now: float or datetime or None
:rtype: dict[str, typing.Any]
:returns: the JWT claims
"""
Expand All @@ -359,6 +366,7 @@ def decode(
subject=subject,
issuer=issuer,
leeway=leeway,
now=now,
)
return decoded["payload"]

Expand All @@ -370,6 +378,7 @@ def _validate_claims(
issuer: Container[str] | str | None = None,
subject: str | None = None,
leeway: float | timedelta = 0,
now: float | datetime | None = None,
) -> None:
if isinstance(leeway, timedelta):
leeway = leeway.total_seconds()
Expand All @@ -379,7 +388,12 @@ def _validate_claims(

self._validate_required_claims(payload, options["require"])

now = datetime.now(tz=timezone.utc).timestamp()
if now is None:
now = datetime.now(tz=timezone.utc).timestamp()
elif isinstance(now, datetime):
now = now.timestamp()
elif not isinstance(now, (int, float)):
raise TypeError("now must be a number (int or float)")

if "iat" in payload and options["verify_iat"]:
self._validate_iat(payload, now, leeway)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,40 @@ def test_decode_with_expiration(self, jwt, payload):
with pytest.raises(ExpiredSignatureError):
jwt.decode(jwt_message, secret, algorithms=["HS256"])

def test_decode_with_expiration_and_manual_now(self, jwt, payload):
exp = utc_timestamp() - 1000
payload["exp"] = exp
secret = "secret"
jwt_message = jwt.encode(payload, secret)

decoded = jwt.decode(jwt_message, secret, algorithms=["HS256"], now=exp - 0.1)
assert decoded == payload

with pytest.raises(ExpiredSignatureError):
jwt.decode(jwt_message, secret, algorithms=["HS256"], now=exp + 0.1)

def test_decode_with_datetime_expiration_and_manual_now(self, jwt, payload):
exp = datetime.now(tz=timezone.utc) - timedelta(days=10)
payload["exp"] = int(exp.timestamp())
secret = "secret"
jwt_message = jwt.encode(payload, secret)

decoded = jwt.decode(
jwt_message,
secret,
algorithms=["HS256"],
now=exp - timedelta(seconds=1),
)
assert decoded == payload

with pytest.raises(ExpiredSignatureError):
jwt.decode(
jwt_message,
secret,
algorithms=["HS256"],
now=exp + timedelta(seconds=1),
)

def test_decode_with_notbefore(self, jwt, payload):
payload["nbf"] = utc_timestamp() + 10
secret = "secret"
Expand Down