7
7
from . import tlshashlib as hashlib
8
8
from ..errors import MaskTooLongError , MessageTooLongError , EncodingError , \
9
9
InvalidSignature , UnknownRSAType
10
+ from .constanttime import ct_isnonzero_u32 , ct_neq_u32 , ct_lsb_prop_u8 , \
11
+ ct_lsb_prop_u16 , ct_lt_u32
10
12
11
13
12
14
class RSAKey (object ):
@@ -34,6 +36,7 @@ def __init__(self, n=0, e=0):
34
36
:type e: int
35
37
:param e: RSA public exponent.
36
38
"""
39
+ self ._key_hash = None
37
40
raise NotImplementedError ()
38
41
39
42
def __len__ (self ):
@@ -376,35 +379,174 @@ def encrypt(self, bytes):
376
379
paddedBytes = self ._addPKCS1Padding (bytes , 2 )
377
380
return self ._raw_public_key_op_bytes (paddedBytes )
378
381
382
+ def _dec_prf (self , key , label , out_len ):
383
+ """PRF for deterministic implicit rejection in the RSA decryption.
384
+
385
+ :param bytes key: key to use for derivation
386
+ :param bytes label: name of the keystream generated
387
+ :param int out_len: length of output, in bits
388
+ :rtype: bytes
389
+ :returns: a random bytestring
390
+ """
391
+ out = bytearray ()
392
+
393
+ if out_len % 8 != 0 :
394
+ raise ValueError ("only multiples of 8 supported as output size" )
395
+
396
+ iterator = 0
397
+ while len (out ) < out_len // 8 :
398
+ out += secureHMAC (
399
+ key ,
400
+ numberToByteArray (iterator , 2 ) + label +
401
+ numberToByteArray (out_len , 2 ),
402
+ "sha256" )
403
+ iterator += 1
404
+
405
+ return out [:out_len // 8 ]
406
+
379
407
def decrypt (self , encBytes ):
380
408
"""Decrypt the passed-in bytes.
381
409
382
410
This requires the key to have a private component. It performs
383
- PKCS1 decryption of the passed-in data.
411
+ PKCS#1 v1.5 decryption operation of the passed-in data.
412
+
413
+ Note: as a workaround against Bleichenbacher-like attacks, it will
414
+ return a deterministically selected random message in case the padding
415
+ checks failed. It returns an error (None) only in case the ciphertext
416
+ is of incorrect length or encodes an integer bigger than the modulus
417
+ of the key (i.e. it's publically invalid).
384
418
385
419
:type encBytes: bytearray
386
420
:param encBytes: The value which will be decrypted.
387
421
388
422
:rtype: bytearray or None
389
- :returns: A PKCS1 decryption of the passed-in data or None if
390
- the data is not properly formatted.
423
+ :returns: A PKCS#1 v1.5 decryption of the passed-in data or None if
424
+ the provided data is not properly formatted. Note: encrypting
425
+ an empty string is correct, so it may return an empty bytearray
426
+ for some ciphertexts.
391
427
"""
392
428
if not self .hasPrivateKey ():
393
429
raise AssertionError ()
394
430
try :
395
- decBytes = self ._raw_private_key_op_bytes (encBytes )
431
+ dec_bytes = self ._raw_private_key_op_bytes (encBytes )
396
432
except ValueError :
433
+ # _raw_private_key_op_bytes fails only when encBytes >= self.n,
434
+ # or when len(encBytes) != numBytes(self.n) and that's public
435
+ # information, so we don't have to handle it
436
+ # in sidechannel secure way
397
437
return None
398
- #Check first two bytes
399
- if decBytes [0 ] != 0 or decBytes [1 ] != 2 :
400
- return None
401
- #Scan through for zero separator
402
- for x in range (1 , len (decBytes )- 1 ):
403
- if decBytes [x ]== 0 :
404
- break
405
- else :
406
- return None
407
- return decBytes [x + 1 :] #Return everything after the separator
438
+
439
+ ###################
440
+ # here be dragons #
441
+ ###################
442
+ # While the code is written as-if it was side-channel secure, in
443
+ # practice, because of cPython implementation details IT IS NOT
444
+ # see:
445
+ # https://securitypitfalls.wordpress.com/2018/08/03/constant-time-compare-in-python/
446
+
447
+ n = self .n
448
+
449
+ # maximum length we can return is reduced by the mandatory prefix:
450
+ # (0x00 0x02), 8 bytes of padding, so this is the position of the
451
+ # null separator byte, as counted from the last position
452
+ max_sep_offset = numBytes (n ) - 10
453
+
454
+ # the private exponent (d) doesn't change so `_key_hash` doesn't
455
+ # change, calculate it only once
456
+ if not hasattr (self , '_key_hash' ) or not self ._key_hash :
457
+ self ._key_hash = secureHash (numberToByteArray (self .d , numBytes (n )),
458
+ "sha256" )
459
+
460
+ kdk = secureHMAC (self ._key_hash , encBytes , "sha256" )
461
+
462
+ # we need 128 2-byte numbers, encoded as the number of bits
463
+ length_randoms = self ._dec_prf (kdk , b"length" , 128 * 2 * 8 )
464
+
465
+ message_random = self ._dec_prf (kdk , b"message" , numBytes (n ) * 8 )
466
+
467
+ # select the last length that's not too large to return
468
+ synth_length = 0
469
+ length_rand_iter = iter (length_randoms )
470
+ length_mask = (1 << numBits (max_sep_offset )) - 1
471
+ for high , low in zip (length_rand_iter , length_rand_iter ):
472
+ # interpret the two bytes from the PRF output as 16-bit big-endian
473
+ # integer
474
+ len_candidate = (high << 8 ) + low
475
+ len_candidate &= length_mask
476
+ # equivalent to:
477
+ # if len_candidate < max_sep_offset:
478
+ # synth_length = len_candidate
479
+ mask = ct_lt_u32 (len_candidate , max_sep_offset )
480
+ mask = ct_lsb_prop_u16 (mask )
481
+ synth_length = synth_length & (0xffff ^ mask ) \
482
+ | len_candidate & mask
483
+
484
+ synth_msg_start = numBytes (n ) - synth_length
485
+
486
+ error_detected = 0
487
+
488
+ # enumerate over all decrypted bytes
489
+ em_bytes = enumerate (dec_bytes )
490
+ # first check if first two bytes specify PKCS#1 v1.5 encryption padding
491
+ _ , val = next (em_bytes )
492
+ error_detected |= ct_isnonzero_u32 (val )
493
+ _ , val = next (em_bytes )
494
+ error_detected |= ct_neq_u32 (val , 0x02 )
495
+ # then look for for the null separator byte among the padding bytes
496
+ # but inspect all decrypted bytes, even if we already find the
497
+ # separator earlier
498
+ msg_start = 0
499
+ for pos , val in em_bytes :
500
+ # padding must be at least 8 bytes long, fail if any of the first
501
+ # 8 bytes of it are zero
502
+ # equivalent to:
503
+ # if pos < 10 and not val:
504
+ # error_detected = 0x01
505
+ error_detected |= ct_lt_u32 (pos , 10 ) & (1 ^ ct_isnonzero_u32 (val ))
506
+
507
+ # update the msg_start only once; when it's 0
508
+ # (pos+1) because we want to skip the null separator
509
+ # equivalent to:
510
+ # if pos >= 10 and not msg_start and not val:
511
+ # msg_start = pos+1
512
+ mask = (1 ^ ct_lt_u32 (pos , 10 )) & (1 ^ ct_isnonzero_u32 (val )) \
513
+ & (1 ^ ct_isnonzero_u32 (msg_start ))
514
+ mask = ct_lsb_prop_u16 (mask )
515
+ msg_start = msg_start & (0xffff ^ mask ) | (pos + 1 ) & mask
516
+
517
+ # if separator wasn't found, it's an error
518
+ # equivalent to:
519
+ # if not msg_start:
520
+ # error_detected = 0x01
521
+ error_detected |= 1 ^ ct_isnonzero_u32 (msg_start )
522
+
523
+ # equivalent to:
524
+ # if error_detected:
525
+ # ret_msg_start = synth_msg_start
526
+ # else:
527
+ # ret_msg_start = msg_start
528
+ mask = ct_lsb_prop_u16 (error_detected )
529
+ ret_msg_start = msg_start & (0xffff ^ mask ) | synth_msg_start & mask
530
+
531
+ # as at this point the length doesn't leak the information if the
532
+ # padding was correct or not, we don't have to worry about the
533
+ # length of the returned value (and thus the size of the buffer we
534
+ # pass to the caller); but we still need to read both buffers
535
+ # to ensure that the memory access patern is preserved (that both
536
+ # buffers are accessed, not just the one we return)
537
+
538
+ # equivalent to:
539
+ # if error_detected:
540
+ # return message_random[ret_msg_start:]
541
+ # else:
542
+ # return dec_bytes[ret_msg_start:]
543
+ mask = ct_lsb_prop_u8 (error_detected )
544
+ not_mask = 0xff ^ mask
545
+ ret = bytearray (
546
+ x & not_mask | y & mask for x , y in
547
+ zip (dec_bytes [ret_msg_start :], message_random [ret_msg_start :]))
548
+
549
+ return ret
408
550
409
551
def _rawPrivateKeyOp (self , m ):
410
552
raise NotImplementedError ()
@@ -427,7 +569,7 @@ def _raw_public_key_op_bytes(self, ciphertext):
427
569
if len (ciphertext ) != numBytes (n ):
428
570
raise ValueError ("Message has incorrect length for the key size" )
429
571
c_int = bytesToNumber (ciphertext )
430
- if c_int > n :
572
+ if c_int >= n :
431
573
raise ValueError ("Provided message value exceeds modulus" )
432
574
enc_int = self ._rawPublicKeyOp (c_int )
433
575
return numberToByteArray (enc_int , numBytes (n ))
0 commit comments