Skip to content

Commit 4ca384f

Browse files
committed
Split HKDF into extract and expand functions
1 parent f1cefb8 commit 4ca384f

File tree

1 file changed

+37
-29
lines changed

1 file changed

+37
-29
lines changed

lib/Crypto/Protocol/KDF.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838
create_string_buffer,
3939
get_raw_buffer, c_size_t)
4040

41-
_raw_salsa20_lib = load_pycryptodome_raw_lib("Crypto.Cipher._Salsa20",
41+
_raw_salsa20_lib = load_pycryptodome_raw_lib(
42+
"Crypto.Cipher._Salsa20",
4243
"""
4344
int Salsa20_8_core(const uint8_t *x, const uint8_t *y,
4445
uint8_t *out);
4546
""")
4647

47-
_raw_scrypt_lib = load_pycryptodome_raw_lib("Crypto.Protocol._scrypt",
48+
_raw_scrypt_lib = load_pycryptodome_raw_lib(
49+
"Crypto.Protocol._scrypt",
4850
"""
4951
typedef int (core_t)(const uint8_t [64], const uint8_t [64], uint8_t [64]);
5052
int scryptROMix(const uint8_t *data_in, uint8_t *data_out,
@@ -156,7 +158,7 @@ def PBKDF2(password, salt, dkLen=16, count=1000, prf=None, hmac_hash_module=None
156158
# Generic (and slow) implementation
157159

158160
if prf is None:
159-
prf = lambda p,s: HMAC.new(p, s, hmac_hash_module).digest()
161+
prf = lambda p, s: HMAC.new(p, s, hmac_hash_module).digest()
160162

161163
def link(s):
162164
s[0], s[1] = s[1], prf(password, s[1])
@@ -165,15 +167,15 @@ def link(s):
165167
key = b''
166168
i = 1
167169
while len(key) < dkLen:
168-
s = [ prf(password, salt + struct.pack(">I", i)) ] * 2
169-
key += reduce(strxor, (link(s) for j in range(count)) )
170+
s = [prf(password, salt + struct.pack(">I", i))] * 2
171+
key += reduce(strxor, (link(s) for j in range(count)))
170172
i += 1
171173

172174
else:
173175
# Optimized implementation
174176
key = b''
175177
i = 1
176-
while len(key)<dkLen:
178+
while len(key) < dkLen:
177179
base = HMAC.new(password, b"", hmac_hash_module)
178180
first_digest = base.copy().update(salt + struct.pack(">I", i)).digest()
179181
key += base._pbkdf2_hmac_assist(first_digest, count)
@@ -230,7 +232,7 @@ def new(key, ciphermod):
230232
return _S2V(key, ciphermod)
231233

232234
def _double(self, bs):
233-
doubled = bytes_to_long(bs)<<1
235+
doubled = bytes_to_long(bs) << 1
234236
if bord(bs[0]) & 0x80:
235237
doubled ^= 0x87
236238
return long_to_bytes(doubled, len(bs))[-len(bs):]
@@ -278,6 +280,24 @@ def derive(self):
278280
return mac.digest()
279281

280282

283+
def _HKDF_extract(salt, ikm, hashmod):
284+
prk = HMAC.new(salt, ikm, digestmod=hashmod).digest()
285+
return prk
286+
287+
288+
def _HKDF_expand(prk, info, L, hashmod):
289+
t = [b""]
290+
n = 1
291+
tlen = 0
292+
while tlen < L:
293+
hmac = HMAC.new(prk, t[-1] + info + struct.pack('B', n), digestmod=hashmod)
294+
t.append(hmac.digest())
295+
tlen += hashmod.digest_size
296+
n += 1
297+
okm = b"".join(t)
298+
return okm
299+
300+
281301
def HKDF(master, key_len, salt, hashmod, num_keys=1, context=None):
282302
"""Derive one or more keys from a master secret using
283303
the HMAC-based KDF defined in RFC5869_.
@@ -318,28 +338,16 @@ def HKDF(master, key_len, salt, hashmod, num_keys=1, context=None):
318338
if context is None:
319339
context = b""
320340

321-
# Step 1: extract
322-
hmac = HMAC.new(salt, master, digestmod=hashmod)
323-
prk = hmac.digest()
341+
prk = _HKDF_extract(salt, master, hashmod)
342+
okm = _HKDF_expand(prk, context, output_len, hashmod)
324343

325-
# Step 2: expand
326-
t = [ b"" ]
327-
n = 1
328-
tlen = 0
329-
while tlen < output_len:
330-
hmac = HMAC.new(prk, t[-1] + context + struct.pack('B', n), digestmod=hashmod)
331-
t.append(hmac.digest())
332-
tlen += hashmod.digest_size
333-
n += 1
334-
derived_output = b"".join(t)
335344
if num_keys == 1:
336-
return derived_output[:key_len]
337-
kol = [derived_output[idx:idx + key_len]
345+
return okm[:key_len]
346+
kol = [okm[idx:idx + key_len]
338347
for idx in iter_range(0, output_len, key_len)]
339348
return list(kol[:num_keys])
340349

341350

342-
343351
def scrypt(password, salt, key_len, N, r, p, num_keys=1):
344352
"""Derive one or more keys from a passphrase.
345353
@@ -383,7 +391,7 @@ def scrypt(password, salt, key_len, N, r, p, num_keys=1):
383391
raise ValueError("N must be a power of 2")
384392
if N >= 2 ** 32:
385393
raise ValueError("N is too big")
386-
if p > ((2 ** 32 - 1) * 32) // (128 * r):
394+
if p > ((2 ** 32 - 1) * 32) // (128 * r):
387395
raise ValueError("p or r are too big")
388396

389397
prf_hmac_sha256 = lambda p, s: HMAC.new(p, s, SHA256).digest()
@@ -398,14 +406,14 @@ def scrypt(password, salt, key_len, N, r, p, num_keys=1):
398406
for flow in iter_range(p):
399407
idx = flow * 128 * r
400408
buffer_out = create_string_buffer(128 * r)
401-
result = scryptROMix(stage_1[idx : idx + 128 * r],
409+
result = scryptROMix(stage_1[idx: idx + 128 * r],
402410
buffer_out,
403411
c_size_t(128 * r),
404412
N,
405413
core)
406414
if result:
407415
raise ValueError("Error %X while running scrypt" % result)
408-
data_out += [ get_raw_buffer(buffer_out) ]
416+
data_out += [get_raw_buffer(buffer_out)]
409417

410418
dk = PBKDF2(password,
411419
b"".join(data_out),
@@ -429,7 +437,7 @@ def _bcrypt_encode(data):
429437
bits.append(bstr(bits_c))
430438
bits = b"".join(bits)
431439

432-
bits6 = [ bits[idx:idx+6] for idx in range(0, len(bits), 6) ]
440+
bits6 = [bits[idx:idx+6] for idx in range(0, len(bits), 6)]
433441

434442
result = []
435443
for g in bits6[:-1]:
@@ -462,7 +470,7 @@ def _bcrypt_decode(data):
462470
elif modulo4 == 3:
463471
bits = bits[:-2]
464472

465-
bits8 = [ bits[idx:idx+8] for idx in range(0, len(bits), 8) ]
473+
bits8 = [bits[idx:idx+8] for idx in range(0, len(bits), 8)]
466474

467475
result = []
468476
for g in bits8:
@@ -570,7 +578,7 @@ def bcrypt_check(password, bcrypt_hash):
570578

571579
salt = _bcrypt_decode(r.group(2))
572580

573-
bcrypt_hash2 = bcrypt(password, cost, salt)
581+
bcrypt_hash2 = bcrypt(password, cost, salt)
574582

575583
secret = get_random_bytes(16)
576584

0 commit comments

Comments
 (0)