Skip to content

Commit ededa27

Browse files
committed
chore: send PRSS computation to rayon in batches
1 parent 8105c7c commit ededa27

File tree

9 files changed

+374
-305
lines changed

9 files changed

+374
-305
lines changed

core/threshold/src/execution/endpoints/decryption_non_wasm.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use tfhe::shortint::PBSOrder;
6666
use tokio::task::JoinSet;
6767
use tokio::time::{Duration, Instant};
6868
use tracing::info_span;
69-
use tracing::instrument;
69+
use tracing::{instrument, Instrument};
7070

7171
#[cfg(any(test, feature = "testing"))]
7272
use super::decryption::DecryptionMode;
@@ -126,7 +126,6 @@ where
126126
/// Load precomputed init data for noise flooding.
127127
///
128128
/// Note: this is actually a synchronous function. It just needs to be async to implement the trait (which is async in the Large case)
129-
/// TODO: we should move the slow parts to rayon
130129
async fn init_prep_noiseflooding(
131130
&mut self,
132131
num_ctxt: usize,
@@ -136,17 +135,13 @@ where
136135
let own_role = session.my_role();
137136

138137
let prss_span = info_span!("PRSS-MASK.Next", batch_size = num_ctxt);
139-
let masks = prss_span.in_scope(|| {
140-
(0..num_ctxt)
141-
.map(|_| {
142-
self.session
143-
.get_mut()
144-
.prss_as_mut()
145-
.mask_next(own_role, B_SWITCH_SQUASH)
146-
})
147-
.try_collect()
148-
})?;
149-
138+
let masks = self
139+
.session
140+
.get_mut()
141+
.prss_as_mut()
142+
.mask_next_vec(own_role, B_SWITCH_SQUASH, num_ctxt)
143+
.instrument(prss_span)
144+
.await?;
150145
sns_preprocessing.append_masks(masks);
151146
Ok(sns_preprocessing)
152147
}

core/threshold/src/execution/online/preprocessing/dummy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ where
365365
}
366366

