Skip to content

Commit 491265f

Browse files
committed
Encoding EC keys with a fixed bit length
1 parent 6c7cc61 commit 491265f

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

jwt/algorithms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,13 +581,20 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
581581
obj: dict[str, Any] = {
582582
"kty": "EC",
583583
"crv": crv,
584-
"x": to_base64url_uint(public_numbers.x).decode(),
585-
"y": to_base64url_uint(public_numbers.y).decode(),
584+
"x": to_base64url_uint(
585+
public_numbers.x,
586+
bit_length=key_obj.curve.key_size,
587+
).decode(),
588+
"y": to_base64url_uint(
589+
public_numbers.y,
590+
bit_length=key_obj.curve.key_size,
591+
).decode(),
586592
}
587593

588594
if isinstance(key_obj, EllipticCurvePrivateKey):
589595
obj["d"] = to_base64url_uint(
590-
key_obj.private_numbers().private_value
596+
key_obj.private_numbers().private_value,
597+
bit_length=key_obj.curve.key_size,
591598
).decode()
592599

593600
if as_dict:

jwt/utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes:
3737
return base64.urlsafe_b64encode(input).replace(b"=", b"")
3838

3939

40-
def to_base64url_uint(val: int) -> bytes:
40+
def to_base64url_uint(val: int, *, bit_length: int | None = None) -> bytes:
4141
if val < 0:
4242
raise ValueError("Must be a positive integer")
4343

44-
int_bytes = bytes_from_int(val)
44+
int_bytes = bytes_from_int(val, bit_length=bit_length)
4545

4646
if len(int_bytes) == 0:
4747
int_bytes = b"\x00"
@@ -63,13 +63,10 @@ def bytes_to_number(string: bytes) -> int:
6363
return int(binascii.b2a_hex(string), 16)
6464

6565

66-
def bytes_from_int(val: int) -> bytes:
67-
remaining = val
68-
byte_length = 0
69-
70-
while remaining != 0:
71-
remaining >>= 8
72-
byte_length += 1
66+
def bytes_from_int(val: int, *, bit_length: int | None = None) -> bytes:
67+
if bit_length is None:
68+
bit_length = val.bit_length()
69+
byte_length = (bit_length + 7) // 8
7370

7471
return val.to_bytes(byte_length, "big", signed=False)
7572

0 commit comments

Comments
 (0)