Skip to content

Commit c387281

Browse files
authored
Add JWK support to JWT encode (#979)
* Allow JWK for JWS encode. * Add PyJWK to JWT encode. * Update CHANGELOG. * Remove `DEFAULT_ALGORITHM`
1 parent 44d8605 commit c387281

File tree

4 files changed

+47
-7
lines changed

4 files changed

+47
-7
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Changed
1212

1313
- Use ``Sequence`` for parameter types rather than ``List`` where applicable by @imnotjames in `#970 <https://github.com/jpadilla/pyjwt/pull/970>`__
1414
- Remove algorithm requirement from JWT API, instead relying on JWS API for enforcement, by @luhn in `#975 <https://github.com/jpadilla/pyjwt/pull/975>`__
15+
- Add JWK support to JWT encode by @luhn in `#979 <https://github.com/jpadilla/pyjwt/pull/979>`__
1516

1617
Fixed
1718
~~~~~

jwt/api_jws.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
105105
def encode(
106106
self,
107107
payload: bytes,
108-
key: AllowedPrivateKeys | str | bytes,
109-
algorithm: str | None = "HS256",
108+
key: AllowedPrivateKeys | PyJWK | str | bytes,
109+
algorithm: str | None = None,
110110
headers: dict[str, Any] | None = None,
111111
json_encoder: type[json.JSONEncoder] | None = None,
112112
is_payload_detached: bool = False,
@@ -115,7 +115,13 @@ def encode(
115115
segments = []
116116

117117
# declare a new var to narrow the type for type checkers
118-
algorithm_: str = algorithm if algorithm is not None else "none"
118+
if algorithm is None:
119+
if isinstance(key, PyJWK):
120+
algorithm_ = key.algorithm_name
121+
else:
122+
algorithm_ = "HS256"
123+
else:
124+
algorithm_ = algorithm
119125

120126
# Prefer headers values if present to function parameters.
121127
if headers:
@@ -159,6 +165,8 @@ def encode(
159165
signing_input = b".".join(segments)
160166

161167
alg_obj = self.get_algorithm_by_name(algorithm_)
168+
if isinstance(key, PyJWK):
169+
key = key.key
162170
key = alg_obj.prepare_key(key)
163171
signature = alg_obj.sign(signing_input, key)
164172

jwt/api_jwt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
4545
def encode(
4646
self,
4747
payload: dict[str, Any],
48-
key: AllowedPrivateKeys | str | bytes,
49-
algorithm: str | None = "HS256",
48+
key: AllowedPrivateKeys | PyJWK | str | bytes,
49+
algorithm: str | None = None,
5050
headers: dict[str, Any] | None = None,
5151
json_encoder: type[json.JSONEncoder] | None = None,
5252
sort_headers: bool = True,

tests/test_api_jws.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws):
158158
exception = context.value
159159
assert str(exception) == "Invalid header string: must be a json object"
160160

161+
def test_encode_default_algorithm(self, jws, payload):
162+
msg = jws.encode(payload, "secret")
163+
decoded = jws.decode_complete(msg, "secret", algorithms=["HS256"])
164+
assert decoded == {
165+
"header": {"alg": "HS256", "typ": "JWT"},
166+
"payload": payload,
167+
"signature": (
168+
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
169+
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
170+
),
171+
}
172+
161173
def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):
162174
jws.encode(payload, "secret", algorithm="HS256")
163175

@@ -193,6 +205,25 @@ def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload):
193205
msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"})
194206
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])
195207

208+
def test_encode_with_jwk(self, jws, payload):
209+
jwk = PyJWK(
210+
{
211+
"kty": "oct",
212+
"alg": "HS256",
213+
"k": "c2VjcmV0", # "secret"
214+
}
215+
)
216+
msg = jws.encode(payload, key=jwk)
217+
decoded = jws.decode_complete(msg, key=jwk, algorithms=["HS256"])
218+
assert decoded == {
219+
"header": {"alg": "HS256", "typ": "JWT"},
220+
"payload": payload,
221+
"signature": (
222+
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
223+
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
224+
),
225+
}
226+
196227
def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
197228
example_jws = (
198229
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256
@@ -531,13 +562,13 @@ def test_decode_invalid_crypto_padding(self, jws):
531562
assert "Invalid crypto padding" in str(exc.value)
532563

533564
def test_decode_with_algo_none_should_fail(self, jws, payload):
534-
jws_message = jws.encode(payload, key=None, algorithm=None)
565+
jws_message = jws.encode(payload, key=None, algorithm="none")
535566

536567
with pytest.raises(DecodeError):
537568
jws.decode(jws_message, algorithms=["none"])
538569

539570
def test_decode_with_algo_none_and_verify_false_should_pass(self, jws, payload):
540-
jws_message = jws.encode(payload, key=None, algorithm=None)
571+
jws_message = jws.encode(payload, key=None, algorithm="none")
541572
jws.decode(jws_message, options={"verify_signature": False})
542573

543574
def test_get_unverified_header_returns_header_values(self, jws, payload):

0 commit comments

Comments
 (0)