Skip to content

Commit e028124

Browse files
committed
WIP: Parallelize bottlenecks
1 parent 412078a commit e028124

File tree

4 files changed

+111
-70
lines changed

4 files changed

+111
-70
lines changed

mix-node/src/crypto.rs

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use anyhow::Context;
2-
use bitvec::vec::BitVec;
32
use elastic_elgamal::{
43
group::Ristretto,
54
sharing::{ActiveParticipant, PublicKeySet},
@@ -11,6 +10,7 @@ use std::sync::LazyLock;
1110
use thiserror::Error;
1211

1312
pub type Ciphertext = elastic_elgamal::Ciphertext<Ristretto>;
13+
pub type Bits = bitvec::vec::BitVec;
1414

1515
#[derive(Debug, Clone, Serialize, Deserialize)]
1616
pub struct DecryptionShare {
@@ -52,6 +52,26 @@ pub fn remix(
5252
Ok(())
5353
}
5454

55+
pub fn encrypt(pub_key: &PublicKey<Ristretto>, bits: &Bits) -> Vec<Ciphertext> {
56+
bits.as_raw_slice()
57+
.into_par_iter()
58+
.enumerate()
59+
.flat_map(|(chunk_idx, &chunk)| {
60+
let chunk_size = 8 * std::mem::size_of_val(&chunk);
61+
let start_bit = chunk_idx * chunk_size;
62+
let end_bit = std::cmp::min(start_bit + chunk_size, bits.len());
63+
64+
(0..end_bit - start_bit)
65+
.into_par_iter()
66+
.map(move |bit_offset| {
67+
let mut rng = rand::thread_rng();
68+
let bit = (chunk >> bit_offset) & 1;
69+
pub_key.encrypt(bit as u64, &mut rng)
70+
})
71+
})
72+
.collect::<Vec<_>>()
73+
}
74+
5575
pub fn decryption_share_for(
5676
active_participant: &ActiveParticipant<Ristretto>,
5777
ciphertext: &[Ciphertext],
@@ -66,47 +86,46 @@ pub fn decryption_share_for(
6686
DecryptionShare::new(active_participant.index(), share)
6787
}
6888

69-
// PERF: parallelize and maybe async this
89+
// PERF: parallelize this
7090
pub fn decrypt_shares(
7191
key_set: &PublicKeySet<Ristretto>,
7292
enc: &[Ciphertext],
7393
shares: &[DecryptionShare],
74-
) -> anyhow::Result<BitVec> {
94+
) -> anyhow::Result<Bits> {
7595
if shares.iter().any(|s| s.share.len() != enc.len()) {
76-
return Err(anyhow::anyhow!(
77-
"mismatch of lengths between encrypted ciphertext a decryption shares"
78-
));
96+
anyhow::bail!("mismatch of lengths between encrypted ciphertext a decryption shares");
7997
}
8098
// Transpose vectors
81-
let rows = shares.len();
82-
let cols = enc.len();
83-
let transposed = (0..cols).map(|col| {
84-
(0..rows)
85-
.map(|row| (shares[row].index, shares[row].share[col]))
86-
.collect::<Vec<_>>()
99+
let transposed = (0..enc.len()).into_par_iter().map(|ct_idx| {
100+
shares
101+
.into_par_iter()
102+
.map(move |s| (s.index, s.share[ct_idx]))
87103
});
88104

89105
transposed
90106
.zip(enc)
91107
.map(|(shares, enc)| {
92-
let dec_iter = shares.into_iter().filter_map(|(i, (share, proof))| {
93-
let share = CandidateDecryption::from_bytes(&share.to_bytes())?;
94-
let verification = key_set.verify_share(share, *enc, i, &proof).ok()?;
95-
Some((i, verification))
96-
});
108+
let dec_iter: Vec<_> = shares
109+
.filter_map(|(i, (share, proof))| {
110+
let share = CandidateDecryption::from_bytes(&share.to_bytes())?;
111+
let verification = key_set.verify_share(share, *enc, i, &proof).ok()?;
112+
Some((i, verification))
113+
})
114+
.collect();
97115
let combined = key_set
98116
.params()
99-
.combine_shares(dec_iter)
117+
.combine_shares(dec_iter.into_iter())
100118
.context("failed to combine shares")?;
101119
Ok(combined
102120
.decrypt(*enc, &LOOKUP_TABLE)
103121
.context("decrypted values out of range of lookup table")?
104122
== 1u64)
105123
})
106-
.collect::<anyhow::Result<BitVec>>()
124+
.collect::<anyhow::Result<Vec<_>>>()
125+
.map(Bits::from_iter)
107126
}
108127

109-
pub fn hamming_distance(x_code: BitVec, y_code: BitVec) -> usize {
128+
pub fn hamming_distance(x_code: Bits, y_code: Bits) -> usize {
110129
// Q: What if x and y are different sizes?
111130
(x_code ^ y_code).count_ones()
112131
}

mix-node/src/rest/routes.rs

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
use super::error::Error;
22
use crate::{
3-
crypto::{self, Ciphertext, CryptoError, DecryptionShare},
3+
crypto::{self, Bits, Ciphertext, CryptoError, DecryptionShare},
44
rokio, AppState, EncryptedCodes,
55
};
66
use anyhow::Context;
77
use axum::{extract::State, response::Json};
8-
use bitvec::vec::BitVec;
98
use elastic_elgamal::{group::Ristretto, sharing::PublicKeySet};
10-
use rayon::prelude::*;
119
use reqwest::Client;
1210
use serde::{Deserialize, Serialize};
1311
use std::sync::Arc;
@@ -45,21 +43,14 @@ pub async fn public_key_set(State(state): State<Arc<AppState>>) -> Json<PublicKe
4543
Json(state.pub_key_set().clone())
4644
}
4745

48-
#[tracing::instrument(skip(state, plaintext))]
46+
#[tracing::instrument(skip(state, bits))]
4947
pub async fn encrypt(
5048
State(state): State<Arc<AppState>>,
51-
Json(plaintext): Json<BitVec>,
49+
Json(bits): Json<Bits>,
5250
) -> Json<Vec<Ciphertext>> {
5351
let ciphertexts = rokio::spawn(move || {
5452
let pub_key = state.crypto.active_participant.key_set().shared_key();
55-
plaintext
56-
.into_iter()
57-
.par_bridge()
58-
.map(|msg| {
59-
let mut rng = rand::thread_rng();
60-
pub_key.encrypt(msg as u64, &mut rng)
61-
})
62-
.collect::<Vec<_>>()
53+
crypto::encrypt(pub_key, &bits)
6354
})
6455
.await;
6556

@@ -127,13 +118,38 @@ pub async fn hamming_distance(
127118
let (x_code, y_code) = (Arc::new(x_code), Arc::new(y_code));
128119

129120
// Decrypt
130-
let (x_shares, y_shares) = tokio::join!(
131-
request_all_shares(&x_code, &state),
132-
request_all_shares(&y_code, &state)
133-
);
134-
// PERF: parallelize and rokio
135-
let x_decrypt = crypto::decrypt_shares(state.pub_key_set(), &x_code, &x_shares)?;
136-
let y_decrypt = crypto::decrypt_shares(state.pub_key_set(), &y_code, &y_shares)?;
121+
let (x_shares, y_shares) = {
122+
let (x_state, y_state) = (Arc::clone(&state), Arc::clone(&state));
123+
let (x_inner_code, y_inner_code) = (Arc::clone(&x_code), Arc::clone(&y_code));
124+
125+
let (mut x_shares, mut y_shares, x_self_share, y_self_share) = tokio::join!(
126+
request_all_shares(&x_code, &state),
127+
request_all_shares(&y_code, &state),
128+
rokio::spawn(move || {
129+
crypto::decryption_share_for(&x_state.crypto.active_participant, &x_inner_code)
130+
}),
131+
rokio::spawn(move || {
132+
crypto::decryption_share_for(&y_state.crypto.active_participant, &y_inner_code)
133+
})
134+
);
135+
x_shares.push(x_self_share);
136+
y_shares.push(y_self_share);
137+
138+
(x_shares, y_shares)
139+
};
140+
141+
let x_decrypt = {
142+
let state = Arc::clone(&state);
143+
rokio::spawn(move || crypto::decrypt_shares(state.pub_key_set(), &x_code, &x_shares))
144+
};
145+
let y_decrypt = {
146+
let state = Arc::clone(&state);
147+
rokio::spawn(move || crypto::decrypt_shares(state.pub_key_set(), &y_code, &y_shares))
148+
};
149+
let (x_decrypt, y_decrypt) = {
150+
let (x_decrypt, y_decrypt) = tokio::join!(x_decrypt, y_decrypt);
151+
(x_decrypt?, y_decrypt?)
152+
};
137153

138154
// Hamming
139155
let hamming_distance = crypto::hamming_distance(x_decrypt, y_decrypt);

mix-node/tests/common/mod.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
#![allow(dead_code)]
2+
13
use bitvec::prelude::*;
24
use elastic_elgamal::{group::Ristretto, DiscreteLogTable, Keypair, PublicKey, SecretKey};
3-
use mix_node::{crypto::Ciphertext, EncryptedCodes};
5+
use mix_node::{
6+
crypto::{self, Bits, Ciphertext},
7+
EncryptedCodes,
8+
};
49
use rand::{CryptoRng, Rng};
510
use std::iter;
611

712
pub const N_BITS: usize = mix_node::N_BITS / 2;
813

9-
#[allow(unused)]
1014
pub fn set_up_payload() -> (EncryptedCodes, Keypair<Ristretto>) {
1115
let mut rng = rand::thread_rng();
1216
let new_iris_code = BitVec::<_, Lsb0>::from_slice(&rng.gen::<[u8; N_BITS / 8]>());
@@ -20,11 +24,10 @@ pub fn set_up_payload() -> (EncryptedCodes, Keypair<Ristretto>) {
2024

2125
// Encrypt
2226
let receiver = Keypair::generate(&mut rng);
23-
let dec_key = receiver.secret().clone();
2427
let enc_key = receiver.public().clone();
2528

26-
let enc_new_user: Vec<_> = encrypt_bits(&new_user[..], &enc_key, &mut rng).collect();
27-
let enc_archived_user: Vec<_> = encrypt_bits(&archived_user[..], &enc_key, &mut rng).collect();
29+
let enc_new_user: Vec<_> = crypto::encrypt(receiver.public(), &new_user);
30+
let enc_archived_user: Vec<_> = crypto::encrypt(receiver.public(), &archived_user);
2831

2932
(
3033
EncryptedCodes {
@@ -36,7 +39,11 @@ pub fn set_up_payload() -> (EncryptedCodes, Keypair<Ristretto>) {
3639
)
3740
}
3841

39-
#[allow(unused)]
42+
pub fn set_up_iris_code(size: usize) -> Bits {
43+
let mut rng = rand::thread_rng();
44+
(0..size).map(|_| rng.gen::<bool>()).collect::<Bits>()
45+
}
46+
4047
pub fn encode_bits<T: BitStore, O: BitOrder>(
4148
bits: &BitSlice<T, O>,
4249
) -> impl Iterator<Item = bool> + '_ {
@@ -49,7 +56,6 @@ pub fn encode_bits<T: BitStore, O: BitOrder>(
4956
})
5057
}
5158

52-
#[allow(unused)]
5359
pub fn decode_bits<T: BitStore, O: BitOrder>(
5460
bits: &BitSlice<T, O>,
5561
) -> impl Iterator<Item = bool> + '_ {
@@ -63,7 +69,6 @@ pub fn decode_bits<T: BitStore, O: BitOrder>(
6369
})
6470
}
6571

66-
#[allow(unused)]
6772
pub fn encrypt_bits<'a, T: BitStore, O: BitOrder>(
6873
bits: &'a BitSlice<T, O>,
6974
ek: &'a PublicKey<Ristretto>,
@@ -72,7 +77,6 @@ pub fn encrypt_bits<'a, T: BitStore, O: BitOrder>(
7277
bits.iter().map(|bit| ek.encrypt(*bit as u32, rng))
7378
}
7479

75-
#[allow(unused)]
7680
pub fn decrypt_bits<'a>(
7781
ct: &'a [Ciphertext],
7882
pk: &'a SecretKey<Ristretto>,

mix-node/tests/web_integration.rs

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use elastic_elgamal::{group::Ristretto, sharing::PublicKeySet, Ciphertext};
55
use format as f;
66
use mix_node::{
77
config::get_configuration,
8-
crypto::{self, DecryptionShare},
8+
crypto::{self, Bits, DecryptionShare},
99
rest::routes::HammingResponse,
1010
test_helpers::{self, TestApp},
1111
EncryptedCodes,
@@ -41,9 +41,8 @@ async fn test_mix_node() -> anyhow::Result<()> {
4141
.await?;
4242

4343
// Decrypt
44-
let dec_new_user: BitVec<u8, Lsb0> =
45-
common::decrypt_bits(&enc_new_user, receiver.secret()).collect();
46-
let dec_archived_user: BitVec<u8, Lsb0> =
44+
let dec_new_user: Bits = common::decrypt_bits(&enc_new_user, receiver.secret()).collect();
45+
let dec_archived_user: Bits =
4746
common::decrypt_bits(&enc_archived_user, receiver.secret()).collect();
4847

4948
// Assert result
@@ -63,7 +62,9 @@ async fn test_mix_node_bad_request() -> anyhow::Result<()> {
6362
config.application.auth_token = None;
6463
let TestApp { port, .. } = test_helpers::create_app(config).await;
6564

66-
let (mut codes, _dec_key) = common::set_up_payload();
65+
// let code = common::set_up_iris_code(mix_node::N_BITS);
66+
67+
let (mut codes, _receiver) = common::set_up_payload();
6768
// Remove elements to cause a size mismatch
6869
codes.x_code.pop();
6970
codes.x_code.pop();
@@ -107,9 +108,9 @@ async fn test_mix_node_authorized() -> anyhow::Result<()> {
107108
config.application.auth_token = Some(Secret::new(auth_token.to_string()));
108109
let TestApp { port, .. } = test_helpers::create_app(config).await;
109110

110-
let (codes, _dec_key) = common::set_up_payload();
111+
let (codes, _receiver) = common::set_up_payload();
111112

112-
// Bad request + Serialization
113+
// Auth
113114
let client = reqwest::Client::new();
114115
let response = client
115116
.post(format!("http://localhost:{port}/remix"))
@@ -136,7 +137,6 @@ async fn test_mix_node_public_key() -> anyhow::Result<()> {
136137
assert_eq!(response.status(), StatusCode::OK);
137138

138139
let _body: PublicKeySet<Ristretto> = response.json().await?;
139-
140140
Ok(())
141141
}
142142

@@ -145,7 +145,7 @@ async fn test_mix_node_encrypt() -> anyhow::Result<()> {
145145
let config = get_configuration()?;
146146
let TestApp { port, .. } = test_helpers::create_app(config).await;
147147

148-
let payload = bitvec![0, 1, 0, 1, 1, 0, 0, 1];
148+
let payload = common::set_up_iris_code(mix_node::N_BITS);
149149

150150
let client = reqwest::Client::new();
151151
let response = client
@@ -173,14 +173,17 @@ async fn test_network_decrypt_shares() -> anyhow::Result<()> {
173173
.await?;
174174
assert_eq!(response.status(), StatusCode::OK);
175175

176-
// Encrypt client side
177-
let mut rng = rand::thread_rng();
178-
let payload = bitvec![0, 1, 0, 1, 1, 0, 0, 1];
176+
// Encrypt server side
177+
let payload: Bits = bitvec![0, 1, 0, 1, 1, 0, 0, 1];
178+
179179
let pub_key: PublicKeySet<Ristretto> = response.json().await?;
180-
let encrypted: Vec<_> = payload
181-
.iter()
182-
.map(|pt| pub_key.shared_key().encrypt(*pt as u64, &mut rng))
183-
.collect();
180+
let encrypted: Vec<_> = client
181+
.post(f!("http://localhost:{}/encrypt", nodes[0].port))
182+
.json(&payload)
183+
.send()
184+
.await?
185+
.json()
186+
.await?;
184187

185188
// Decrypt
186189
let mut shares = vec![];
@@ -207,7 +210,6 @@ async fn test_network_hamming_distance() -> anyhow::Result<()> {
207210
let [TestApp { port, .. }, ..] = test_helpers::create_network(3, 2).await[..] else {
208211
return Err(anyhow::anyhow!("there's not at least one node"));
209212
};
210-
let mut rng = rand::thread_rng();
211213

212214
// Request public key
213215
let client = reqwest::Client::new();
@@ -220,10 +222,10 @@ async fn test_network_hamming_distance() -> anyhow::Result<()> {
220222

221223
// Codes are the same, expected hamming distance = 0
222224
let payload = {
223-
let code: Vec<_> = [0, 1, 0, 1, 1, 0, 0, 1u64]
224-
.into_iter()
225-
.map(|m| pub_key.shared_key().encrypt(m, &mut rng))
226-
.collect();
225+
let code = crypto::encrypt(
226+
pub_key.shared_key(),
227+
&common::set_up_iris_code(mix_node::N_BITS),
228+
);
227229
EncryptedCodes {
228230
x_code: code.clone(),
229231
y_code: code,

0 commit comments

Comments
 (0)