@@ -2,7 +2,7 @@ use anyhow::Context;
22use itertools:: Itertools ;
33use std:: collections:: HashMap ;
44use tonic:: async_trait;
5- use tracing:: { info_span , instrument} ;
5+ use tracing:: instrument;
66
77use super :: prss:: PRSSPrimitives ;
88use crate :: error:: error_handler:: log_error_wrapper;
@@ -12,6 +12,7 @@ use crate::execution::online::preprocessing::memory::InMemoryBasePreprocessing;
1212use crate :: execution:: online:: preprocessing:: { RandomPreprocessing , TriplePreprocessing } ;
1313use crate :: execution:: runtime:: session:: BaseSessionHandles ;
1414use crate :: execution:: sharing:: shamir:: RevealOp ;
15+ use crate :: thread_handles:: spawn_compute_bound;
1516use crate :: {
1617 algebra:: structure_traits:: { ErrorCorrect , Ring } ,
1718 execution:: {
@@ -114,17 +115,13 @@ async fn next_random_batch<Z: Ring, Ses: SmallSessionHandles<Z>>(
114115) -> anyhow:: Result < Vec < Share < Z > > > {
115116 let my_role = session. my_role ( ) ;
116117 //Create telemetry span to record all calls to PRSS.Next
117- let prss_span = info_span ! ( "PRSS.Next" , batch_size = amount) ;
118- let res = prss_span. in_scope ( || {
119- let mut res = Vec :: with_capacity ( amount) ;
120- for _ in 0 ..amount {
121- res. push ( Share :: new (
122- my_role,
123- session. prss_as_mut ( ) . prss_next ( my_role) ?,
124- ) ) ;
125- }
126- Ok :: < _ , anyhow:: Error > ( res)
127- } ) ?;
118+ let res = session
119+ . prss_as_mut ( )
120+ . prss_next_vec ( my_role, amount)
121+ . await ?
122+ . into_iter ( )
123+ . map ( |x| Share :: new ( my_role, x) )
124+ . collect ( ) ;
128125 Ok ( res)
129126}
130127
@@ -139,75 +136,78 @@ async fn next_triple_batch<Z: ErrorCorrect, Ses: SmallSessionHandles<Z>, BCast:
139136 broadcast : & BCast ,
140137) -> anyhow:: Result < Vec < Triple < Z > > > {
141138 let counters = session. prss ( ) . get_counters ( ) ;
139+ let my_role = session. my_role ( ) ;
140+ let threshold = session. threshold ( ) ;
142141 let prss_base_ctr = counters. prss_ctr ;
143142 let przs_base_ctr = counters. przs_ctr ;
144143
145- let vec_x_single = prss_list ( session, amount) ?;
146- let vec_y_single = prss_list ( session, amount) ?;
147- let vec_v_single = prss_list ( session, amount) ?;
148- let vec_z_double = przs_list ( session, amount) ?;
144+ let all_prss = session
145+ . prss_as_mut ( )
146+ . prss_next_vec ( my_role, 3 * amount)
147+ . await ?;
148+ let vec_z_double: Vec < _ > = session
149+ . prss_as_mut ( )
150+ . przs_next_vec ( my_role, threshold, amount)
151+ . await ?;
149152
150- let mut vec_d_double = Vec :: with_capacity ( amount) ;
151- for i in 0 ..amount {
152- let x_single = vec_x_single
153- . get ( i)
154- . with_context ( || log_error_wrapper ( "Expected x does not exist" ) ) ?
155- . to_owned ( ) ;
156- let y_single = vec_y_single
157- . get ( i)
158- . with_context ( || log_error_wrapper ( "Expected y does not exist" ) ) ?
159- . to_owned ( ) ;
160- let v_single = vec_v_single
161- . get ( i)
162- . with_context ( || log_error_wrapper ( "Expected v does not exist" ) ) ?
163- . to_owned ( ) ;
164- let z_double = vec_z_double
165- . get ( i)
166- . with_context ( || log_error_wrapper ( "Expected z does not exist" ) ) ?
167- . to_owned ( ) ;
168- let v_double = z_double + v_single;
169- let d_double = x_single * y_single + v_double;
170- vec_d_double. push ( d_double)
153+ let all_prss_cloned = all_prss. clone ( ) ;
154+ let vec_z_double_cloned = vec_z_double. clone ( ) ;
155+ let vec_d_double = spawn_compute_bound ( move ||{
156+ let mut all_prss = all_prss_cloned. into_iter ( ) ;
157+ let vec_x_single: Vec < _ > = all_prss. by_ref ( ) . take ( amount) . collect ( ) ;
158+ let vec_y_single: Vec < _ > = all_prss. by_ref ( ) . take ( amount) . collect ( ) ;
159+ let vec_v_single: Vec < _ > = all_prss. by_ref ( ) . take ( amount) . collect ( ) ;
160+
161+ if vec_x_single. len ( ) != amount
162+ || vec_y_single. len ( ) != amount
163+ || vec_v_single. len ( ) != amount
164+ || vec_z_double. len ( ) != amount
165+ {
166+ return Err ( anyhow:: anyhow!(
167+ "BUG: Not all expected values were generated, x={}, y={}, v={}, z={}. Expected {amount}." ,
168+ vec_x_single. len( ) ,
169+ vec_y_single. len( ) ,
170+ vec_v_single. len( ) ,
171+ vec_z_double. len( ) ,
172+ ) ) ;
171173 }
172174
175+ Ok ( vec_x_single
176+ . into_iter ( )
177+ . zip_eq ( vec_y_single. into_iter ( ) )
178+ . zip_eq ( vec_v_single. into_iter ( ) )
179+ . zip_eq ( vec_z_double_cloned. into_iter ( ) )
180+ . map ( |( ( ( x, y) , v) , z) | x * y + ( z + v) )
181+ . collect_vec ( ) )
182+ } ) . await ??;
183+
173184 let broadcast_res = broadcast
174- . broadcast_from_all_w_corrupt_set_update ( session, vec_d_double. clone ( ) . into ( ) )
185+ . broadcast_from_all_w_corrupt_set_update ( session, vec_d_double. into ( ) )
175186 . await ?;
176187
177188 //Try reconstructing 2t sharings of d, a None means reconstruction failed.
178- let recons_vec_d = reconstruct_d_values ( session, amount, broadcast_res. clone ( ) ) ?;
189+ let recons_vec_d = reconstruct_d_values ( session, amount, broadcast_res. clone ( ) ) . await ?;
179190
191+ let mut all_prss = all_prss. into_iter ( ) ;
192+ let vec_x_single: Vec < _ > = all_prss. by_ref ( ) . take ( amount) . collect ( ) ;
193+ let vec_y_single: Vec < _ > = all_prss. by_ref ( ) . take ( amount) . collect ( ) ;
194+ let vec_v_single: Vec < _ > = all_prss. by_ref ( ) . take ( amount) . collect ( ) ;
180195 let mut triples = Vec :: with_capacity ( amount) ;
181196 let mut bad_triples_idx = Vec :: new ( ) ;
182- for i in 0 ..amount {
197+ for ( i, ( x, ( y, z) ) ) in vec_x_single
198+ . into_iter ( )
199+ . zip_eq ( vec_y_single. into_iter ( ) . zip_eq ( vec_v_single. into_iter ( ) ) )
200+ . enumerate ( )
201+ {
183202 //If we managed to reconstruct, we store the triple
184203 if let Some ( d) = recons_vec_d
185204 . get ( i)
186205 . with_context ( || log_error_wrapper ( "Not all expected d values exist" ) ) ?
187206 {
188207 triples. push ( Triple {
189- a : Share :: new (
190- session. my_role ( ) ,
191- vec_x_single
192- . get ( i)
193- . with_context ( || log_error_wrapper ( "Not all expected x values exist" ) ) ?
194- . to_owned ( ) ,
195- ) ,
196- b : Share :: new (
197- session. my_role ( ) ,
198- vec_y_single
199- . get ( i)
200- . with_context ( || log_error_wrapper ( "Not all expected y values exist" ) ) ?
201- . to_owned ( ) ,
202- ) ,
203- c : Share :: new (
204- session. my_role ( ) ,
205- d. to_owned ( )
206- - vec_v_single
207- . get ( i)
208- . with_context ( || log_error_wrapper ( "Not all expected v values exist" ) ) ?
209- . to_owned ( ) ,
210- ) ,
208+ a : Share :: new ( session. my_role ( ) , x) ,
209+ b : Share :: new ( session. my_role ( ) , y) ,
210+ c : Share :: new ( session. my_role ( ) , d. to_owned ( ) - z) ,
211211 } ) ;
212212 //If reconstruction failed, it's a bad triple and we will run cheater identification
213213 } else {
@@ -239,7 +239,7 @@ async fn next_triple_batch<Z: ErrorCorrect, Ses: SmallSessionHandles<Z>, BCast:
239239/// Helper method to parse the result of the broadcast by taking the ith share from each party and combine them in a vector for which reconstruction is then computed.
240240/// Returns a list of length `amount` which contains the reconstructed values.
241241/// In case a wrong amount of elements or a wrong type is returned then the culprit is added to the list of corrupt parties.
242- fn reconstruct_d_values < Z , Ses : BaseSessionHandles > (
242+ async fn reconstruct_d_values < Z , Ses : BaseSessionHandles > (
243243 session : & mut Ses ,
244244 amount : usize ,
245245 d_recons : HashMap < Role , BroadcastValue < Z > > ,
@@ -289,20 +289,19 @@ where
289289
290290 //We know we may not be able to correct all errors, thus we set max_errors to maximum number of errors the code can correct,
291291 //and deal with failure with the cheater identification strategy
292- let max_errors = ( session. num_parties ( )
293- - session. corrupt_roles ( ) . len ( )
294- - ( 2 * session. threshold ( ) as usize + 1 ) )
295- / 2 ;
296-
297- Ok ( collected_shares
298- . into_iter ( )
299- . map ( |cur_collection| {
300- let sharing = ShamirSharings :: create ( cur_collection) ;
301- sharing
302- . err_reconstruct ( 2 * session. threshold ( ) as usize , max_errors)
303- . ok ( )
304- } )
305- . collect_vec ( ) )
292+ let degree = 2 * session. threshold ( ) as usize ;
293+ let max_errors = ( session. num_parties ( ) - session. corrupt_roles ( ) . len ( ) - ( degree + 1 ) ) / 2 ;
294+
295+ spawn_compute_bound ( move || {
296+ collected_shares
297+ . into_iter ( )
298+ . map ( |cur_collection| {
299+ let sharing = ShamirSharings :: create ( cur_collection) ;
300+ sharing. err_reconstruct ( degree, max_errors) . ok ( )
301+ } )
302+ . collect_vec ( )
303+ } )
304+ . await
306305}
307306
308307/// Helper method which takes the list of d shares of each party (the result of the broadcast)
@@ -345,35 +344,6 @@ fn parse_d_shares<Z: Ring, Ses: BaseSessionHandles>(
345344 Ok ( res)
346345}
347346
348- /// Output amount of PRSS.Next() calls
349- #[ instrument( name="PRSS.Next" , skip( session, amount) , fields( sid=?session. session_id( ) , own_identity=?session. own_identity( ) , batch_size=?amount) ) ]
350- fn prss_list < Z : Ring , Ses : SmallSessionHandles < Z > > (
351- session : & mut Ses ,
352- amount : usize ,
353- ) -> anyhow:: Result < Vec < Z > > {
354- let my_id = session. my_role ( ) ;
355- let mut vec_prss = Vec :: with_capacity ( amount) ;
356- for _i in 0 ..amount {
357- vec_prss. push ( session. prss_as_mut ( ) . prss_next ( my_id) ?) ;
358- }
359- Ok ( vec_prss)
360- }
361-
362- /// Output amount of PRZS.Next() calls
363- #[ instrument( name="PRZS.Next" , skip( session, amount) , fields( sid=?session. session_id( ) , own_identity=?session. own_identity( ) , batch_size=?amount) ) ]
364- fn przs_list < Z : Ring , Ses : SmallSessionHandles < Z > > (
365- session : & mut Ses ,
366- amount : usize ,
367- ) -> anyhow:: Result < Vec < Z > > {
368- let my_id = session. my_role ( ) ;
369- let threshold = session. threshold ( ) ;
370- let mut vec_przs = Vec :: with_capacity ( amount) ;
371- for _i in 0 ..amount {
372- vec_przs. push ( session. prss_as_mut ( ) . przs_next ( my_id, threshold) ?) ;
373- }
374- Ok ( vec_przs)
375- }
376-
377347/// Helper method for validating results when corruption has happened (by the reconstruction not being successful).
378348/// The method finds the corrupt parties (based on what they broadcast) and adds them to the list of corrupt parties in the session.
379349///
@@ -842,8 +812,8 @@ mod test {
842812 /// Unit testing of [`reconstruct_d_values`]
843813 /// Test what happens when a party send a wrong type of value
844814 #[ tracing_test:: traced_test]
845- #[ test]
846- fn test_wrong_type ( ) {
815+ #[ tokio :: test]
816+ async fn test_wrong_type ( ) {
847817 let mut session = get_networkless_base_session_for_parties ( 4 , 1 , Role :: indexed_from_one ( 1 ) ) ;
848818 // Observe party 1 inputs a vector of size 1 and party 2 inputs a single element
849819 let d_recons = HashMap :: from ( [
@@ -859,7 +829,9 @@ mod test {
859829 ) ,
860830 ] ) ;
861831 assert ! ( session. corrupt_roles( ) . is_empty( ) ) ;
862- let res = reconstruct_d_values ( & mut session, 1 , d_recons) . unwrap ( ) ;
832+ let res = reconstruct_d_values ( & mut session, 1 , d_recons)
833+ . await
834+ . unwrap ( ) ;
863835 assert_eq ! ( 1 , session. corrupt_roles( ) . len( ) ) ;
864836 assert ! ( session. corrupt_roles( ) . contains( & Role :: indexed_from_one( 2 ) ) ) ;
865837 assert ! ( logs_contain(
0 commit comments