Skip to content

Commit 2738f15

Browse files
committed
implement deterministic implicit rejection for RSA decryption
1 parent eaffaa9 commit 2738f15

File tree

3 files changed

+1405
-15
lines changed

3 files changed

+1405
-15
lines changed

tlslite/utils/compat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import sys
77
import os
8+
import re
89
import platform
910
import math
1011
import binascii
@@ -68,6 +69,10 @@ def formatExceptionTrace(e):
6869
"""Return exception information formatted as string"""
6970
return str(e)
7071

72+
def remove_whitespace(text):
73+
"""Removes all whitespace from passed in string"""
74+
return re.sub(r"\s+", "", text, flags=re.UNICODE)
75+
7176
else:
7277
# Python 2.6 requires strings instead of bytearrays in a couple places,
7378
# so we define this function so it does the conversion if needed.
@@ -76,9 +81,18 @@ def formatExceptionTrace(e):
7681
if sys.version_info < (2, 7) or sys.version_info < (2, 7, 4) \
7782
or platform.system() == 'Java':
7883
def compat26Str(x): return str(x)
84+
85+
def remove_whitespace(text):
86+
"""Removes all whitespace from passed in string"""
87+
return re.sub(r"\s+", "", text)
88+
7989
else:
8090
def compat26Str(x): return x
8191

92+
def remove_whitespace(text):
93+
"""Removes all whitespace from passed in string"""
94+
return re.sub(r"\s+", "", text, flags=re.UNICODE)
95+
8296
def compatAscii2Bytes(val):
8397
"""Convert ASCII string to bytes."""
8498
return val

tlslite/utils/rsakey.py

Lines changed: 157 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from . import tlshashlib as hashlib
88
from ..errors import MaskTooLongError, MessageTooLongError, EncodingError, \
99
InvalidSignature, UnknownRSAType
10+
from .constanttime import ct_isnonzero_u32, ct_neq_u32, ct_lsb_prop_u8, \
11+
ct_lsb_prop_u16, ct_lt_u32
1012

1113

