Skip to content

Commit 22ba132

Browse files
authored
api_jwt: add a strict_aud option (#902)
* api_jwt: add a `strict_aud` option Signed-off-by: William Woodruff <[email protected]> * CHANGELOG: record changes Signed-off-by: William Woodruff <[email protected]> --------- Signed-off-by: William Woodruff <[email protected]>
1 parent 6db5df7 commit 22ba132

File tree

4 files changed

+103
-1
lines changed

4 files changed

+103
-1
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Fixed
1616
Added
1717
~~~~~
1818

19+
- Add ``strict_aud`` as an option to ``jwt.decode`` by @woodruffw in `#902 <https://github.com/jpadilla/pyjwt/pull/902>`__
20+
1921
`v2.7.0 <https://github.com/jpadilla/pyjwt/compare/2.6.0...2.7.0>`__
2022
-----------------------------------------------------------------------
2123

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ API Reference
5353
* ``verify_exp=verify_signature`` check that ``exp`` (expiration) claim value is in the future
5454
* ``verify_iat=verify_signature`` check that ``iat`` (issued at) claim value is an integer
5555
* ``verify_nbf=verify_signature`` check that ``nbf`` (not before) claim value is in the past
56+
* ``strict_aud=False`` check that the ``aud`` claim is a single value (not a list), and matches ``audience`` exactly
5657

5758
.. warning::
5859

jwt/api_jwt.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def _validate_claims(
251251
self._validate_iss(payload, issuer)
252252

253253
if options["verify_aud"]:
254-
self._validate_aud(payload, audience)
254+
self._validate_aud(
255+
payload, audience, strict=options.get("strict_aud", False)
256+
)
255257

256258
def _validate_required_claims(
257259
self,
@@ -307,6 +309,8 @@ def _validate_aud(
307309
self,
308310
payload: dict[str, Any],
309311
audience: str | Iterable[str] | None,
312+
*,
313+
strict: bool = False,
310314
) -> None:
311315
if audience is None:
312316
if "aud" not in payload or not payload["aud"]:
@@ -322,6 +326,22 @@ def _validate_aud(
322326

323327
audience_claims = payload["aud"]
324328

329+
# In strict mode, we forbid list matching: the supplied audience
330+
# must be a string, and it must exactly match the audience claim.
331+
if strict:
332+
# Only a single audience is allowed in strict mode.
333+
if not isinstance(audience, str):
334+
raise InvalidAudienceError("Invalid audience (strict)")
335+
336+
# Only a single audience claim is allowed in strict mode.
337+
if not isinstance(audience_claims, str):
338+
raise InvalidAudienceError("Invalid claim format in token (strict)")
339+
340+
if audience != audience_claims:
341+
raise InvalidAudienceError("Audience doesn't match (strict)")
342+
343+
return
344+
325345
if isinstance(audience_claims, str):
326346
audience_claims = [audience_claims]
327347
if not isinstance(audience_claims, list):

tests/test_api_jwt.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,82 @@ def test_decode_complete_warns_on_unsupported_kwarg(self, jwt, payload):
723723
jwt.decode_complete(jwt_message, secret, algorithms=["HS256"], foo="bar")
724724
assert len(record) == 1
725725
assert "foo" in str(record[0].message)
726+
727+
def test_decode_strict_aud_forbids_list_audience(self, jwt, payload):
728+
secret = "secret"
729+
payload["aud"] = "urn:foo"
730+
jwt_message = jwt.encode(payload, secret)
731+
732+
# Decodes without `strict_aud`.
733+
jwt.decode(
734+
jwt_message,
735+
secret,
736+
audience=["urn:foo", "urn:bar"],
737+
options={"strict_aud": False},
738+
algorithms=["HS256"],
739+
)
740+
741+
# Fails with `strict_aud`.
742+
with pytest.raises(InvalidAudienceError, match=r"Invalid audience \(strict\)"):
743+
jwt.decode(
744+
jwt_message,
745+
secret,
746+
audience=["urn:foo", "urn:bar"],
747+
options={"strict_aud": True},
748+
algorithms=["HS256"],
749+
)
750+
751+
def test_decode_strict_aud_forbids_list_claim(self, jwt, payload):
752+
secret = "secret"
753+
payload["aud"] = ["urn:foo", "urn:bar"]
754+
jwt_message = jwt.encode(payload, secret)
755+
756+
# Decodes without `strict_aud`.
757+
jwt.decode(
758+
jwt_message,
759+
secret,
760+
audience="urn:foo",
761+
options={"strict_aud": False},
762+
algorithms=["HS256"],
763+
)
764+
765+
# Fails with `strict_aud`.
766+
with pytest.raises(
767+
InvalidAudienceError, match=r"Invalid claim format in token \(strict\)"
768+
):
769+
jwt.decode(
770+
jwt_message,
771+
secret,
772+
audience="urn:foo",
773+
options={"strict_aud": True},
774+
algorithms=["HS256"],
775+
)
776+
777+
def test_decode_strict_aud_does_not_match(self, jwt, payload):
778+
secret = "secret"
779+
payload["aud"] = "urn:foo"
780+
jwt_message = jwt.encode(payload, secret)
781+
782+
with pytest.raises(
783+
InvalidAudienceError, match=r"Audience doesn't match \(strict\)"
784+
):
785+
jwt.decode(
786+
jwt_message,
787+
secret,
788+
audience="urn:bar",
789+
options={"strict_aud": True},
790+
algorithms=["HS256"],
791+
)
792+
793+
def test_decode_strict_ok(self, jwt, payload):
794+
secret = "secret"
795+
payload["aud"] = "urn:foo"
796+
jwt_message = jwt.encode(payload, secret)
797+
798+
jwt.decode(
799+
jwt_message,
800+
secret,
801+
audience="urn:foo",
802+
options={"strict_aud": True},
803+
algorithms=["HS256"],
804+
)

0 commit comments

Comments
 (0)