Skip to content

Commit bfd73a9

Browse files
azixusAnomalRoil
andauthored
feat(adkg): execute ACSS upon receiving start signal (#263)
Co-authored-by: AnomalRoil <AnomalRoil@users.noreply.github.com>
1 parent 123683b commit bfd73a9

3 files changed

Lines changed: 129 additions & 36 deletions

File tree

bin/adkg-cli/src/adkg_dxkr23.rs

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ async fn adkg_pairing_out_g2<'a, E, S, TBT>(
128128
adkg_scheme: S,
129129
topic_transport: Arc<TBT>,
130130
writer: Option<InMemoryWriter>,
131-
mut rng: impl AdkgRng + 'static,
131+
rng: impl AdkgRng + 'static,
132132
) -> anyhow::Result<()>
133133
where
134134
E: Pairing,
@@ -171,7 +171,7 @@ where
171171
group_config,
172172
topic_transport,
173173
adkg_scheme,
174-
&mut rng,
174+
rng,
175175
tx_adkg_out,
176176
)
177177
.await
@@ -315,7 +315,7 @@ async fn adkg_dxkr23<S, TBT>(
315315
group_config: GroupConfig,
316316
topic_transport: Arc<TBT>,
317317
adkg_scheme: S,
318-
rng: &mut impl AdkgRng,
318+
rng: impl AdkgRng + 'static,
319319
out: oneshot::Sender<AdkgOutput<S::Curve>>,
320320
) -> anyhow::Result<()>
321321
where
@@ -325,9 +325,9 @@ where
325325
S::ABAConfig: AbaConfig<'static, PartyId, Input = AbaCrainInput<S::Curve>>,
326326
<S::ACSSConfig as AcssConfig<'static, S::Curve, PartyId>>::Output:
327327
Into<ShareWithPoly<S::Curve>>,
328-
TBT: TopicBasedTransport<Identity = PartyId>,
328+
TBT: TopicBasedTransport<Identity = PartyId> + Send + Sync + 'static,
329329
{
330-
let mut adkg = adkg_scheme.new_adkg(
330+
let adkg = adkg_scheme.new_adkg(
331331
adkg_config.id,
332332
group_config.n,
333333
group_config.t,
@@ -336,6 +336,10 @@ where
336336
pks.clone(),
337337
)?;
338338

339+
let (adkg_start_tx, adkg_start_rx) = oneshot::channel();
340+
let (adkg_stop_tx, adkg_stop_rx) = oneshot::channel();
341+
let adkg_out = adkg.run(adkg_start_rx, adkg_stop_rx, rng, topic_transport);
342+
339343
// Calculate time to sleep before actively executing the adkg
340344
let sleep_duration = (group_config.start_time - chrono::Utc::now())
341345
.to_std() // TimeDelta to positive duration
@@ -353,11 +357,14 @@ where
353357
"Executing ADKG with a timeout of {}",
354358
humantime::format_duration(adkg_config.timeout)
355359
);
360+
if adkg_start_tx.send(()).is_err() {
361+
anyhow::bail!("Failed to send ADKG start signal");
362+
}
356363

357364
let res = tokio::select! {
358-
output = adkg.start(rng, topic_transport) => {
359-
let output = match output {
360-
Ok(adkg_out) => {
365+
output = adkg_out => {
366+
let output: anyhow::Result<_> = match output {
367+
Some(Ok(adkg_out)) => {
361368
tracing::info!(used_sessions = ?adkg_out.used_sessions, "Successfully obtained secret key & output from ADKG");
362369
if out.send(adkg_out).is_err() {
363370
// fails if the receiver side is dropped early
@@ -368,9 +375,13 @@ where
368375
tokio::time::sleep(adkg_config.grace_period).await;
369376
Ok(())
370377
}
371-
Err(e) => {
378+
Some(Err(e)) => {
372379
tracing::error!("failed to obtain output from ADKG: {e:?}");
373-
Err(e)
380+
Err(e.into())
381+
}
382+
None => {
383+
tracing::error!("failed to obtain output from ADKG: stopped before an output");
384+
Err(anyhow!("ADKG stopped before output"))
374385
}
375386
};
376387

@@ -384,9 +395,9 @@ where
384395
};
385396

386397
tracing::warn!("Stopping ADKG...");
387-
adkg.stop().await;
398+
let _ = adkg_stop_tx.send(());
388399

389-
Ok(res??)
400+
res?
390401
}
391402

392403
/// Pairing-based DLEQ proof that there exists an s_j s.t. P_1 = [s_j] G_1 \land P_2 = [s_j] G_2,

crates/adkg/src/adkg.rs

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use futures::{FutureExt, pin_mut};
12
mod randex;
23
pub(crate) mod types;
34

@@ -190,6 +191,7 @@ where
190191
ACSSConfig::Output: Into<ShareWithPoly<CG>>,
191192
ABAConfig: AbaConfig<'static, PartyId, Input = AbaCrainInput<CG>>,
192193
{
194+
/// Start the ADKG immediately
193195
pub async fn start<T>(
194196
&mut self,
195197
rng: &mut impl AdkgRng,
@@ -198,7 +200,47 @@ where
198200
where
199201
T: TopicBasedTransport<Identity = PartyId>,
200202
{
201-
self.execute(rng, transport).await
203+
self.execute_internal(std::future::ready(()), rng, transport)
204+
.await
205+
}
206+
207+
/// An alternative way to execute the adkg by managing the lifecycle asynchronously.
208+
/// The function executes the ADKG once `start` is resolved, and stops `stop` is resolved.
209+
///
210+
/// The function returns immediately with a future that resolves upon obtaining an output.
211+
pub fn run<T>(
212+
mut self,
213+
start: impl Future + Send + 'static,
214+
stop: impl Future<Output: Send> + Send + 'static,
215+
mut rng: impl AdkgRng + 'static,
216+
transport: Arc<T>,
217+
) -> impl Future<Output = Option<Result<AdkgOutput<CG>, AdkgError>>>
218+
where
219+
T: TopicBasedTransport<Identity = PartyId> + Send + Sync + 'static,
220+
{
221+
let (output_tx, output_rx) = tokio::sync::oneshot::channel();
222+
tokio::spawn({
223+
async move {
224+
pin_mut!(stop);
225+
226+
tokio::select! {
227+
out = self.execute_internal(start, &mut rng, transport) => {
228+
// Send output
229+
let _ = output_tx.send(out);
230+
231+
// Wait for the stop signal
232+
stop.await;
233+
},
234+
_ = &mut stop => (),
235+
}
236+
237+
// stop signal received, stop ADKG
238+
info!("Stop signal received, stopping ADKG");
239+
self.stop().await;
240+
}
241+
});
242+
243+
output_rx.map(Result::ok)
202244
}
203245

204246
pub async fn stop(mut self) {
@@ -259,8 +301,9 @@ where
259301
}
260302
}
261303

262-
async fn execute<T>(
304+
async fn execute_internal<T>(
263305
&mut self,
306+
start_signal: impl Future,
264307
rng: &mut impl AdkgRng,
265308
transport: Arc<T>,
266309
) -> Result<AdkgOutput<CG>, AdkgError>
@@ -304,25 +347,29 @@ where
304347
.collect();
305348

306349
// Start the multi RBC, ACSS and ABA
307-
state
308-
.multi_acss
309-
.lock()
310-
.await
311-
.start(s, rng, transport.clone());
350+
state.multi_acss.lock().await.start(rng, transport.clone());
312351
state
313352
.multi_rbc
314353
.lock()
315354
.await
316355
.start(rbc_predicates, transport.clone());
317356
state.multi_aba.lock().await.start(rng, transport.clone());
318357

358+
// Get the ACSS sender
359+
let acss_leader_sender = state
360+
.multi_acss
361+
.lock()
362+
.await
363+
.get_leader_sender()
364+
.expect("failed to get acss leader sender");
365+
319366
// Get the node's own RBC
320-
let leader_sender = state
367+
let rbc_leader_sender = state
321368
.multi_rbc
322369
.lock()
323370
.await
324371
.get_leader_sender()
325-
.expect("failed to get leader sender");
372+
.expect("failed to get rbc leader sender");
326373

327374
// Create cancellation tokens for each subtask
328375
let acss_cancel = self.cancel.child_token();
@@ -331,7 +378,7 @@ where
331378

332379
// Handler for the key set proposal phase. Manages the termination of
333380
self.acss_task = Some(task::spawn(Self::acss_task(
334-
leader_sender,
381+
rbc_leader_sender,
335382
state.clone(),
336383
acss_cancel.clone(),
337384
)));
@@ -344,6 +391,15 @@ where
344391
// Upon termination of jth ABA
345392
let abas_task = task::spawn(Self::aba_outputs_task(state.clone(), aba_cancel.clone()));
346393

394+
// Everything has been set-up, wait for the start signal
395+
start_signal.await;
396+
if acss_leader_sender.send(s).is_err() {
397+
error!(
398+
"ADKG main thread of node `{}` failed to set ACSS input",
399+
self.id
400+
);
401+
}
402+
347403
// Try to join ABAs task, and obtain the final list of parties.
348404
info!(
349405
"ADKG main thread of node `{}` waiting on ABA task to complete",

crates/adkg/src/vss/acss/multi_acss.rs

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,21 @@ where
2424
acss_config: Arc<ACSSConfig>,
2525

2626
// Attributes used to manage the subtasks
27-
acss_tasks: JoinSet<(SessionId, Result<(), ACSSConfig::Error>)>, // set of acss tasks
27+
acss_tasks: JoinSet<(SessionId, Result<(), MultiAcssError>)>, // set of acss tasks
2828
acss_receivers: Vec<Option<oneshot::Receiver<ACSSConfig::Output>>>,
29+
acss_leader_sender: Option<oneshot::Sender<ACSSConfig::Input>>, // set the leader input
2930
cancels: Vec<CancellationToken>,
3031
}
3132

33+
#[derive(thiserror::Error, Debug)]
34+
pub enum MultiAcssError {
35+
#[error(transparent)]
36+
Acss(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
37+
38+
#[error("failed to get ACSS input from channel: sender dropped")]
39+
AcssInputDropped,
40+
}
41+
3242
impl<CG, ACSSConfig> MultiAcss<CG, ACSSConfig>
3343
where
3444
CG: CurveGroup,
@@ -42,12 +52,14 @@ where
4252
acss_config,
4353
acss_tasks: JoinSet::new(),
4454
acss_receivers: vec![],
55+
acss_leader_sender: None,
4556
cancels,
4657
}
4758
}
4859

4960
/// Start the n parallel ACSS instances in the background.
50-
pub fn start<T>(&mut self, s: ACSSConfig::Input, rng: &mut impl AdkgRng, transport: T)
61+
/// Returns a channel used to transmit the ACSS secret.
62+
pub fn start<T>(&mut self, rng: &mut impl AdkgRng, transport: T)
5163
where
5264
T: TopicBasedTransport<Identity = PartyId>,
5365
{
@@ -57,7 +69,11 @@ where
5769
.map(|(sender, receiver)| (sender, Some(receiver)))
5870
.collect();
5971
self.acss_receivers = receivers;
60-
let mut s = Some(s); // need an option for interior mutability...
72+
73+
// Create one channel for the ACSS input
74+
let (input_tx, input_rx) = oneshot::channel();
75+
self.acss_leader_sender = Some(input_tx);
76+
let mut input_rx = Some(input_rx); // need an option for interior mutability...
6177

6278
for (sid, cancel, sender) in izip!(
6379
SessionId::iter_all(self.n_instances),
@@ -77,26 +93,31 @@ where
7793
// s is not cloneable, and we only want to move it when sid == node_id
7894
// In order to not move s due to the async move below, we take() s only once
7995
// here, and use None when sid != node_id. This allows to move the value only once.
80-
let s = if sid == node_id { s.take() } else { None };
96+
let mut input_rx = if sid == node_id {
97+
input_rx.take()
98+
} else {
99+
None
100+
};
81101

82102
let mut rng = rng
83103
.get(AdkgRngType::Acss(sid))
84104
.expect("failed to obtain acss rng");
85105
async move {
86106
// Start the acss tasks
87107
let res = if sid == node_id {
88-
acss.deal(
89-
s.expect("can only enter once"), // s must be Some(.) since sid == node_id
90-
cancellation_token,
91-
sender,
92-
&mut rng,
93-
)
94-
.instrument(tracing::warn_span!("ACSS::deal", ?sid))
95-
.await
108+
if let Ok(s) = input_rx.take().expect("to enter once").await {
109+
acss.deal(s, cancellation_token, sender, &mut rng)
110+
.instrument(tracing::warn_span!("ACSS::deal", ?sid))
111+
.await
112+
.map_err(|e| MultiAcssError::Acss(e.into()))
113+
} else {
114+
Err(MultiAcssError::AcssInputDropped)
115+
}
96116
} else {
97117
acss.get_share(sid.into(), cancellation_token, sender, &mut rng)
98118
.instrument(tracing::warn_span!("ACSS::get_share", ?sid))
99119
.await
120+
.map_err(|e| MultiAcssError::Acss(e.into()))
100121
};
101122

102123
(sid, res)
@@ -105,6 +126,11 @@ where
105126
}
106127
}
107128

129+
/// Get the oneshot sender used to set the leader output of the ACSS where self.node_id == sid
130+
pub fn get_leader_sender(&mut self) -> Option<oneshot::Sender<ACSSConfig::Input>> {
131+
self.acss_leader_sender.take()
132+
}
133+
108134
/// Create an iterator over the remaining ACSS outputs.
109135
pub fn iter_remaining_outputs(
110136
&mut self,
@@ -124,11 +150,11 @@ where
124150
}
125151

126152
/// Stop the ACSS instances and return Ok(()) if no errors were output, otherwise, return the identifier of failed instances and their errors.
127-
pub async fn stop(self) -> Result<(), Vec<(SessionId, ACSSConfig::Error)>> {
153+
pub async fn stop(self) -> Result<(), Vec<(SessionId, MultiAcssError)>> {
128154
// Signal cancellation through each of the cancellation tokens
129155
self.cancels.iter().for_each(|cancel| cancel.cancel());
130156

131-
let errors: Vec<(SessionId, ACSSConfig::Error)> = self
157+
let errors: Vec<(SessionId, MultiAcssError)> = self
132158
.acss_tasks
133159
.join_all()
134160
.await

0 commit comments

Comments
 (0)