-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcomb_attack.py
More file actions
162 lines (138 loc) · 5.04 KB
/
comb_attack.py
File metadata and controls
162 lines (138 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import random
import itertools
import time
from multiprocessing import Pool
import sys
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logger_sh = logging.StreamHandler()
formatter = logging.Formatter('%(name)s [%(levelname)s] %(message)s')
logger_sh.setFormatter(formatter)
logger.addHandler(logger_sh)
def generate_pair(n_bits, n_unknowns):
"""
Generate two integers alpha, beta < 2^n_bits such that alpha*beta = 1 mod 2^n_bits.
Compute bin_alpha, bin_beta as alpha, beta where n_unknowns bits of each number
are removed.
bin_alpha, bin_beta are returned as binary strings with '*' for unknown bits, lsb first,
with zeros if needed; so for instance 11 mod 32 would be represented as '11010'.
alpha and beta are returned as integers mod 2^n_bits.
"""
while True:
alpha = random.randint(1, 2**n_bits)
if alpha % 2 == 0:
continue
beta = pow(alpha, -1, 2**n_bits)
break
bin_alpha = bin(alpha)[2:].zfill(n_bits)[::-1]
bin_alpha = list(bin_alpha)
bin_beta = bin(beta)[2:].zfill(n_bits)[::-1]
bin_beta = list(bin_beta)
# Add noise
while bin_alpha.count('*') < n_unknowns:
bin_alpha[random.randint(0, n_bits - 1)] = '*'
while bin_beta.count('*') < n_unknowns:
bin_beta[random.randint(0, n_bits - 1)] = '*'
# All numbers are odd
bin_alpha[0] = bin_beta[0] = '1'
return ''.join(bin_alpha), ''.join(bin_beta), alpha, beta
def _solve_step(data):
"""
Given a tuple (A, B, str_A, str_B) of partial solutions A and B with
their binary representation, find all compatible tuples at the next
step.
"""
alpha_i, beta_i, lvl = data
global _bin_alpha, _bin_beta
# Evaluate a new step
a_i = _bin_alpha[lvl] # we know alpha_i and beta_i mod 2^lvl
b_i = _bin_beta[lvl]
if a_i == '*':
combs_a = [0, 1]
else:
combs_a = [int(a_i)]
if b_i == '*':
combs = [[a_i, 0] for a_i in combs_a]
combs += [[a_i, 1] for a_i in combs_a]
else:
combs = [[a_i, int(b_i)] for a_i in combs_a]
valid_steps = []
for a_i, b_i in combs:
check_a = alpha_i + 2**lvl * a_i
check_b = beta_i + 2**lvl * b_i
if (check_a * check_b) % 2**(lvl+1) == 1:
part_sol = (check_a, check_b, lvl+1)
valid_steps.append(part_sol)
return valid_steps
def solve_product_bfs(bin_alpha, bin_beta):
"""
Given numbers with holes bin_alpha, bin_beta find all possible matching
solutions alpha * beta == 1 mod 2^k.
Finds solutions by bfs, i.e. all solutions mod 2^i are computed for all i.
Leftmost bit is the least significative, and the strings must have the same
length k. The output format is the same: use int(alpha[::-1], 2) to recover
an integer.
"""
global _bin_alpha, _bin_beta
_bin_alpha = bin_alpha
_bin_beta = bin_beta
k = len(bin_alpha)
assert len(bin_beta) == k
assert not '0' in [bin_alpha[0], bin_beta[0]]
P = Pool()
q = []
q.append((1, 1, 1)) # (alpha, beta, level)
sols = []
last_deg = 0
t0 = time.time()
t_last = time.time()
while last_deg < k-1:
t_new = time.time()
el1, el2 = bin_alpha[last_deg], bin_beta[last_deg]
nel1, nel2 = bin_alpha[last_deg+1], bin_beta[last_deg+1]
logger.info(f'Done deg {last_deg} ({el1}|{el2}) - n_paths: {len(q)+1} --> next deg {last_deg+1} ({nel1}|{nel2})')
if last_deg > 1:
eta = (k-1-last_deg)*(t_new - t0)/last_deg
logger.info(f'- step time: {t_new-t_last:.3f}s / tot time: {t_new-t0:.3f}s / eta: {eta:.3f}s ({k-1-last_deg} st)')
t_last = t_new
last_deg += 1
if len(q) < 1000:
q_new = map(_solve_step, q)
else:
q_new = P.map(_solve_step, q)
q = list(itertools.chain.from_iterable(q_new))
sols = map(lambda x: x[0], q)
return list(sols)
if __name__ == '__main__':
"""
Usage:
- python3 comb_attack.py: runs the test
- python3 comb_attack.py alpha beta: runs the attack for alpha and beta,
and writes all possible outputs in _output.txt
"""
arg = sys.argv
if len(arg) == 1:
logger.setLevel(logging.INFO)
k = 250
noise = k // 2
bin_alpha, bin_beta, alpha, beta = generate_pair(k, noise)
check_alpha = bin(alpha)[2:].zfill(k)[::-1]
check_beta = bin(beta)[2:].zfill(k)[::-1]
print(f'{check_alpha = } {check_beta = }')
sols = solve_product_bfs(bin_alpha, bin_beta)
print(f'{len(sols) = }')
for A in sols:
if A == alpha:
print(f'FOUND')
B = pow(A, -1, 2**k)
assert B == beta
print(f'{A = } {B = }')
else:
assert len(arg) == 3, "wrong usage"
bin_alpha, bin_beta = arg[1:]
sols = solve_product_bfs(bin_alpha, bin_beta)
print(f'Found {len(sols)} solutions')
sols_str = '\n'.join([str(i) for i in sols])
with open('_output.txt', 'w') as fh:
fh.write(sols_str)