Skip to content

Commit 14478d3

Browse files
adds e2e example
1 parent 6e85e79 commit 14478d3

1 file changed

Lines changed: 252 additions & 0 deletions

File tree

examples/e2e.rs

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
use ark_bls12_381::Bls12_381;
2+
use ark_ec::pairing::{Pairing, PairingOutput};
3+
use ark_ec::PrimeGroup;
4+
use ark_std::{test_rng, UniformRand};
5+
use simple_batched_threshold_encryption::bte::{
6+
crs::setup,
7+
decryption::{combine, decrypt, decrypt_fft, partial_decrypt, verify, verify_ciphertext_batch},
8+
encryption::encrypt,
9+
};
10+
use std::env;
11+
use std::time::{Duration, Instant};
12+
13+
type E = Bls12_381;
14+
type Fr = <E as Pairing>::ScalarField;
15+
16+
#[derive(Clone, Copy, Debug)]
17+
enum Mode {
18+
Naive,
19+
Fft,
20+
}
21+
22+
#[derive(Clone, Debug)]
23+
struct Config {
24+
batch_size: usize,
25+
num_parties: usize,
26+
threshold: usize,
27+
iterations: usize,
28+
mode: Mode,
29+
}
30+
31+
#[derive(Default)]
32+
struct Timing {
33+
setup: Duration,
34+
encrypt: Duration,
35+
verify_cts: Duration,
36+
partial_decrypt: Duration,
37+
verify_shares: Duration,
38+
combine: Duration,
39+
decrypt: Duration,
40+
}
41+
42+
fn main() {
43+
let config = parse_args();
44+
print_run_header(&config);
45+
46+
let mut aggregate = Timing::default();
47+
48+
for iter in 0..config.iterations {
49+
let mut rng = test_rng();
50+
51+
let start = Instant::now();
52+
let (ek, dk, sks) =
53+
setup::<E>(config.batch_size, config.num_parties, config.threshold, &mut rng);
54+
let setup_time = start.elapsed();
55+
56+
let messages: Vec<PairingOutput<E>> = (0..config.batch_size)
57+
.map(|_| PairingOutput::<E>::generator() * Fr::rand(&mut rng))
58+
.collect();
59+
60+
let start = Instant::now();
61+
let cts: Vec<_> = messages.iter().map(|m| encrypt(&ek, m, &mut rng)).collect();
62+
let encrypt_time = start.elapsed();
63+
64+
let start = Instant::now();
65+
let cts_ok = verify_ciphertext_batch(&cts, &mut rng);
66+
let verify_cts_time = start.elapsed();
67+
assert!(cts_ok, "ciphertext proof verification failed");
68+
69+
let start = Instant::now();
70+
let pds: Vec<_> = sks[..config.threshold]
71+
.iter()
72+
.map(|sk| partial_decrypt(sk, &cts, &mut rng).expect("valid ciphertext proofs"))
73+
.collect();
74+
let partial_decrypt_time = start.elapsed();
75+
76+
let start = Instant::now();
77+
for pd in &pds {
78+
assert!(verify(&dk, pd, &cts), "share verification failed");
79+
}
80+
let verify_shares_time = start.elapsed();
81+
82+
let start = Instant::now();
83+
let pd = combine::<E>(&pds);
84+
let combine_time = start.elapsed();
85+
86+
let start = Instant::now();
87+
let recovered = match config.mode {
88+
Mode::Naive => decrypt(&dk, &pd, &cts, &mut rng),
89+
Mode::Fft => decrypt_fft(&dk, &pd, &cts, &mut rng),
90+
};
91+
let decrypt_time = start.elapsed();
92+
93+
assert_eq!(recovered.len(), messages.len());
94+
for i in 0..messages.len() {
95+
assert_eq!(recovered[i], messages[i], "decryption mismatch at {i}");
96+
}
97+
98+
aggregate.setup += setup_time;
99+
aggregate.encrypt += encrypt_time;
100+
aggregate.verify_cts += verify_cts_time;
101+
aggregate.partial_decrypt += partial_decrypt_time;
102+
aggregate.verify_shares += verify_shares_time;
103+
aggregate.combine += combine_time;
104+
aggregate.decrypt += decrypt_time;
105+
106+
print_iteration(
107+
iter + 1,
108+
&Timing {
109+
setup: setup_time,
110+
encrypt: encrypt_time,
111+
verify_cts: verify_cts_time,
112+
partial_decrypt: partial_decrypt_time,
113+
verify_shares: verify_shares_time,
114+
combine: combine_time,
115+
decrypt: decrypt_time,
116+
},
117+
);
118+
}
119+
120+
let avg = average_timing(&aggregate, config.iterations as u32);
121+
print_summary(&avg);
122+
}
123+
124+
fn parse_args() -> Config {
125+
let mut config = Config {
126+
batch_size: 2048,
127+
num_parties: 8,
128+
threshold: 5,
129+
iterations: 1,
130+
mode: Mode::Fft,
131+
};
132+
133+
let mut args = env::args().skip(1);
134+
while let Some(arg) = args.next() {
135+
match arg.as_str() {
136+
"--batch-size" | "-b" => {
137+
config.batch_size = parse_usize_arg("--batch-size", args.next());
138+
}
139+
"--num-parties" | "-n" => {
140+
config.num_parties = parse_usize_arg("--num-parties", args.next());
141+
}
142+
"--threshold" | "-t" => {
143+
config.threshold = parse_usize_arg("--threshold", args.next());
144+
}
145+
"--iters" | "-i" => {
146+
config.iterations = parse_usize_arg("--iters", args.next());
147+
}
148+
"--mode" | "-m" => {
149+
config.mode = parse_mode_arg(args.next());
150+
}
151+
"--help" | "-h" => {
152+
print_help_and_exit();
153+
}
154+
other => {
155+
panic!("unknown argument: {other}");
156+
}
157+
}
158+
}
159+
160+
assert!(
161+
config.threshold <= config.num_parties,
162+
"threshold must be <= num_parties"
163+
);
164+
assert!(config.iterations >= 1, "iters must be >= 1");
165+
config
166+
}
167+
168+
fn parse_usize_arg(flag: &str, value: Option<String>) -> usize {
169+
value
170+
.unwrap_or_else(|| panic!("missing value for {flag}"))
171+
.parse::<usize>()
172+
.unwrap_or_else(|_| panic!("invalid usize for {flag}"))
173+
}
174+
175+
fn parse_mode_arg(value: Option<String>) -> Mode {
176+
match value
177+
.unwrap_or_else(|| panic!("missing value for --mode"))
178+
.as_str()
179+
{
180+
"naive" => Mode::Naive,
181+
"fft" => Mode::Fft,
182+
other => panic!("invalid mode: {other} (expected 'naive' or 'fft')"),
183+
}
184+
}
185+
186+
fn print_help_and_exit() -> ! {
187+
println!("Usage: cargo run --release --example decrypt_e2e -- [options]");
188+
println!(" --batch-size, -b Batch size (default: 2048)");
189+
println!(" --num-parties, -n Number of parties (default: 8)");
190+
println!(" --threshold, -t Threshold (default: 5)");
191+
println!(" --iters, -i Iterations (default: 1)");
192+
println!(" --mode, -m Decrypt mode: naive | fft (default: fft)");
193+
std::process::exit(0);
194+
}
195+
196+
fn average_timing(total: &Timing, n: u32) -> Timing {
197+
Timing {
198+
setup: total.setup / n,
199+
encrypt: total.encrypt / n,
200+
verify_cts: total.verify_cts / n,
201+
partial_decrypt: total.partial_decrypt / n,
202+
verify_shares: total.verify_shares / n,
203+
combine: total.combine / n,
204+
decrypt: total.decrypt / n,
205+
}
206+
}
207+
208+
fn print_run_header(config: &Config) {
209+
println!("End-to-End Decryption");
210+
println!("=====================");
211+
println!("batch size : {}", config.batch_size);
212+
println!("num parties: {}", config.num_parties);
213+
println!("threshold : {}", config.threshold);
214+
println!("iterations : {}", config.iterations);
215+
println!("mode : {:?}", config.mode);
216+
println!();
217+
}
218+
219+
fn print_iteration(iter: usize, timing: &Timing) {
220+
println!("Iteration {}", iter);
221+
println!("-----------");
222+
print_timing_block(timing);
223+
println!();
224+
}
225+
226+
fn print_summary(avg: &Timing) {
227+
println!("Average");
228+
println!("-------");
229+
print_timing_block(avg);
230+
}
231+
232+
fn print_timing_block(timing: &Timing) {
233+
println!("setup {}", fmt_duration(timing.setup));
234+
println!("encrypt {}", fmt_duration(timing.encrypt));
235+
println!("verify_cts {}", fmt_duration(timing.verify_cts));
236+
println!("partial_decrypt {}", fmt_duration(timing.partial_decrypt));
237+
println!("verify_shares {}", fmt_duration(timing.verify_shares));
238+
println!("combine {}", fmt_duration(timing.combine));
239+
println!("decrypt {}", fmt_duration(timing.decrypt));
240+
}
241+
242+
fn fmt_duration(duration: Duration) -> String {
243+
if duration.as_secs() > 0 {
244+
format!("{:.3}s", duration.as_secs_f64())
245+
} else if duration.as_millis() > 0 {
246+
format!("{:.3}ms", duration.as_secs_f64() * 1_000.0)
247+
} else if duration.as_micros() > 0 {
248+
format!("{:.3}us", duration.as_secs_f64() * 1_000_000.0)
249+
} else {
250+
format!("{}ns", duration.as_nanos())
251+
}
252+
}

0 commit comments

Comments
 (0)