1214
class RSAKey(object):
@@ -34,6 +36,7 @@ def __init__(self, n=0, e=0):
3436
:type e: int
3537
:param e: RSA public exponent.
3638
"""
39+
self._key_hash = None
3740
raise NotImplementedError()
3841

3942
def __len__(self):
@@ -376,35 +379,174 @@ def encrypt(self, bytes):
376379
paddedBytes = self._addPKCS1Padding(bytes, 2)
377380
return self._raw_public_key_op_bytes(paddedBytes)
378381

382+
def _dec_prf(self, key, label, out_len):
383+
"""PRF for deterministic implicit rejection in the RSA decryption.
384+
385+
:param bytes key: key to use for derivation
386+
:param bytes label: name of the keystream generated
387+
:param int out_len: length of output, in bits
388+
:rtype: bytes
389+
:returns: a random bytestring
390+
"""
391+
out = bytearray()
392+
393+
if out_len % 8 != 0:
394+
raise ValueError("only multiples of 8 supported as output size")
395+
396+
iterator = 0
397+
while len(out) < out_len // 8:
398+
out += secureHMAC(
399+
key,
400+
numberToByteArray(iterator, 2) + label +
401+
numberToByteArray(out_len, 2),
402+
"sha256")
403+
iterator += 1
404+
405+
return out[:out_len//8]
406+
379407
def decrypt(self, encBytes):
380408
"""Decrypt the passed-in bytes.
381409
382410
This requires the key to have a private component. It performs
383-
PKCS1 decryption of the passed-in data.
411+
PKCS#1 v1.5 decryption operation of the passed-in data.
412+
413+
Note: as a workaround against Bleichenbacher-like attacks, it will
414+
return a deterministically selected random message in case the padding
415+
checks failed. It returns an error (None) only in case the ciphertext
416+
is of incorrect length or encodes an integer bigger than the modulus
417+
of the key (i.e. it's publically invalid).
384418
385419
:type encBytes: bytearray
386420
:param encBytes: The value which will be decrypted.
387421
388422
:rtype: bytearray or None
389-
:returns: A PKCS1 decryption of the passed-in data or None if
390-
the data is not properly formatted.
423+
:returns: A PKCS#1 v1.5 decryption of the passed-in data or None if
424+
the provided data is not properly formatted. Note: encrypting
425+
an empty string is correct, so it may return an empty bytearray
426+
for some ciphertexts.
391427
"""
392428
if not self.hasPrivateKey():
393429
raise AssertionError()
394430
try:
395-
decBytes = self._raw_private_key_op_bytes(encBytes)
431+
dec_bytes = self._raw_private_key_op_bytes(encBytes)
396432
except ValueError:
433+
# _raw_private_key_op_bytes fails only when encBytes >= self.n,
434+
# or when len(encBytes) != numBytes(self.n) and that's public
435+
# information, so we don't have to handle it
436+
# in sidechannel secure way
397437
return None
398-
#Check first two bytes
399-
if decBytes[0] != 0 or decBytes[1] != 2:
400-
return None
401-
#Scan through for zero separator
402-
for x in range(1, len(decBytes)-1):
403-
if decBytes[x]== 0:
404-
break
405-
else:
406-
return None
407-
return decBytes[x+1:] #Return everything after the separator
438+
439+
###################
440+
# here be dragons #
441+
###################
442+
# While the code is written as-if it was side-channel secure, in
443+
# practice, because of cPython implementation details IT IS NOT
444+
# see:
445+
# https://securitypitfalls.wordpress.com/2018/08/03/constant-time-compare-in-python/
446+
447+
n = self.n
448+
449+
# maximum length we can return is reduced by the mandatory prefix:
450+
# (0x00 0x02), 8 bytes of padding, so this is the position of the
451+
# null separator byte, as counted from the last position
452+
max_sep_offset = numBytes(n) - 10
453+
454+
# the private exponent (d) doesn't change so `_key_hash` doesn't
455+
# change, calculate it only once
456+
if not hasattr(self, '_key_hash') or not self._key_hash:
457+
self._key_hash = secureHash(numberToByteArray(self.d, numBytes(n)),
458+
"sha256")
459+
460+
kdk = secureHMAC(self._key_hash, encBytes, "sha256")
461+
462+
# we need 128 2-byte numbers, encoded as the number of bits
463+
length_randoms = self._dec_prf(kdk, b"length", 128 * 2 * 8)
464+
465+
message_random = self._dec_prf(kdk, b"message", numBytes(n) * 8)
466+
467+
# select the last length that's not too large to return
468+
synth_length = 0
469+
length_rand_iter = iter(length_randoms)
470+
length_mask = (1 << numBits(max_sep_offset)) - 1
471+
for high, low in zip(length_rand_iter, length_rand_iter):
472+
# interpret the two bytes from the PRF output as 16-bit big-endian
473+
# integer
474+
len_candidate = (high << 8) + low
475+
len_candidate &= length_mask
476+
# equivalent to:
477+
# if len_candidate < max_sep_offset:
478+
# synth_length = len_candidate
479+
mask = ct_lt_u32(len_candidate, max_sep_offset)
480+
mask = ct_lsb_prop_u16(mask)
481+
synth_length = synth_length & (0xffff ^ mask) \
482+
| len_candidate & mask
483+
484+
synth_msg_start = numBytes(n) - synth_length
485+
486+
error_detected = 0
487+
488+
# enumerate over all decrypted bytes
489+
em_bytes = enumerate(dec_bytes)
490+
# first check if first two bytes specify PKCS#1 v1.5 encryption padding
491+
_, val = next(em_bytes)
492+
error_detected |= ct_isnonzero_u32(val)
493+
_, val = next(em_bytes)
494+
error_detected |= ct_neq_u32(val, 0x02)
495+
# then look for for the null separator byte among the padding bytes
496+
# but inspect all decrypted bytes, even if we already find the
497+
# separator earlier
498+
msg_start = 0
499+
for pos, val in em_bytes:
500+
# padding must be at least 8 bytes long, fail if any of the first
501+
# 8 bytes of it are zero
502+
# equivalent to:
503+
# if pos < 10 and not val:
504+
# error_detected = 0x01
505+
error_detected |= ct_lt_u32(pos, 10) & (1 ^ ct_isnonzero_u32(val))
506+
507+
# update the msg_start only once; when it's 0
508+
# (pos+1) because we want to skip the null separator
509+
# equivalent to:
510+
# if pos >= 10 and not msg_start and not val:
511+
# msg_start = pos+1
512+
mask = (1 ^ ct_lt_u32(pos, 10)) & (1 ^ ct_isnonzero_u32(val)) \
513+
& (1 ^ ct_isnonzero_u32(msg_start))
514+
mask = ct_lsb_prop_u16(mask)
515+
msg_start = msg_start & (0xffff ^ mask) | (pos+1) & mask
516+
517+
# if separator wasn't found, it's an error
518+
# equivalent to:
519+
# if not msg_start:
520+
# error_detected = 0x01
521+
error_detected |= 1 ^ ct_isnonzero_u32(msg_start)
522+
523+
# equivalent to:
524+
# if error_detected:
525+
# ret_msg_start = synth_msg_start
526+
# else:
527+
# ret_msg_start = msg_start
528+
mask = ct_lsb_prop_u16(error_detected)
529+
ret_msg_start = msg_start & (0xffff ^ mask) | synth_msg_start & mask
530+
531+
# as at this point the length doesn't leak the information if the
532+
# padding was correct or not, we don't have to worry about the
533+
# length of the returned value (and thus the size of the buffer we
534+
# pass to the caller); but we still need to read both buffers
535+
# to ensure that the memory access patern is preserved (that both
536+
# buffers are accessed, not just the one we return)
537+
538+
# equivalent to:
539+
# if error_detected:
540+
# return message_random[ret_msg_start:]
541+
# else:
542+
# return dec_bytes[ret_msg_start:]
543+
mask = ct_lsb_prop_u8(error_detected)
544+
not_mask = 0xff ^ mask
545+
ret = bytearray(
546+
x & not_mask | y & mask for x, y in
547+
zip(dec_bytes[ret_msg_start:], message_random[ret_msg_start:]))
548+
549+
return ret
408550

409551
def _rawPrivateKeyOp(self, m):
410552
raise NotImplementedError()
@@ -427,7 +569,7 @@ def _raw_public_key_op_bytes(self, ciphertext):
427569
if len(ciphertext) != numBytes(n):
428570
raise ValueError("Message has incorrect length for the key size")
429571
c_int = bytesToNumber(ciphertext)
430-
if c_int > n:
572+
if c_int >= n:
431573
raise ValueError("Provided message value exceeds modulus")
432574
enc_int = self._rawPublicKeyOp(c_int)
433575
return numberToByteArray(enc_int, numBytes(n))

0 commit comments

Comments
 (0)