367367
/// Fill the masks directly from the [`crate::execution::small_execution::prss::PRSSState`] available from [`SmallSession`]
368-
fn fill_from_small_session(
368+
async fn fill_from_small_session(
369369
&mut self,
370370
_session: &mut SmallSession<ResiduePoly<Z128, EXTENSION_DEGREE>>,
371371
_amount: usize,

core/threshold/src/execution/online/preprocessing/mod.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,20 +155,17 @@ where
155155
) -> anyhow::Result<Vec<ResiduePoly<Z128, EXTENSION_DEGREE>>>;
156156

157157
/// Fill the masks directly from the [`crate::execution::small_execution::prss::PRSSState`] available from [`SmallSession`]
158-
fn fill_from_small_session(
158+
async fn fill_from_small_session(
159159
&mut self,
160160
session: &mut SmallSession<ResiduePoly<Z128, EXTENSION_DEGREE>>,
161161
amount: usize,
162162
) -> anyhow::Result<()> {
163163
let own_role = session.my_role();
164164

165-
let prss_span = tracing::info_span!("PRSS-MASK.Next", batch_size = amount);
166-
let masks = prss_span.in_scope(|| {
167-
(0..amount)
168-
.map(|_| session.prss_state.mask_next(own_role, B_SWITCH_SQUASH))
169-
.try_collect()
170-
})?;
171-
165+
let masks = session
166+
.prss_state
167+
.mask_next_vec(own_role, B_SWITCH_SQUASH, amount)
168+
.await?;
172169
self.append_masks(masks);
173170

174171
Ok(())

core/threshold/src/execution/small_execution/offline.rs

Lines changed: 81 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use anyhow::Context;
22
use itertools::Itertools;
33
use std::collections::HashMap;
44
use tonic::async_trait;
5-
use tracing::{info_span, instrument};
5+
use tracing::instrument;
66

77
use super::prss::PRSSPrimitives;
88
use crate::error::error_handler::log_error_wrapper;
@@ -12,6 +12,7 @@ use crate::execution::online::preprocessing::memory::InMemoryBasePreprocessing;
1212
use crate::execution::online::preprocessing::{RandomPreprocessing, TriplePreprocessing};
1313
use crate::execution::runtime::session::BaseSessionHandles;
1414
use crate::execution::sharing::shamir::RevealOp;
15+
use crate::thread_handles::spawn_compute_bound;
1516
use crate::{
1617
algebra::structure_traits::{ErrorCorrect, Ring},
1718
execution::{
@@ -114,17 +115,13 @@ async fn next_random_batch<Z: Ring, Ses: SmallSessionHandles<Z>>(
114115
) -> anyhow::Result<Vec<Share<Z>>> {
115116
let my_role = session.my_role();
116117
//Create telemetry span to record all calls to PRSS.Next
117-
let prss_span = info_span!("PRSS.Next", batch_size = amount);
118-
let res = prss_span.in_scope(|| {
119-
let mut res = Vec::with_capacity(amount);
120-
for _ in 0..amount {
121-
res.push(Share::new(
122-
my_role,
123-
session.prss_as_mut().prss_next(my_role)?,
124-
));
125-
}
126-
Ok::<_, anyhow::Error>(res)
127-
})?;
118+
let res = session
119+
.prss_as_mut()
120+
.prss_next_vec(my_role, amount)
121+
.await?
122+
.into_iter()
123+
.map(|x| Share::new(my_role, x))
124+
.collect();
128125
Ok(res)
129126
}
130127

@@ -139,75 +136,78 @@ async fn next_triple_batch<Z: ErrorCorrect, Ses: SmallSessionHandles<Z>, BCast:
139136
broadcast: &BCast,
140137
) -> anyhow::Result<Vec<Triple<Z>>> {
141138
let counters = session.prss().get_counters();
139+
let my_role = session.my_role();
140+
let threshold = session.threshold();
142141
let prss_base_ctr = counters.prss_ctr;
143142
let przs_base_ctr = counters.przs_ctr;
144143

145-
let vec_x_single = prss_list(session, amount)?;
146-
let vec_y_single = prss_list(session, amount)?;
147-
let vec_v_single = prss_list(session, amount)?;
148-
let vec_z_double = przs_list(session, amount)?;
144+
let all_prss = session
145+
.prss_as_mut()
146+
.prss_next_vec(my_role, 3 * amount)
147+
.await?;
148+
let vec_z_double: Vec<_> = session
149+
.prss_as_mut()
150+
.przs_next_vec(my_role, threshold, amount)
151+
.await?;
149152

150-
let mut vec_d_double = Vec::with_capacity(amount);
151-
for i in 0..amount {
152-
let x_single = vec_x_single
153-
.get(i)
154-
.with_context(|| log_error_wrapper("Expected x does not exist"))?
155-
.to_owned();
156-
let y_single = vec_y_single
157-
.get(i)
158-
.with_context(|| log_error_wrapper("Expected y does not exist"))?
159-
.to_owned();
160-
let v_single = vec_v_single
161-
.get(i)
162-
.with_context(|| log_error_wrapper("Expected v does not exist"))?
163-
.to_owned();
164-
let z_double = vec_z_double
165-
.get(i)
166-
.with_context(|| log_error_wrapper("Expected z does not exist"))?
167-
.to_owned();
168-
let v_double = z_double + v_single;
169-
let d_double = x_single * y_single + v_double;
170-
vec_d_double.push(d_double)
153+
let all_prss_cloned = all_prss.clone();
154+
let vec_z_double_cloned = vec_z_double.clone();
155+
let vec_d_double = spawn_compute_bound( move ||{
156+
let mut all_prss = all_prss_cloned.into_iter();
157+
let vec_x_single: Vec<_> = all_prss.by_ref().take(amount).collect();
158+
let vec_y_single: Vec<_> = all_prss.by_ref().take(amount).collect();
159+
let vec_v_single: Vec<_> = all_prss.by_ref().take(amount).collect();
160+
161+
if vec_x_single.len() != amount
162+
|| vec_y_single.len() != amount
163+
|| vec_v_single.len() != amount
164+
|| vec_z_double.len() != amount
165+
{
166+
return Err(anyhow::anyhow!(
167+
"BUG: Not all expected values were generated, x={}, y={}, v={}, z={}. Expected {amount}.",
168+
vec_x_single.len(),
169+
vec_y_single.len(),
170+
vec_v_single.len(),
171+
vec_z_double.len(),
172+
));
171173
}
172174

175+
Ok(vec_x_single
176+
.into_iter()
177+
.zip_eq(vec_y_single.into_iter())
178+
.zip_eq(vec_v_single.into_iter())
179+
.zip_eq(vec_z_double_cloned.into_iter())
180+
.map(|(((x, y), v), z)| x * y + (z + v))
181+
.collect_vec())
182+
}).await??;
183+
173184
let broadcast_res = broadcast
174-
.broadcast_from_all_w_corrupt_set_update(session, vec_d_double.clone().into())
185+
.broadcast_from_all_w_corrupt_set_update(session, vec_d_double.into())
175186
.await?;
176187

177188
//Try reconstructing 2t sharings of d, a None means reconstruction failed.
178-
let recons_vec_d = reconstruct_d_values(session, amount, broadcast_res.clone())?;
189+
let recons_vec_d = reconstruct_d_values(session, amount, broadcast_res.clone()).await?;
179190

191+
let mut all_prss = all_prss.into_iter();
192+
let vec_x_single: Vec<_> = all_prss.by_ref().take(amount).collect();
193+
let vec_y_single: Vec<_> = all_prss.by_ref().take(amount).collect();
194+
let vec_v_single: Vec<_> = all_prss.by_ref().take(amount).collect();
180195
let mut triples = Vec::with_capacity(amount);
181196
let mut bad_triples_idx = Vec::new();
182-
for i in 0..amount {
197+
for (i, (x, (y, z))) in vec_x_single
198+
.into_iter()
199+
.zip_eq(vec_y_single.into_iter().zip_eq(vec_v_single.into_iter()))
200+
.enumerate()
201+
{
183202
//If we managed to reconstruct, we store the triple
184203
if let Some(d) = recons_vec_d
185204
.get(i)
186205
.with_context(|| log_error_wrapper("Not all expected d values exist"))?
187206
{
188207
triples.push(Triple {
189-
a: Share::new(
190-
session.my_role(),
191-
vec_x_single
192-
.get(i)
193-
.with_context(|| log_error_wrapper("Not all expected x values exist"))?
194-
.to_owned(),
195-
),
196-
b: Share::new(
197-
session.my_role(),
198-
vec_y_single
199-
.get(i)
200-
.with_context(|| log_error_wrapper("Not all expected y values exist"))?
201-
.to_owned(),
202-
),
203-
c: Share::new(
204-
session.my_role(),
205-
d.to_owned()
206-
- vec_v_single
207-
.get(i)
208-
.with_context(|| log_error_wrapper("Not all expected v values exist"))?
209-
.to_owned(),
210-
),
208+
a: Share::new(session.my_role(), x),
209+
b: Share::new(session.my_role(), y),
210+
c: Share::new(session.my_role(), d.to_owned() - z),
211211
});
212212
//If reconstruction failed, it's a bad triple and we will run cheater identification
213213
} else {
@@ -239,7 +239,7 @@ async fn next_triple_batch<Z: ErrorCorrect, Ses: SmallSessionHandles<Z>, BCast:
239239
/// Helper method to parse the result of the broadcast by taking the ith share from each party and combine them in a vector for which reconstruction is then computed.
240240
/// Returns a list of length `amount` which contains the reconstructed values.
241241
/// In case a wrong amount of elements or a wrong type is returned then the culprit is added to the list of corrupt parties.
242-
fn reconstruct_d_values<Z, Ses: BaseSessionHandles>(
242+
async fn reconstruct_d_values<Z, Ses: BaseSessionHandles>(
243243
session: &mut Ses,
244244
amount: usize,
245245
d_recons: HashMap<Role, BroadcastValue<Z>>,
@@ -289,20 +289,19 @@ where
289289

290290
//We know we may not be able to correct all errors, thus we set max_errors to maximum number of errors the code can correct,
291291
//and deal with failure with the cheater identification strategy
292-
let max_errors = (session.num_parties()
293-
- session.corrupt_roles().len()
294-
- (2 * session.threshold() as usize + 1))
295-
/ 2;
296-
297-
Ok(collected_shares
298-
.into_iter()
299-
.map(|cur_collection| {
300-
let sharing = ShamirSharings::create(cur_collection);
301-
sharing
302-
.err_reconstruct(2 * session.threshold() as usize, max_errors)
303-
.ok()
304-
})
305-
.collect_vec())
292+
let degree = 2 * session.threshold() as usize;
293+
let max_errors = (session.num_parties() - session.corrupt_roles().len() - (degree + 1)) / 2;
294+
295+
spawn_compute_bound(move || {
296+
collected_shares
297+
.into_iter()
298+
.map(|cur_collection| {
299+
let sharing = ShamirSharings::create(cur_collection);
300+
sharing.err_reconstruct(degree, max_errors).ok()
301+
})
302+
.collect_vec()
303+
})
304+
.await
306305
}
307306

308307
/// Helper method which takes the list of d shares of each party (the result of the broadcast)
@@ -345,35 +344,6 @@ fn parse_d_shares<Z: Ring, Ses: BaseSessionHandles>(
345344
Ok(res)
346345
}
347346

348-
/// Output amount of PRSS.Next() calls
349-
#[instrument(name="PRSS.Next",skip(session,amount),fields(sid=?session.session_id(),own_identity=?session.own_identity(),batch_size=?amount))]
350-
fn prss_list<Z: Ring, Ses: SmallSessionHandles<Z>>(
351-
session: &mut Ses,
352-
amount: usize,
353-
) -> anyhow::Result<Vec<Z>> {
354-
let my_id = session.my_role();
355-
let mut vec_prss = Vec::with_capacity(amount);
356-
for _i in 0..amount {
357-
vec_prss.push(session.prss_as_mut().prss_next(my_id)?);
358-
}
359-
Ok(vec_prss)
360-
}
361-
362-
/// Output amount of PRZS.Next() calls
363-
#[instrument(name="PRZS.Next",skip(session,amount),fields(sid=?session.session_id(),own_identity=?session.own_identity(),batch_size=?amount))]
364-
fn przs_list<Z: Ring, Ses: SmallSessionHandles<Z>>(
365-
session: &mut Ses,
366-
amount: usize,
367-
) -> anyhow::Result<Vec<Z>> {
368-
let my_id = session.my_role();
369-
let threshold = session.threshold();
370-
let mut vec_przs = Vec::with_capacity(amount);
371-
for _i in 0..amount {
372-
vec_przs.push(session.prss_as_mut().przs_next(my_id, threshold)?);
373-
}
374-
Ok(vec_przs)
375-
}
376-
377347
/// Helper method for validating results when corruption has happened (by the reconstruction not being successful).
378348
/// The method finds the corrupt parties (based on what they broadcast) and adds them to the list of corrupt parties in the session.
379349
///
@@ -842,8 +812,8 @@ mod test {
842812
/// Unit testing of [`reconstruct_d_values`]
843813
/// Test what happens when a party send a wrong type of value
844814
#[tracing_test::traced_test]
845-
#[test]
846-
fn test_wrong_type() {
815+
#[tokio::test]
816+
async fn test_wrong_type() {
847817
let mut session = get_networkless_base_session_for_parties(4, 1, Role::indexed_from_one(1));
848818
// Observe party 1 inputs a vector of size 1 and party 2 inputs a single element
849819
let d_recons = HashMap::from([
@@ -859,7 +829,9 @@ mod test {
859829
),
860830
]);
861831
assert!(session.corrupt_roles().is_empty());
862-
let res = reconstruct_d_values(&mut session, 1, d_recons).unwrap();
832+
let res = reconstruct_d_values(&mut session, 1, d_recons)
833+
.await
834+
.unwrap();
863835
assert_eq!(1, session.corrupt_roles().len());
864836
assert!(session.corrupt_roles().contains(&Role::indexed_from_one(2)));
865837
assert!(logs_contain(

0 commit comments

Comments
 (0)