@@ -24,11 +24,21 @@ where
2424 acss_config : Arc < ACSSConfig > ,
2525
2626 // Attributes used to manage the subtasks
27- acss_tasks : JoinSet < ( SessionId , Result < ( ) , ACSSConfig :: Error > ) > , // set of acss tasks
27+ acss_tasks : JoinSet < ( SessionId , Result < ( ) , MultiAcssError > ) > , // set of acss tasks
2828 acss_receivers : Vec < Option < oneshot:: Receiver < ACSSConfig :: Output > > > ,
29+ acss_leader_sender : Option < oneshot:: Sender < ACSSConfig :: Input > > , // set the leader input
2930 cancels : Vec < CancellationToken > ,
3031}
3132
33+ #[ derive( thiserror:: Error , Debug ) ]
34+ pub enum MultiAcssError {
35+ #[ error( transparent) ]
36+ Acss ( #[ from] Box < dyn std:: error:: Error + Send + Sync + ' static > ) ,
37+
38+ #[ error( "failed to get ACSS input from channel: sender dropped" ) ]
39+ AcssInputDropped ,
40+ }
41+
3242impl < CG , ACSSConfig > MultiAcss < CG , ACSSConfig >
3343where
3444 CG : CurveGroup ,
@@ -42,12 +52,14 @@ where
4252 acss_config,
4353 acss_tasks : JoinSet :: new ( ) ,
4454 acss_receivers : vec ! [ ] ,
55+ acss_leader_sender : None ,
4556 cancels,
4657 }
4758 }
4859
4960 /// Start the n parallel ACSS instances in the background.
50- pub fn start < T > ( & mut self , s : ACSSConfig :: Input , rng : & mut impl AdkgRng , transport : T )
61+ /// Returns a channel used to transmit the ACSS secret.
62+ pub fn start < T > ( & mut self , rng : & mut impl AdkgRng , transport : T )
5163 where
5264 T : TopicBasedTransport < Identity = PartyId > ,
5365 {
5769 . map ( |( sender, receiver) | ( sender, Some ( receiver) ) )
5870 . collect ( ) ;
5971 self . acss_receivers = receivers;
60- let mut s = Some ( s) ; // need an option for interior mutability...
72+
73+ // Create one channel for the ACSS input
74+ let ( input_tx, input_rx) = oneshot:: channel ( ) ;
75+ self . acss_leader_sender = Some ( input_tx) ;
76+ let mut input_rx = Some ( input_rx) ; // need an option for interior mutability...
6177
6278 for ( sid, cancel, sender) in izip ! (
6379 SessionId :: iter_all( self . n_instances) ,
@@ -77,26 +93,31 @@ where
7793 // s is not cloneable, and we only want to move it when sid == node_id
7894 // In order to not move s due to the async move below, we take() s only once
7995 // here, and use None when sid != node_id. This allows to move the value only once.
80- let s = if sid == node_id { s. take ( ) } else { None } ;
96+ let mut input_rx = if sid == node_id {
97+ input_rx. take ( )
98+ } else {
99+ None
100+ } ;
81101
82102 let mut rng = rng
83103 . get ( AdkgRngType :: Acss ( sid) )
84104 . expect ( "failed to obtain acss rng" ) ;
85105 async move {
86106 // Start the acss tasks
87107 let res = if sid == node_id {
88- acss . deal (
89- s . expect ( "can only enter once" ) , // s must be Some(.) since sid == node_id
90- cancellation_token ,
91- sender ,
92- & mut rng ,
93- )
94- . instrument ( tracing :: warn_span! ( "ACSS::deal" , ?sid ) )
95- . await
108+ if let Ok ( s ) = input_rx . take ( ) . expect ( "to enter once" ) . await {
109+ acss . deal ( s , cancellation_token , sender , & mut rng )
110+ . instrument ( tracing :: warn_span! ( "ACSS::deal" , ?sid ) )
111+ . await
112+ . map_err ( |e| MultiAcssError :: Acss ( e . into ( ) ) )
113+ } else {
114+ Err ( MultiAcssError :: AcssInputDropped )
115+ }
96116 } else {
97117 acss. get_share ( sid. into ( ) , cancellation_token, sender, & mut rng)
98118 . instrument ( tracing:: warn_span!( "ACSS::get_share" , ?sid) )
99119 . await
120+ . map_err ( |e| MultiAcssError :: Acss ( e. into ( ) ) )
100121 } ;
101122
102123 ( sid, res)
@@ -105,6 +126,11 @@ where
105126 }
106127 }
107128
129+ /// Get the oneshot sender used to set the leader output of the ACSS where self.node_id == sid
130+ pub fn get_leader_sender ( & mut self ) -> Option < oneshot:: Sender < ACSSConfig :: Input > > {
131+ self . acss_leader_sender . take ( )
132+ }
133+
108134 /// Create an iterator over the remaining ACSS outputs.
109135 pub fn iter_remaining_outputs (
110136 & mut self ,
@@ -124,11 +150,11 @@ where
124150 }
125151
126152 /// Stop the ACSS instances and return Ok(()) if no errors were output, otherwise, return the identifier of failed instances and their errors.
127- pub async fn stop ( self ) -> Result < ( ) , Vec < ( SessionId , ACSSConfig :: Error ) > > {
153+ pub async fn stop ( self ) -> Result < ( ) , Vec < ( SessionId , MultiAcssError ) > > {
128154 // Signal cancellation through each of the cancellation tokens
129155 self . cancels . iter ( ) . for_each ( |cancel| cancel. cancel ( ) ) ;
130156
131- let errors: Vec < ( SessionId , ACSSConfig :: Error ) > = self
157+ let errors: Vec < ( SessionId , MultiAcssError ) > = self
132158 . acss_tasks
133159 . join_all ( )
134160 . await
0 commit comments