diff --git a/crates/adkg/Cargo.toml b/crates/adkg/Cargo.toml index 32fcc6ea..36a6fdd0 100644 --- a/crates/adkg/Cargo.toml +++ b/crates/adkg/Cargo.toml @@ -34,7 +34,6 @@ sha2 = { workspace = true, optional = true} sha3.workspace = true thiserror.workspace = true tracing.workspace = true -tracing-subscriber.workspace = true tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync"] } tokio-util.workspace = true futures.workspace = true @@ -45,4 +44,5 @@ dcipher-network = { workspace = true, features = ["in_memory"] } ark-bn254.workspace = true ark-bls12-381.workspace = true utils = { workspace = true, features = ["bls12-381", "bn254", "sha3"] } +tracing-subscriber = { workspace = true, features = ["env-filter"] } rayon = "1.0" diff --git a/crates/adkg/src/aba.rs b/crates/adkg/src/aba.rs index 421e414d..b0a779c5 100644 --- a/crates/adkg/src/aba.rs +++ b/crates/adkg/src/aba.rs @@ -57,7 +57,7 @@ pub trait Aba: Send { } /// A binary estimate can either be 0/1, or \bot. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Copy, Clone, Ord, PartialOrd, PartialEq, Eq, Hash, Serialize, Deserialize)] #[repr(u8)] pub enum Estimate { Bot, diff --git a/crates/adkg/src/aba/crain20.rs b/crates/adkg/src/aba/crain20.rs index c3437625..3a20930d 100644 --- a/crates/adkg/src/aba/crain20.rs +++ b/crates/adkg/src/aba/crain20.rs @@ -1,29 +1,32 @@ //! Implementation of the Tyler Crain's Asynchronous Byzantine Agreement described in https://arxiv.org/pdf/2002.08765. //! We specifically implement the Good-Case-Coin-Free variant described in https://eprint.iacr.org/2021/1591.pdf, Appendix B. -use futures::StreamExt; +mod broadcast; +mod coin; mod ecdh_coin_toss; pub mod messages; +mod recv_handler; use crate::aba::{Aba, AbaConfig, Estimate}; use crate::helpers::{PartyId, SessionId}; use crate::network::{RetryStrategy, broadcast_with_self}; use ark_ec::CurveGroup; use dcipher_network::topic::TopicBasedTransport; -use dcipher_network::{ReceivedMessage, Transport, TransportSender}; +use dcipher_network::{Transport, TransportSender}; use digest::core_api::BlockSizeUser; use digest::crypto_common::rand_core::CryptoRng; use digest::{DynDigest, FixedOutputReset}; -use ecdh_coin_toss::{Coin, EcdhCoinTossError, EcdhCoinTossEval}; -use messages::{AbaMessage, AuxStage, CoinEvalMessage, View}; -use messages::{AuxiliaryMessage, AuxiliarySetMessage, EstimateMessage}; +use ecdh_coin_toss::{EcdhCoinTossError, EcdhCoinTossEval}; +use futures::future::Either; +use messages::AuxiliarySetMessage; +use messages::{AbaMessage, AuxStage, View}; use rand::RngCore; -use serde::{Deserialize, Serialize}; -use std::collections::HashSet; +use serde::Deserialize; +use std::collections::BTreeSet; use std::hash::Hash; -use std::ops::{Index, IndexMut}; +use std::ops::{Deref, DerefMut, Index, IndexMut}; +use std::pin::pin; use std::{ - borrow::Borrow, collections::{BTreeMap, HashMap, btree_map::Entry}, marker::PhantomData, sync::Arc, @@ -33,7 +36,7 @@ use tokio::sync::oneshot::error::RecvError; use tokio::sync::{Mutex, Notify, oneshot}; use tokio::task::JoinError; use tokio_util::sync::CancellationToken; -use tracing::{Level, debug, error, event, info, trace, warn}; +use tracing::{Level, debug, error, event, info}; use utils::hash_to_curve::HashToCurve; use utils::serialize::fq::FqSerialize; use utils::serialize::point::PointSerializeCompressed; @@ -155,15 +158,14 @@ where } struct AbaState { - notify_count_est: NotifyMap<(u8, AuxStage)>, // notify upon receiving 2t + 1 binary estimates (Algorithm 3, Line 7) - - bin_values: Mutex>, // replace vec by bitset / integer + notify_bin_values: NotifyMap<(u8, AuxStage)>, // notify upon receiving 2t + 1 binary estimates (Algorithm 3, Line 7) + bin_values: Mutex>, notify_count_aux: NotifyMap<(u8, AuxStage)>, // notify upon receiving n - t aux agreements (Algorithm 4, Line 4) - count_aux: Mutex>, // count aux messages + aux_views: Mutex>, // store each views sent through aux messages - notify_count_auxset: NotifyMap, // notify upon receiving n - t auxset agreements (Algorithm 5, Line 7) - count_auxset: Mutex>, // count auxset messages + notify_count_auxset: NotifyMap, // notify upon receiving at least n - t auxset agreements (Algorithm 5, Line 7) + auxset_views: Mutex>, // store each views sent through auxset messages notify_enough_coin_evals: NotifyMap, coin_evals: Mutex>>, @@ -179,10 +181,7 @@ where receiver: T::ReceiveMessageStream, } -struct AbaCrain20Instance -where - TS: TransportSender, -{ +struct AbaCrain20Instance { config: Arc>, sid: SessionId, sender: TS, @@ -229,12 +228,12 @@ where // Initialize the ABA state machine let state = Arc::new(AbaState:: { - count_aux: Mutex::new(PerPartyStorage::new()), + aux_views: Mutex::new(PerPartyStorage::new()), notify_count_aux: NotifyMap::new(), - count_auxset: Mutex::new(PerPartyStorage::new()), + auxset_views: Mutex::new(PerPartyStorage::new()), notify_count_auxset: NotifyMap::new(), bin_values: Mutex::new(HashMap::new()), - notify_count_est: NotifyMap::new(), + notify_bin_values: NotifyMap::new(), notify_enough_coin_evals: NotifyMap::new(), coin_evals: Mutex::new(PerPartyStorage::new()), }); @@ -303,184 +302,6 @@ where } } -impl AbaCrain20 -where - CG: CurveGroup + Copy + HashToCurve + PointSerializeCompressed, - CG::ScalarField: FqSerialize, - EcdhCoinTossEval: for<'de> Deserialize<'de>, - CK: Send + Into> + 'static, - H: Default + DynDigest + BlockSizeUser + Clone + Send + Sync + 'static, - T: Transport, - T::Sender: Clone, -{ - /// Thread responsible for receiving all types of ABA messages and transmitting notifications. - async fn recv_thread( - sid: SessionId, - config: Arc>, - receiver: T::ReceiveMessageStream, - sender: T::Sender, - cancel: CancellationToken, - state: Arc>, - ) { - let id = config.id; - // Stop the thread upon receiving a signal from the cancellation token - tokio::select! { - _ = cancel.cancelled() => { - info!("Node `{id}` in ABA with sid `{sid}` stopping recv_thread"); - } - - _ = Self::recv_loop(config, receiver, sender, state) => {} - } - } - - /// Infinite loop listening for ABA messages and sending notifications. - async fn recv_loop( - config: Arc>, - mut receiver: T::ReceiveMessageStream, - sender: T::Sender, - state: Arc>, - ) { - // Local variables - let mut count_est = PerPartyStorage::new(); - let mut sent_estimate: HashSet = HashSet::new(); - - loop { - let ReceivedMessage { - sender: sender_id, - content, - .. - } = match receiver.next().await { - Some(Ok(m)) => m, - Some(Err(e)) => { - warn!("Node `{}` failed to recv: {e:?}", config.id); - continue; - } - None => { - error!( - "Node `{}` failed to recv: no more items in stream", - config.id - ); - return; - } - }; - - let m: AbaMessage = match bson::from_slice(&content) { - Ok(m) => m, - Err(e) => { - error!(error = ?e, "Node `{}` failed to deserialize message", config.id); - continue; - } - }; - trace!( - "Node `{}` received message {m:?} from {sender_id}", - config.id, - ); - - match m { - // 4: upon receiving BVAL(v) do - AbaMessage::Estimate(est) => { - count_est.insert_once(est, sender_id, true); - let count = count_est.get_count(&est); - - #[allow(clippy::int_plus_one)] - if count >= 2 * config.t + 1 { - // 7: if BVAL(V) received from 2t + 1 different nodes then - // 8: bin_values := bin_values \cup {v} - // add the estimate to the binary values - let mut r_bin_values = state.bin_values.lock().await; - let bin_values = &mut r_bin_values.entry(est.round).or_default()[est.stage]; - if bin_values.contains(&est.estimate) { - drop(r_bin_values); - } else { - bin_values.push(est.estimate); - drop(r_bin_values); - - // notify of update to bin_values - debug!( - "Node {} notifying bin values for round {}", - config.id, est.round - ); - state.notify_count_est.notify_one((est.round, est.stage)) - }; - } else if count >= config.t + 1 && !sent_estimate.contains(&est) { - // 5: if BVAL(v) received from t + 1 different nodes AND BVAL(v) was not sent, then - // 6: Send BVAL(v) to all nodes - let msg_est = AbaMessage::Estimate(est); - if let Err(e) = - broadcast_with_self(&msg_est, &config.retry_strategy, &sender).await - { - error!( - "Node `{}` failed to broadcast estimate message: {e:?}", - config.id - ) - } - sent_estimate.insert(est); - } - } - - AbaMessage::Auxiliary(est) => { - // Insert aux message - let mut count_aux = state.count_aux.lock().await; - count_aux.insert_once(est, sender_id, true); - let count_bot = count_aux.get_count(&AuxiliaryMessage { - round: est.round, - stage: est.stage, - estimate: Estimate::Bot, - }); - let count_0 = count_aux.get_count(&AuxiliaryMessage { - round: est.round, - stage: est.stage, - estimate: Estimate::Zero, - }); - let count_1 = count_aux.get_count(&AuxiliaryMessage { - round: est.round, - stage: est.stage, - estimate: Estimate::One, - }); - drop(count_aux); // drop mutex - - // Did we receive at least n - t aux for {0}, {1}, or {0, 1} - if count_0 + count_1 + count_bot >= config.n - config.t { - state.notify_count_aux.notify_one((est.round, est.stage)); - } - } - - AbaMessage::AuxiliarySet(set_view) => { - // Insert aux message - // lock count_aux mutex - let mut count_auxset = state.count_auxset.lock().await; - count_auxset.insert_once(set_view, sender_id, true); - drop(count_auxset); // drop mutex - - // notify on each new AUXSET message - state.notify_count_auxset.notify_one(set_view.round); - } - - AbaMessage::CoinEval(msg_eval) => { - // Deserialize eval - let Ok(eval): Result, _> = msg_eval.borrow().try_into() - else { - warn!("Failed to deserialize CoinEvalMessage"); - continue; - }; - - // Store one eval per party, per round. We cannot verify it here - // since the node may not be ready to check evaluations yet. - let mut coin_evals = state.coin_evals.lock().await; - coin_evals.insert_once(msg_eval.round, sender_id, eval); - let count = coin_evals.get_count(&msg_eval.round); - drop(coin_evals); // drop lock - - // Notify if we have t + 1 evals - if count > config.t { - state.notify_enough_coin_evals.notify_one(msg_eval.round); - } - } - }; - } - } -} - impl AbaCrain20Instance where CG: CurveGroup + Copy + HashToCurve + PointSerializeCompressed, @@ -545,12 +366,18 @@ where self.config.id ); let view_r_1 = loop { - state.notify_count_auxset.notified(r).await; + // wake up each time after having received n - t auxset, or on bin_values update + future_select_pin( + state.notify_count_auxset.notified(r), + state.notify_bin_values.notified((r, AuxStage::Stage1)), + ) + .await; - let count_auxset = state.count_auxset.lock().await; + let auxset_views = state.auxset_views.lock().await; let bin_values = state.bin_values.lock().await; // warn: two locks + let auxset_views = auxset_views.get(&r).to_owned().unwrap_or_default(); let bin_values = &bin_values.get(&r).cloned().unwrap_or_default()[AuxStage::Stage1]; - if let Some(view) = self.get_view_from_auxset(bin_values, &count_auxset, r) { + if let Some(view) = self.construct_view(bin_values, &auxset_views) { event!( Level::DEBUG, "Node `{}` at round `{r}` obtained valid view `{view:?}`", @@ -569,9 +396,10 @@ where let view_r_2 = self.sbv_broadcast(r, AuxStage::Stage2, est, &state).await; // 11: if view[r, 2] = {v}, v \neq \bot then - if view_r_2 == View::Zero || view_r_2 == View::One { + let v = Estimate::from(view_r_2.clone()); + if v != Estimate::Bot { // est \gets v - est = view_r_2.into(); + est = v; info!( "Node {} sid `{}` decided on estimate `{est:?}`", self.config.id, self.sid @@ -618,14 +446,14 @@ where } // 13: if view[r, 2] = {v, \bot} then est \gets v - if view_r_2 == View::BotZero { + if view_r_2 == View::bot_zero() { est = Estimate::Zero; - } else if view_r_2 == View::BotOne { + } else if view_r_2 == View::bot_one() { est = Estimate::One; } // 14: if view[r, 2] = {\bot} then est \gets Coin() - if view_r_2 == View::Bot { + if view_r_2.is_bot() { let coin_keys = coin_keys .as_ref() .expect("coin_keys cannot be None at this point"); @@ -643,361 +471,33 @@ where } } } +} - /// Binary-value broadcast described in https://dl.acm.org/doi/10.1145/2785953, Figure 1 - /// Send the current party's estimate to all other nodes with an Estimate message. - #[tracing::instrument(skip(self))] - async fn bv_broadcast(&self, r: u8, stage: AuxStage, v: Estimate) { - // 1: broadcast B_VAL(v) to all - let msg_est = AbaMessage::Estimate(EstimateMessage { - round: r, - stage, - estimate: v, - }); - - event!( - Level::DEBUG, - "Node `{}` at round `{r}` sending {:?} to all", - self.config.id, - msg_est - ); - if let Err(e) = - broadcast_with_self(&msg_est, &self.config.retry_strategy, &self.sender).await - { - error!( - "Node `{}` failed to broadcast estimate message: {e:?}", - self.config.id - ) - } - } - - /// Synchronized binary-value broadcast described in https://dl.acm.org/doi/10.1145/2785953, Figure 2 - /// Send the current party's estimate to all other nodes with an Estimate message. - #[tracing::instrument(skip(self, state))] - async fn sbv_broadcast( - &self, - r: u8, - stage: AuxStage, - v: Estimate, - state: &Arc>, - ) -> View { - // 1: BV_Broadcast(v) - self.bv_broadcast(r, stage, v).await; - - event!( - Level::DEBUG, - "Node `{}` waiting for bin values", - self.config.id - ); - let bin_values = loop { - // 2: wait until bin_values \neq \emptyset - state.notify_count_est.notified((r, stage)).await; - - let bin_values = state.bin_values.lock().await; - let bin_values = &bin_values.get(&r).cloned().unwrap_or_default()[stage]; - if !bin_values.is_empty() { - event!( - Level::DEBUG, - "Node `{}` obtained bin_values = `{bin_values:?}`", - self.config.id - ); - break bin_values.clone(); - } - }; - - // 3: Send AUX(w) for w \in bin_values to all - for w in bin_values.iter() { - let msg_aux = AbaMessage::Auxiliary(AuxiliaryMessage { - round: r, - stage, - estimate: *w, - }); - event!( - Level::DEBUG, - "Node `{}` sending {:?} to all", - self.config.id, - msg_aux - ); - - if let Err(e) = - broadcast_with_self(&msg_aux, &self.config.retry_strategy, &self.sender).await - { - error!( - "Node `{}` failed to broadcast aux message: {e:?}", - self.config.id - ) - } - } - - // 4: wait until \exists a set view s.t. - // (1) view \subseteq bin_values, and - // (2) contained in AUX(.) messages received from n - t nodes - let view = loop { - event!( - Level::DEBUG, - "Node `{}` waiting for count_aux notification", - self.config.id - ); - // Wait for condition (2) received from recv thread - state.notify_count_aux.notified((r, stage)).await; - - let count_aux = state.count_aux.lock().await; - let bin_values = state.bin_values.lock().await; // warn: two locks, could deadlock - let bin_values = &bin_values.get(&r).cloned().unwrap_or_default()[stage]; - - let view = self.get_view_from_aux(bin_values, &count_aux, stage, r); - if let Some(view) = view { - event!( - Level::DEBUG, - "Node {} obtained view = `{view:?}`", - self.config.id - ); - break view; - } else { - event!( - Level::DEBUG, - "Node {} received notify_count_aux notification while having no binary estimates / not enough aux", - self.config.id - ); - } - }; - // 5: return view - #[allow(clippy::let_and_return)] // for clarity - view - } - - /// Try to get the output from the coin keys receiver, return an error otherwise. - async fn get_coin_keys( - &self, - r: u8, - coin_keys_receiver: oneshot::Receiver, - ) -> Result { - event!( - Level::DEBUG, - "Node `{}` at round `{r}` has not yet obtained keys for common coin protocol, waiting.", - self.config.id - ); - - // Return coin_keys if sender not dropped, err otherwise - match coin_keys_receiver.await { - Ok(coin_keys) => { - event!( - Level::DEBUG, - "Node `{}` at round `{r}` obtained keys for common coin protocol", - self.config.id - ); - - Ok(coin_keys) - } - Err(_) => { - error!( - "Node `{}` at round `{r}` failed to obtain common coin input through channel: sender dropper. Aborting ABA.", - self.config.id - ); - Err(AbaError::CoinKeysRecv) - } - } - } - - /// Try to generate and send a partial coin evaluation, or return an error otherwise. - async fn send_coin_eval( - &self, - r: u8, - coin_keys: &CoinKeys, - rng: &mut RNG, - ) -> Result<(), Box> - where - RNG: RngCore + CryptoRng, - { - let eval = EcdhCoinTossEval::::eval( - &coin_keys.sk, - &Self::coin_input(usize::from(self.sid), &coin_keys.combined_vk, r)?, - &self.config.g, - rng, - ) - .map_err(|e| AbaError::CoinToss(e, "failed to generate coin toss evaluation: {e}"))?; - - let msg_coin_eval = AbaMessage::CoinEval(CoinEvalMessage::new(eval, r).unwrap()); - - if let Err(e) = - broadcast_with_self(&msg_coin_eval, &self.config.retry_strategy, &self.sender).await - { - error!( - "Node `{}` failed to broadcast coin eval message: {e:?}", - self.config.id - ); - } - - Ok(()) - } - - /// Wait for enough evaluations and try to recover a common coin. Returns an error if too many evaluations are invalid. - async fn get_coin( - &self, - r: u8, - state: &Arc>, - coin_keys: &CoinKeys, - ) -> Result> { - // Get the input of the common coin protocol - let coin_input = Self::coin_input( - usize::from(self.sid), - &coin_keys.combined_vk.into_affine().into(), - r, - )?; - - loop { - // Wait until we have enough valid partial coins evals for the current round - event!( - Level::DEBUG, - "Node `{}` at round `{r}` waiting for coin evaluations", - self.config.id - ); - state.notify_enough_coin_evals.notified(r).await; - - // mutex locked for the entire duration, either that or cloning evals - let coin_evals = state.coin_evals.lock().await; - let Some((senders, evals)) = coin_evals.get_all(&r) else { - event!( - Level::DEBUG, - "Node `{}` at round `{r}` received coin evals notifications while not having evals", - self.config.id - ); - continue; - }; - - if evals.len() < self.config.t + 1 { - event!( - Level::DEBUG, - "Node `{}` at round `{r}` does not have enough evals: {} < {}", - self.config.id, - evals.len(), - self.config.t - ); - continue; // not enough evals for this round yet - }; - - // Try to get and return the common coin - let coin_vks: Vec<_> = senders.iter().map(|&j| coin_keys.vks[j]).collect(); - match EcdhCoinTossEval::get_coin( - &evals, - &senders, - &coin_vks, - &coin_input, - &self.config.g, - self.config.t + 1, - ) { - Ok(coin) => return Ok(coin), - Err(e) => { - // Failed to obtain the common coin, we either continue if we don't have all evals yet, or we abort - event!( - Level::WARN, - "Node `{}` at round `{r}` failed to obtain a common coin due to invalid eval(s): {e:?}", - self.config.id - ); - - if evals.len() < self.config.n { - continue; - } else { - event!( - Level::ERROR, - "Node `{}` at round `{r}` failed to obtain a common coin with n evals: {e:?}. Aborting ABA with error.", - self.config.id - ); - Err(AbaError::CoinToss( - e, - "failed to obtain common coin with all evals", - ))? - } - } - } - } - } - - /// Get the input to the common coin. - fn coin_input(sid: usize, combined_vk: &CG, round: u8) -> Result, Box> { - CoinInput { - combined_vk: *combined_vk, - sid, - round, - } - .serialize() - } - - /// Try to extract a view from binary values and auxiliary messages. - fn get_view_from_aux( - &self, - bin_values: &[Estimate], - count: &PerPartyStorage, - stage: AuxStage, - round: u8, - ) -> Option { - assert!(bin_values.len() <= 2); - - // Get possible views, assuming bin_values contains at most 2 elements - // views = { {bin_values[0]}, {bin_values[1]}, {bin_values[0], bin_values[1]} } - let mut views = vec![vec![bin_values[0]]]; - if bin_values.len() == 2 { - views.append(&mut vec![ - vec![bin_values[1]], // either the other value, - vec![bin_values[0], bin_values[1]], // or, both values - ]); - } - - // Try to find a view such that the sum of its estimate count is >= than n - t - for view in views { - let sum: usize = view - .iter() - .map(|est| { - let m = AuxiliaryMessage { - round, - stage, - estimate: *est, - }; - - count.get_count(&m) - }) - .sum(); - - if sum >= self.config.n - self.config.t { - return View::from_estimates(bin_values); - } - } - - None - } - - /// Try to extract a view from binary values and auxiliary set messages. - fn get_view_from_auxset( - &self, - bin_values: &[Estimate], - count: &PerPartyStorage, - round: u8, - ) -> Option { +impl AbaCrain20Instance { + /// Try to build a view from the union of views sent by other nodes, filtered by local binary values + /// obtained through the BV_broadcast algorithm, Figure 1 of . + /// Implements filtering of line (05), Figure 3 of : + /// \exists a view such that its values (i) belong to bin values and comes from views sent by + /// (n − t) distinct processes. + fn construct_view(&self, bin_values: &[Estimate], views: &[&View]) -> Option { assert!(bin_values.len() <= 2); - // Get possible views, assuming bin_values contains at most 2 elements - // views = { {bin_values[0]}, {bin_values[1]}, {bin_values[0], bin_values[1]} } - let mut views = vec![View::from_estimates(&[bin_values[0]])]; - if bin_values.len() == 2 { - views.append(&mut vec![ - View::from_estimates(&[bin_values[1]]), // either the other value, - View::from_estimates(&[bin_values[0], bin_values[1]]), // or, both values - ]); - } + let bin_values = BTreeSet::from_iter(bin_values.iter().copied()); - // Try to find a view such that its count is >= than n - t - for view in views { - let Some(view) = view else { + // Form a view such that its values (i) belong to bin values and comes from views sent by + // (n − t) distinct processes + let mut count = 0; + let mut view_union = View::default(); + for &view in views { + if !view.is_subset(&bin_values) { + // not a subset of bin_values, ignore continue; - }; + } - let count: usize = view - .get_view_superset() - .into_iter() - .map(|view| count.get_count(&AuxiliarySetMessage { round, view })) - .sum(); + view_union.extend(view.iter()); // equivalent to union + count += 1; if count >= self.config.n - self.config.t { - return Some(view); + return Some(view_union); } } @@ -1005,26 +505,6 @@ where } } -/// Structure used to serialize the input of the coin -#[derive(Serialize)] -#[serde(bound(serialize = "CG: PointSerializeCompressed",))] -struct CoinInput { - #[serde(with = "utils::serialize::point::base64")] - combined_vk: CG, - sid: usize, - round: u8, -} - -impl CoinInput -where - CG: PointSerializeCompressed, -{ - fn serialize(&self) -> Result, Box> { - bson::to_vec(&self) - .map_err(|e| AbaError::BsonSer(e, "failed to serialize CoinInput to bson").into()) - } -} - /// Helper struct used for per key, per party storage. /// Used to quickly insert a value for a specific key and party, and , get/count all values belonging to a key, independently of the parties. struct PerPartyStorage { @@ -1039,6 +519,12 @@ where PerPartyStorage { db: HashMap::new() } } + /// Get the entry to a key + fn entry(&mut self, k: K, party: PartyId) -> Entry<'_, PartyId, V> { + let storage = self.db.entry(k).or_default(); + storage.entry(party) + } + /// Only insert if the key is not already present fn insert_once(&mut self, k: K, party: PartyId, v: V) { let storage = self.db.entry(k).or_default(); @@ -1056,6 +542,12 @@ where Some(storage.iter().map(|(k, v)| (*k, v)).unzip()) } + /// Returns the values stored for key k + fn get(&self, k: &K) -> Option> { + let storage = self.db.get(k)?; + Some(storage.values().collect()) + } + /// Returns the number of values stored for key k amongst all parties. fn get_count(&self, k: &K) -> usize { let Some(storage) = self.db.get(k) else { @@ -1123,34 +615,31 @@ impl IndexMut for BinValues { } } +impl Deref for View { + type Target = BTreeSet; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for View { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl View { - /// Creates a view from binary values - pub(crate) fn from_estimates(estimates: &[Estimate]) -> Option { - match estimates { - &[Estimate::Bot] => Some(View::Bot), - &[Estimate::Zero] => Some(View::Zero), - &[Estimate::One] => Some(View::One), - - &[Estimate::Bot, Estimate::Zero] | &[Estimate::Zero, Estimate::Bot] => { - Some(View::BotZero) - } - &[Estimate::Bot, Estimate::One] | &[Estimate::One, Estimate::Bot] => Some(View::BotOne), - &[Estimate::Zero, Estimate::One] | &[Estimate::One, Estimate::Zero] => { - Some(View::ZeroOne) - } + pub(crate) fn bot_zero() -> Self { + Self(BTreeSet::from_iter([Estimate::Bot, Estimate::Zero])) + } - _ => None, - } + pub(crate) fn bot_one() -> Self { + Self(BTreeSet::from_iter([Estimate::Bot, Estimate::One])) } - /// Get the superset of a view, e.g. the superset of Bot is {Bot, BotZero, BotOne} - pub(crate) fn get_view_superset(self) -> Vec { - match self { - View::Bot => vec![View::Bot, View::BotZero, View::BotOne], - View::Zero => vec![View::Zero, View::BotZero, View::ZeroOne], - View::One => vec![View::One, View::BotOne, View::ZeroOne], - View::BotZero | View::BotOne | View::ZeroOne => vec![self], - } + pub(crate) fn is_bot(&self) -> bool { + self.0.first().is_some_and(|est| est == &Estimate::Bot) } } @@ -1167,20 +656,30 @@ where } } +async fn future_select_pin(a: impl Future, b: impl Future) -> Out { + let a = pin!(a); + let b = pin!(b); + match futures::future::select(a, b).await { + Either::Left((o, _)) | Either::Right((o, _)) => o, + } +} + #[cfg(test)] mod tests { - use crate::aba::AbaConfig; - use crate::aba::crain20::{AbaCrain20Config, AbaInput, CoinKeys}; + use crate::aba::crain20::{AbaCrain20, AbaCrain20Config, AbaInput, CoinKeys}; use crate::aba::{Aba, Estimate}; - use crate::helpers::PartyId; + use crate::helpers::{PartyId, SessionId}; use crate::network::RetryStrategy; - use ark_bn254::Bn254; + use ark_bn254::{Bn254, Fr}; use ark_ec::{PrimeGroup, pairing::Pairing}; - use dcipher_network::topic::dispatcher::TopicDispatcher; - use dcipher_network::transports::in_memory::MemoryNetwork; + use ark_poly::univariate::DensePolynomial; + use ark_poly::{DenseUVPolynomial, Polynomial}; + use ark_std::UniformRand; + use dcipher_network::Transport; + use dcipher_network::transports::in_memory::{BusMemoryTransport, MemoryNetwork}; + use itertools::Itertools; use rand::rngs::OsRng; use std::collections::VecDeque; - use std::sync::Arc; use tokio::sync::oneshot; use tokio::task; use tokio::task::JoinSet; @@ -1188,25 +687,86 @@ mod tests { type G = ::G1; + fn gen_keys(n: u16, t: u16, g: G) -> (Vec, G, Vec) { + // Build polynomial from coefficients + let poly_coeffs = (0..t) + .map(|_| ::ScalarField::rand(&mut OsRng)) + .collect::>(); + let p = DensePolynomial::from_coefficients_slice(&poly_coeffs); + + let sk = p.evaluate(&0.into()); + let pk = g * sk; + + let sks = (1..=n).map(|i| p.evaluate(&i.into())).collect::>(); + let pks = sks.iter().map(|ski| g * ski).collect::>(); + + (sks, pk, pks) + } + #[tokio::test] - async fn test_aba_all_parties_est_one() { - _ = tracing_subscriber::fmt() - .with_max_level(tracing::Level::INFO) - .try_init(); + async fn test_aba_agreement() { + let t = 2; + let n = 3 * t + 1; + let g = G::generator(); + let sid = SessionId::const_from(0); + let est = Estimate::One; + + let (sks, pk, pks) = gen_keys(n as u16, t as u16, g); + let estimates: Vec<_> = vec![est; n]; + let final_est = run(n, t, sks, pks, pk, g, estimates, sid).await; + assert_eq!(est, final_est); + } + + #[tokio::test] + async fn test_aba_disagreement() { let t = 2; let n = 3 * t + 1; let g = G::generator(); + let sid = SessionId::const_from(0); + + let (sks, pk, pks) = gen_keys(n as u16, t as u16, g); + let estimates: Vec<_> = PartyId::iter_all(n) + .map(|i| { + // + // let est = if thread_rng().gen_bool(0.5) { + // Estimate::One + // } else { + // Estimate::Zero + // }; + // 50-50 split or so + if i.as_usize() <= n / 2 { + Estimate::One + } else { + Estimate::Zero + } + }) + .collect(); + + run(n, t, sks, pks, pk, g, estimates, sid).await; + } + + #[allow(clippy::too_many_arguments)] + async fn run( + n: usize, + t: usize, + sks: Vec, + pks: Vec, + pk: G, + g: G, + estimates: Vec, + sid: SessionId, + ) -> Estimate { + let mut coin_keys: VecDeque<_> = sks + .into_iter() + .map(|sk| CoinKeys { + sk, + vks: pks.clone(), + combined_vk: pk, + }) + .collect(); - let (dispatchers, mut tbts): (Vec<_>, VecDeque<_>) = - MemoryNetwork::get_transports(PartyId::iter_all(n)) - .into_iter() - .map(|t| { - let mut dispatcher = TopicDispatcher::new(); - let tbt = dispatcher.start(t); - (dispatcher, tbt) - }) - .collect(); + let mut transports: VecDeque<_> = MemoryNetwork::get_transports(PartyId::iter_all(n)); let mut abas: VecDeque<_> = PartyId::iter_all(n) .map(|i| AbaCrain20Config::<_, _, sha3::Sha3_256>::new(i, n, t, g, RetryStrategy::None)) .collect(); @@ -1214,8 +774,10 @@ mod tests { let mut tasks = JoinSet::new(); for i in PartyId::iter_all(n) { tasks.spawn({ - let transport = tbts.pop_front().unwrap(); + let mut transport = transports.pop_front().unwrap(); let aba_config = abas.pop_front().unwrap(); + let coin_keys = coin_keys.pop_front().unwrap(); + let v = estimates[i]; async move { let (isender, ireceiver) = oneshot::channel(); @@ -1225,17 +787,21 @@ mod tests { // Create input with One estimate let (coin_keys_sender, coin_keys_receiver) = oneshot::channel::>(); - drop(coin_keys_sender); // all nodes input the same estimate, coin must not be used + coin_keys_sender.send(coin_keys).unwrap(); + let est = AbaInput { - v: Estimate::One, + v, coin_keys_receiver, }; // Spawn aba task let aba_task = task::spawn(async move { - let aba = aba_config - .new_instance(0.into(), Arc::new(transport)) - .expect("failed to create aba instance"); + let aba = AbaCrain20::<_, _, _, BusMemoryTransport<_>> { + config: aba_config, + receiver: transport.receiver_stream().unwrap(), + sender: transport.sender().unwrap(), + sid, + }; aba.propose(ireceiver, osender, cancel_cloned, &mut OsRng) .await }); @@ -1259,14 +825,15 @@ mod tests { }); } + let mut ests = vec![]; while let Some(res) = tasks.join_next().await { assert!(res.is_ok()); let (_, est) = res.unwrap(); - assert_eq!(est, Estimate::One); + assert!([Estimate::Zero, Estimate::One].contains(&est)); + ests.push(est); } + assert!(ests.iter().all_equal()); - for d in dispatchers { - d.stop().await; - } + *ests.first().unwrap() } } diff --git a/crates/adkg/src/aba/crain20/broadcast.rs b/crates/adkg/src/aba/crain20/broadcast.rs new file mode 100644 index 00000000..df630e0c --- /dev/null +++ b/crates/adkg/src/aba/crain20/broadcast.rs @@ -0,0 +1,146 @@ +//! Implementations of BV_broadcast and SBV_broadcast. + +use crate::aba::crain20::messages::{ + AbaMessage, AuxStage, AuxiliaryMessage, EstimateMessage, View, +}; +use crate::aba::crain20::{AbaCrain20Instance, AbaState}; +use crate::aba::{Estimate, crain20}; +use crate::helpers::PartyId; +use crate::network::broadcast_with_self; +use ark_ec::CurveGroup; +use dcipher_network::TransportSender; +use std::sync::Arc; +use tracing::{Level, error, event}; + +impl AbaCrain20Instance +where + CG: CurveGroup, + TS: TransportSender + Clone, +{ + /// Binary-value broadcast described in https://dl.acm.org/doi/10.1145/2785953, Figure 1 + /// Send the current party's estimate to all other nodes with an Estimate message. + #[tracing::instrument(skip(self))] + async fn bv_broadcast(&self, r: u8, stage: AuxStage, v: Estimate) { + // 1: broadcast B_VAL(v) to all + let msg_est = AbaMessage::Estimate(EstimateMessage { + round: r, + stage, + estimate: v, + }); + + event!( + Level::DEBUG, + "Node `{}` at round `{r}` sending {:?} to all", + self.config.id, + msg_est + ); + if let Err(e) = + broadcast_with_self(&msg_est, &self.config.retry_strategy, &self.sender).await + { + error!( + "Node `{}` failed to broadcast estimate message: {e:?}", + self.config.id + ) + } + } + + /// Synchronized binary-value broadcast described in https://dl.acm.org/doi/10.1145/2785953, Figure 2 + /// Send the current party's estimate to all other nodes with an Estimate message. + #[tracing::instrument(skip(self, state))] + pub(super) async fn sbv_broadcast( + &self, + r: u8, + stage: AuxStage, + v: Estimate, + state: &Arc>, + ) -> View { + // 1: BV_Broadcast(v) + self.bv_broadcast(r, stage, v).await; + + event!( + Level::DEBUG, + "Node `{}` waiting for bin values", + self.config.id + ); + let bin_values = loop { + // 2: wait until bin_values \neq \emptyset + state.notify_bin_values.notified((r, stage)).await; + + let bin_values = state.bin_values.lock().await; + let bin_values = &bin_values.get(&r).cloned().unwrap_or_default()[stage]; + if !bin_values.is_empty() { + event!( + Level::DEBUG, + "Node `{}` obtained bin_values = `{bin_values:?}`", + self.config.id + ); + break bin_values.clone(); + } + }; + + // 3: Send AUX(w) for w \in bin_values to all + for w in bin_values.iter() { + let msg_aux = AbaMessage::Auxiliary(AuxiliaryMessage { + round: r, + stage, + estimate: *w, + }); + event!( + Level::DEBUG, + "Node `{}` sending {:?} to all", + self.config.id, + msg_aux + ); + + if let Err(e) = + broadcast_with_self(&msg_aux, &self.config.retry_strategy, &self.sender).await + { + error!( + "Node `{}` failed to broadcast aux message: {e:?}", + self.config.id + ) + } + } + + // 4: wait until \exists a set view s.t. + // (1) view \subseteq bin_values, and + // (2) contained in AUX(.) messages received from n - t nodes + let view = loop { + event!( + Level::DEBUG, + "Node `{}` waiting for count_aux notification", + self.config.id + ); + + // wake up each time after having received n - t aux, or on bin_values update + crain20::future_select_pin( + state.notify_count_aux.notified((r, stage)), + state.notify_bin_values.notified((r, stage)), + ) + .await; + + let aux_views = state.aux_views.lock().await; + let bin_values = state.bin_values.lock().await; // warn: two locks, could deadlock + let aux_views = aux_views.get(&(r, stage)).to_owned().unwrap_or_default(); + let bin_values = &bin_values.get(&r).cloned().unwrap_or_default()[stage]; + let view = self.construct_view(bin_values, &aux_views); + if let Some(view) = view { + event!( + Level::DEBUG, + "Node {} obtained view = `{view:?}`", + self.config.id + ); + break view; + } else { + event!( + Level::DEBUG, + "Node {} received notify_count_aux notification while having no binary estimates / not enough aux", + self.config.id + ); + } + }; + // 5: return view + #[allow(clippy::let_and_return)] // for clarity + view + } +} diff --git a/crates/adkg/src/aba/crain20/coin.rs b/crates/adkg/src/aba/crain20/coin.rs new file mode 100644 index 00000000..6055d5c5 --- /dev/null +++ b/crates/adkg/src/aba/crain20/coin.rs @@ -0,0 +1,207 @@ +//! Functions used for the coin toss protocol + +use crate::aba::crain20::ecdh_coin_toss::{Coin, EcdhCoinTossEval}; +use crate::aba::crain20::messages::{AbaMessage, CoinEvalMessage}; +use crate::aba::crain20::{AbaCrain20Instance, AbaError, AbaState, CoinKeys}; +use crate::helpers::PartyId; +use crate::network::broadcast_with_self; +use ark_ec::CurveGroup; +use dcipher_network::TransportSender; +use digest::core_api::BlockSizeUser; +use digest::crypto_common::rand_core::CryptoRng; +use digest::{DynDigest, FixedOutputReset}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::oneshot; +use tracing::{Level, error, event}; +use utils::hash_to_curve::HashToCurve; +use utils::serialize::fq::FqSerialize; +use utils::serialize::point::PointSerializeCompressed; + +impl AbaCrain20Instance +where + CG: CurveGroup + Copy + HashToCurve + PointSerializeCompressed, + CG::ScalarField: FqSerialize, + EcdhCoinTossEval: for<'de> Deserialize<'de>, + CK: Send + Into> + 'static, + H: Default + DynDigest + FixedOutputReset + BlockSizeUser + Clone + Send + Sync + 'static, + TS: TransportSender, +{ + /// Try to get the output from the coin keys receiver, return an error otherwise. + pub(super) async fn get_coin_keys( + &self, + r: u8, + coin_keys_receiver: oneshot::Receiver, + ) -> Result { + event!( + Level::DEBUG, + "Node `{}` at round `{r}` has not yet obtained keys for common coin protocol, waiting.", + self.config.id + ); + + // Return coin_keys if sender not dropped, err otherwise + match coin_keys_receiver.await { + Ok(coin_keys) => { + event!( + Level::DEBUG, + "Node `{}` at round `{r}` obtained keys for common coin protocol", + self.config.id + ); + + Ok(coin_keys) + } + Err(_) => { + error!( + "Node `{}` at round `{r}` failed to obtain common coin input through channel: sender dropper. Aborting ABA.", + self.config.id + ); + Err(AbaError::CoinKeysRecv) + } + } + } + + /// Try to generate and send a partial coin evaluation, or return an error otherwise. + pub(super) async fn send_coin_eval( + &self, + r: u8, + coin_keys: &CoinKeys, + rng: &mut RNG, + ) -> Result<(), Box> + where + RNG: RngCore + CryptoRng, + { + let eval = EcdhCoinTossEval::::eval( + &coin_keys.sk, + &Self::coin_input(usize::from(self.sid), &coin_keys.combined_vk, r)?, + &self.config.g, + rng, + ) + .map_err(|e| AbaError::CoinToss(e, "failed to generate coin toss evaluation: {e}"))?; + + let msg_coin_eval = AbaMessage::CoinEval(CoinEvalMessage::new(eval, r).unwrap()); + + if let Err(e) = + broadcast_with_self(&msg_coin_eval, &self.config.retry_strategy, &self.sender).await + { + error!( + "Node `{}` failed to broadcast coin eval message: {e:?}", + self.config.id + ); + } + + Ok(()) + } + + /// Wait for enough evaluations and try to recover a common coin. Returns an error if too many evaluations are invalid. + pub(super) async fn get_coin( + &self, + r: u8, + state: &Arc>, + coin_keys: &CoinKeys, + ) -> Result> { + // Get the input of the common coin protocol + let coin_input = Self::coin_input( + usize::from(self.sid), + &coin_keys.combined_vk.into_affine().into(), + r, + )?; + + loop { + // Wait until we have enough valid partial coins evals for the current round + event!( + Level::DEBUG, + "Node `{}` at round `{r}` waiting for coin evaluations", + self.config.id + ); + state.notify_enough_coin_evals.notified(r).await; + + // mutex locked for the entire duration, either that or cloning evals + let coin_evals = state.coin_evals.lock().await; + let Some((senders, evals)) = coin_evals.get_all(&r) else { + event!( + Level::DEBUG, + "Node `{}` at round `{r}` received coin evals notifications while not having evals", + self.config.id + ); + continue; + }; + + if evals.len() < self.config.t + 1 { + event!( + Level::DEBUG, + "Node `{}` at round `{r}` does not have enough evals: {} < {}", + self.config.id, + evals.len(), + self.config.t + ); + continue; // not enough evals for this round yet + }; + + // Try to get and return the common coin + let coin_vks: Vec<_> = senders.iter().map(|&j| coin_keys.vks[j]).collect(); + match EcdhCoinTossEval::get_coin( + &evals, + &senders, + &coin_vks, + &coin_input, + &self.config.g, + self.config.t + 1, + ) { + Ok(coin) => return Ok(coin), + Err(e) => { + // Failed to obtain the common coin, we either continue if we don't have all evals yet, or we abort + event!( + Level::WARN, + "Node `{}` at round `{r}` failed to obtain a common coin due to invalid eval(s): {e:?}", + self.config.id + ); + + if evals.len() < self.config.n { + continue; + } else { + event!( + Level::ERROR, + "Node `{}` at round `{r}` failed to obtain a common coin with n evals: {e:?}. Aborting ABA with error.", + self.config.id + ); + Err(AbaError::CoinToss( + e, + "failed to obtain common coin with all evals", + ))? + } + } + } + } + } + + /// Get the input to the common coin. + fn coin_input(sid: usize, combined_vk: &CG, round: u8) -> Result, Box> { + CoinInput { + combined_vk: *combined_vk, + sid, + round, + } + .serialize() + } +} + +/// Structure used to serialize the input of the coin +#[derive(Serialize)] +#[serde(bound(serialize = "CG: PointSerializeCompressed",))] +struct CoinInput { + #[serde(with = "utils::serialize::point::base64")] + combined_vk: CG, + sid: usize, + round: u8, +} + +impl CoinInput +where + CG: PointSerializeCompressed, +{ + fn serialize(&self) -> Result, Box> { + bson::to_vec(&self) + .map_err(|e| AbaError::BsonSer(e, "failed to serialize CoinInput to bson").into()) + } +} diff --git a/crates/adkg/src/aba/crain20/messages.rs b/crates/adkg/src/aba/crain20/messages.rs index b00f2b95..aaa13078 100644 --- a/crates/adkg/src/aba/crain20/messages.rs +++ b/crates/adkg/src/aba/crain20/messages.rs @@ -2,6 +2,7 @@ use super::ecdh_coin_toss::{Coin, EcdhCoinTossEval}; use crate::aba::Estimate; use ark_ec::CurveGroup; use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; /// Messages sent during the ABA protocol. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -37,24 +38,15 @@ pub struct AuxiliaryMessage { } /// Message used to send a set of estimates, i.e., a view. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct AuxiliarySetMessage { pub(crate) round: u8, pub(crate) view: View, } /// Set of all possible views during the ABA. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[repr(u8)] -pub enum View { - Bot, - Zero, - One, - - BotZero, - BotOne, - ZeroOne, -} +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize, Deserialize)] +pub struct View(pub(crate) BTreeSet); /// Message to send a partial evaluation for the common coin tossing protocol. #[serde_with::serde_as] @@ -105,10 +97,9 @@ impl From for Estimate { /// Convert single element views into its corresponding estimate, otherwise into Estimate::Bot impl From for Estimate { fn from(val: View) -> Self { - match val { - View::Zero => Estimate::Zero, - View::One => Estimate::One, - _ => Estimate::Bot, + if val.0.len() > 1 { + return Estimate::Bot; } + val.0.first().copied().unwrap_or(Estimate::Bot) } } diff --git a/crates/adkg/src/aba/crain20/recv_handler.rs b/crates/adkg/src/aba/crain20/recv_handler.rs new file mode 100644 index 00000000..bede184a --- /dev/null +++ b/crates/adkg/src/aba/crain20/recv_handler.rs @@ -0,0 +1,177 @@ +//! Handles message received by other nodes. + +use crate::aba::crain20::ecdh_coin_toss::EcdhCoinTossEval; +use crate::aba::crain20::messages::{AbaMessage, EstimateMessage}; +use crate::aba::crain20::{AbaCrain20, AbaCrain20Config, AbaState, PerPartyStorage}; +use crate::helpers::{PartyId, SessionId}; +use crate::network::broadcast_with_self; +use ark_ec::CurveGroup; +use dcipher_network::{ReceivedMessage, Transport}; +use futures::StreamExt; +use serde::Deserialize; +use std::borrow::Borrow; +use std::collections::HashSet; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, debug, error, info, trace, warn}; + +impl AbaCrain20 +where + CG: CurveGroup, + EcdhCoinTossEval: for<'de> Deserialize<'de>, + T: Transport, +{ + /// Thread responsible for receiving all types of ABA messages and transmitting notifications. + pub(super) async fn recv_thread( + sid: SessionId, + config: Arc>, + receiver: T::ReceiveMessageStream, + sender: T::Sender, + cancel: CancellationToken, + state: Arc>, + ) { + let id = config.id; + // Stop the thread upon receiving a signal from the cancellation token + tokio::select! { + _ = cancel.cancelled() => { + info!("Node `{id}` in ABA with sid `{sid}` stopping recv_thread"); + } + + _ = Self::recv_loop(config, receiver, sender, state).instrument(tracing::info_span!("recv_loop", ?sid)) => {} + } + } + + /// Infinite loop listening for ABA messages and sending notifications. + async fn recv_loop( + config: Arc>, + mut receiver: T::ReceiveMessageStream, + sender: T::Sender, + state: Arc>, + ) { + // Local variables + let mut count_est = PerPartyStorage::new(); + let mut sent_estimate: HashSet = HashSet::new(); + + loop { + let ReceivedMessage { + sender: sender_id, + content, + .. + } = match receiver.next().await { + Some(Ok(m)) => m, + Some(Err(e)) => { + warn!("Node `{}` failed to recv: {e:?}", config.id); + continue; + } + None => { + error!( + "Node `{}` failed to recv: no more items in stream", + config.id + ); + return; + } + }; + + let m: AbaMessage = match bson::from_slice(&content) { + Ok(m) => m, + Err(e) => { + error!(error = ?e, "Node `{}` failed to deserialize message", config.id); + continue; + } + }; + trace!( + "Node `{}` received message {m:?} from {sender_id}", + config.id, + ); + + match m { + // 4: upon receiving BVAL(v) do + AbaMessage::Estimate(est) => { + count_est.insert_once(est, sender_id, true); + let count = count_est.get_count(&est); + + #[allow(clippy::int_plus_one)] + if count >= 2 * config.t + 1 { + // 7: if BVAL(V) received from 2t + 1 different nodes then + // 8: bin_values := bin_values \cup {v} + // add the estimate to the binary values + let mut r_bin_values = state.bin_values.lock().await; + let bin_values = &mut r_bin_values.entry(est.round).or_default()[est.stage]; + if bin_values.contains(&est.estimate) { + drop(r_bin_values); + } else { + bin_values.push(est.estimate); + drop(r_bin_values); + + // notify of update to bin_values + debug!( + "Node {} notifying bin values for round {}", + config.id, est.round + ); + state.notify_bin_values.notify_one((est.round, est.stage)); + }; + } else if count >= config.t + 1 && !sent_estimate.contains(&est) { + // 5: if BVAL(v) received from t + 1 different nodes AND BVAL(v) was not sent, then + // 6: Send BVAL(v) to all nodes + let msg_est = AbaMessage::Estimate(est); + if let Err(e) = + broadcast_with_self(&msg_est, &config.retry_strategy, &sender).await + { + error!( + "Node `{}` failed to broadcast estimate message: {e:?}", + config.id + ) + } + sent_estimate.insert(est); + } + } + + AbaMessage::Auxiliary(aux) => { + let mut aux_views = state.aux_views.lock().await; + // Add the new estimate to the current view + aux_views + .entry((aux.round, aux.stage), sender_id) + .or_default() + .insert(aux.estimate); + + // notify once we got at least n - t aux messages + if aux_views.get_count(&(aux.round, aux.stage)) >= config.n - config.t { + state.notify_count_aux.notify_one((aux.round, aux.stage)); + } + } + + AbaMessage::AuxiliarySet(aux_set) => { + // Insert auxset view, at most once per sender_id + let mut auxset_views = state.auxset_views.lock().await; + auxset_views.insert_once(aux_set.round, sender_id, aux_set.view); + + // notify once we got at least n - t auxset messages + if auxset_views.get_count(&aux_set.round) >= config.n - config.t { + state.notify_count_auxset.notify_one(aux_set.round); + } + } + + AbaMessage::CoinEval(msg_eval) => { + // Deserialize eval + let Ok(eval): Result, _> = msg_eval.borrow().try_into() + else { + warn!("Failed to deserialize CoinEvalMessage"); + continue; + }; + + // Store one eval per party, per round. We cannot verify it here + // since the node may not be ready to check evaluations yet. + let mut coin_evals = state.coin_evals.lock().await; + coin_evals.insert_once(msg_eval.round, sender_id, eval); + let count = coin_evals.get_count(&msg_eval.round); + drop(coin_evals); // drop lock + + // Notify if we have t + 1 evals + if count > config.t { + state.notify_enough_coin_evals.notify_one(msg_eval.round); + } + } + }; + } + } +} diff --git a/crates/adkg/src/adkg.rs b/crates/adkg/src/adkg.rs index 955c8ffe..56a93ac7 100644 --- a/crates/adkg/src/adkg.rs +++ b/crates/adkg/src/adkg.rs @@ -17,7 +17,7 @@ use crate::rand::{AdkgRng, AdkgRngType}; use crate::rbc::ReliableBroadcastConfig; use crate::rbc::multi_rbc::MultiRbc; use crate::vss::acss::AcssConfig; -use crate::vss::acss::hbacss0::PedersenSecret; +use crate::vss::acss::hbacss0::{Hbacss0Input, PedersenSecret}; use crate::vss::acss::multi_acss::MultiAcss; use ark_ec::{AffineRepr, CurveGroup, PrimeGroup}; use ark_ff::Zero; @@ -186,7 +186,7 @@ where CG::ScalarField: FqSerialize + FqDeserialize, H: Default + DynDigest + FixedOutputReset + BlockSizeUser + Clone + 'static, RBCConfig: ReliableBroadcastConfig<'static, PartyId>, - ACSSConfig: AcssConfig<'static, CG, PartyId, Input = Vec>>, + ACSSConfig: AcssConfig<'static, CG, PartyId, Input = Hbacss0Input>, ACSSConfig::Output: Into>, ABAConfig: AbaConfig<'static, PartyId, Input = AbaCrainInput>, { @@ -286,13 +286,20 @@ where .map_err(|e| AdkgError::Rng(e.into(), "failed to get acss secret rng"))?; // Generate random secret scalars to be used in the node's ACSS - let s: Vec<_> = (0..shares_per_acss) + let pedersen_in: Vec<_> = (0..shares_per_acss) .map(|_| { let a = CG::ScalarField::rand(&mut acss_rng); let a_hat = CG::ScalarField::rand(&mut acss_rng); PedersenSecret { s: a, r: a_hat } }) .collect(); + // Additional feldman secret used in the coin toss of the multi-valued validated byzantine agreement (MVBA) + let feldman_in = CG::ScalarField::rand(&mut acss_rng); + + let s = Hbacss0Input { + feld: feldman_in, + peds: pedersen_in, + }; // Generate predicates for each of the RBCs let rbc_predicates: Vec<_> = PartyId::iter_all(self.n) @@ -919,7 +926,7 @@ where ) -> Result, AdkgError> { let id = state.id; let inner_fn = async move { - let mut final_parties: HashSet = HashSet::new(); + let mut aba_sessions: HashSet = HashSet::new(); let mut remaining_estimates = state.multi_aba.lock().await.iter_remaining_estimates(); loop { let Some((j, res)) = remaining_estimates.next().await else { @@ -938,28 +945,18 @@ where id, j ); - state - .multi_aba - .lock() - .await - .cancel(j) - .await - .unwrap() - .unwrap(); + match state.multi_aba.lock().await.cancel(j).await { + Some(Ok(())) => (), + Some(Err(e)) => { + warn!(error = ?e, "Failed to send cancel signal to node"); + } + None => { + warn!("Failed to send cancel signal to node"); + } + } if let Estimate::One = estimate { - let rbc_parties = match state.rbc_outputs.get(&j) { - Some(rbc_parties) => rbc_parties, - None => { - info!( - "Node `{}` obtained an estimate from ABA with sid `{j}` without having obtained an RBC output, waiting.", - state.id - ); - Self::wait_for_rbc_output(&state.rbc_outputs, &j).await - } - }; - - final_parties.extend(rbc_parties.iter()); + aba_sessions.insert(j); // Input 0 to remaining ABAs let iter_remaining_senders = @@ -986,7 +983,24 @@ where } } } - final_parties.into_iter().collect() + + let mut final_sessions = HashSet::with_capacity(state.n); + for j in aba_sessions { + // Get the list of parties for that session + let rbc_sessions = match state.rbc_outputs.get(&j) { + Some(rbc_sessions) => rbc_sessions, + None => { + info!( + "Node `{}` obtained an estimate from ABA with sid `{j}` without having obtained an RBC output, waiting.", + state.id + ); + Self::wait_for_rbc_output(&state.rbc_outputs, &j).await + } + }; + final_sessions.extend(rbc_sessions) + } + + final_sessions.into_iter().collect() }; tokio::select! { @@ -1016,6 +1030,9 @@ where let id = state.id; let inner_fn = async move { loop { + // Register interest for notification prior to checking to not lose notifications + let notified = state.completed_acss_outputs.notified(); + // Check whether the completed ACSSs is a superset of rbc_parties // i.e., rbc_parties is a subset of completed_acss_outputs if state.completed_acss_outputs.is_superset(&rbc_parties) { @@ -1062,7 +1079,7 @@ where } // If not, wait for an update - state.completed_acss_outputs.wait().await; + notified.await; } }; @@ -1090,6 +1107,9 @@ where let rbc_output = Self::wait_for_rbc_output(&state.rbc_outputs, &sid).await; loop { + // Register interest for notification prior to checking to not lose notifications + let notified = state.completed_acss_outputs.notified(); + if state.completed_acss_outputs.is_superset(&rbc_output) { // Get the acss outputs to be used as the coin keys let outputs: Vec<_> = state @@ -1101,7 +1121,7 @@ where } // If not, wait for an update - state.completed_acss_outputs.wait().await; + notified.await; } }; @@ -1122,11 +1142,14 @@ where ) -> HashSet { let wait_for_rbc = async { loop { + // Register interest for notification prior to checking to not lose notifications + let notified = rbc_outputs.notified(); + if let Some(rbc_output) = rbc_outputs.get(sid) { return rbc_output; } - rbc_outputs.wait().await; + notified.await; } }; wait_for_rbc.await @@ -1152,6 +1175,7 @@ mod tests { use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use tokio::task::JoinSet; + use tracing_subscriber::EnvFilter; use utils::dst::{NamedCurveGroup, NamedDynDigest, Rfc9380DstBuilder}; use utils::hash_to_curve::HashToCurve; use utils::serialize::fq::{FqDeserialize, FqSerialize}; @@ -1183,6 +1207,26 @@ mod tests { CG::hash_to_curve_custom::(b"ADKG_GENERATOR_G", &dst) } + #[tokio::test(flavor = "multi_thread", worker_threads = 32)] + #[ignore] + async fn adkg_loop_bn254() { + // Static configuration and long term keys + let t = 2; + let n = 3 * t + 1; + + const SEED: &[u8] = b"ADKG_BN254_TEST_SEED"; + + // We use h == Bn254 G1 as the generator for the group public key + // and an independent generator g for the ADKG operations. + let g = get_generator_g::<_, sha3::Sha3_256>(); + let h = ark_bn254::G1Projective::generator(); + + // run adkg with reconstruction threshold of t & 2t + loop { + run_adkg_test::<_, sha3::Sha3_256>(t, t, n, g, h, SEED).await; + } + } + #[tokio::test(flavor = "multi_thread", worker_threads = 32)] async fn adkg_test_bn254() { // Static configuration and long term keys @@ -1231,7 +1275,9 @@ mod tests { H: Default + NamedDynDigest + FixedOutputReset + BlockSizeUser + Clone + 'static, { _ = tracing_subscriber::fmt() - .with_max_level(tracing::Level::WARN) + .with_env_filter( + EnvFilter::try_from_env("ADKG_DEBUG").unwrap_or_else(|_| "warn".parse().unwrap()), + ) .try_init(); let sks: VecDeque = (1..=n) diff --git a/crates/adkg/src/adkg/types.rs b/crates/adkg/src/adkg/types.rs index d6c9abf0..e5d7444a 100644 --- a/crates/adkg/src/adkg/types.rs +++ b/crates/adkg/src/adkg/types.rs @@ -6,7 +6,7 @@ use crate::pok::PokProof; use crate::rbc::multi_rbc::MultiRbc; use crate::rbc::{RbcPredicate, ReliableBroadcastConfig}; use crate::vss::acss::AcssConfig; -use crate::vss::acss::hbacss0::PublicPoly; +use crate::vss::acss::hbacss0::{FeldPublicPoly, PedPublicPoly}; use crate::vss::acss::multi_acss::MultiAcss; use crate::vss::pedersen::PedersenPartyShare; use ark_ec::CurveGroup; @@ -15,6 +15,7 @@ use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::Notify; +use tokio::sync::futures::Notified; use tracing::warn; use utils::serialize::fq::{FqDeserialize, FqSerialize}; use utils::serialize::point::{PointDeserializeCompressed, PointSerializeCompressed}; @@ -32,8 +33,10 @@ pub struct LazyCoinKeys { /// ACSS output required by ADKG. #[derive(Clone)] pub struct ShareWithPoly { + pub mvba_share: CG::ScalarField, + pub mvba_public_poly: FeldPublicPoly, pub shares: Vec>, - pub public_polys: Vec>, + pub public_polys: Vec>, } /// Predicate used by reliable broadcasts. @@ -120,19 +123,19 @@ impl LazyCoinKeys { impl From> for CoinKeys { fn from(val: LazyCoinKeys) -> Self { // Obtain the combined public polynomial as p_j = \sum_{k \in rbc_parties} p_k(x) - // which is the sum of the public polynomial output by each ACSS specified in the j-th RBC + // which is the sum of the MVBA public polynomial output by each ACSS specified in the j-th RBC let public_poly: Vec = (0..=val.t) .map(|i| { val.outputs .iter() - .map(|(_, out)| out.public_polys[0].as_vec()[i]) + .map(|(_, out)| out.mvba_public_poly.0[i]) .sum() }) .collect(); - // Our own secret share, the sum of our ACSS shares + // Our own secret share, the sum of our ACSS MVBA shares // u_{i,j} = \sum_{k \in rbc_parties} s_{k,j} = - let u_i_j: CG::ScalarField = val.outputs.iter().map(|(_, out)| out.shares[0].si).sum(); + let u_i_j: CG::ScalarField = val.outputs.iter().map(|(_, out)| out.mvba_share).sum(); // Obtain commitments to the secret shares of the other parties // (g^{u_{1,j}}, ... g^{u_{n,j}}) = (g^p*(1), ..., g^p*(n)) @@ -171,6 +174,9 @@ impl RbcPredicate for NotifyPredicate { }; loop { + // Register interest for notification prior to checking to not lose notifications + let notified = self.completed_acss.notified(); + // Check that we have at least min_size parties, and that the completed ACSSs is a superset // of the parties in the message if rbc_parties.v.len() >= self.min_size @@ -181,7 +187,7 @@ impl RbcPredicate for NotifyPredicate { } // If not, wait for an update to completed_acss - self.completed_acss.wait().await; + notified.await; } } } @@ -199,8 +205,8 @@ impl NotifyMap { self.notify.notify_waiters() } - pub async fn wait(self: &Arc) { - self.notify.notified().await + pub fn notified(self: &Arc) -> Notified<'_> { + self.notify.notified() } pub fn keys(self: &Arc) -> Vec diff --git a/crates/adkg/src/scheme.rs b/crates/adkg/src/scheme.rs index e9222279..37d0862a 100644 --- a/crates/adkg/src/scheme.rs +++ b/crates/adkg/src/scheme.rs @@ -9,8 +9,7 @@ use crate::network::RetryStrategy; use crate::rbc::ReliableBroadcastConfig; use crate::rbc::r4::Rbc4RoundsConfig; use crate::vss::acss::AcssConfig; -use crate::vss::acss::hbacss0::HbAcss0Config; -use crate::vss::acss::hbacss0::PedersenSecret; +use crate::vss::acss::hbacss0::{HbAcss0Config, Hbacss0Input}; use ark_ec::{CurveGroup, PrimeGroup}; use ark_std::UniformRand; use digest::core_api::BlockSizeUser; @@ -51,7 +50,7 @@ where 'static, Self::Curve, PartyId, - Input = Vec::ScalarField>>, + Input = Hbacss0Input<::ScalarField>, >; type ABAConfig: AbaConfig<'static, PartyId>; diff --git a/crates/adkg/src/vss/acss/hbacss0.rs b/crates/adkg/src/vss/acss/hbacss0.rs index 0c2db622..dbb5b9c2 100644 --- a/crates/adkg/src/vss/acss/hbacss0.rs +++ b/crates/adkg/src/vss/acss/hbacss0.rs @@ -11,9 +11,9 @@ use crate::helpers::PartyId; use crate::network::{RetryStrategy, broadcast_with_self}; use crate::nizk::NIZKDleqProof; use crate::rbc::ReliableBroadcastConfig; -use crate::vss::acss::hbacss0::types::{PedersenPartyShares, ShareRecoveryMessage}; -use crate::vss::pedersen; +use crate::vss::acss::hbacss0::types::{PartyShares, ShareRecoveryMessage}; use crate::vss::pedersen::PedersenPartyShare; +use crate::vss::{feldman, pedersen}; use crate::{ pke::ec_hybrid_chacha20poly1305::{self, EphemeralMultiHybridCiphertext}, rbc::{RbcPredicate, ReliableBroadcast}, @@ -123,19 +123,35 @@ pub struct PedersenSecret { pub r: F, } +/// The input of the Feldman + Pedersen-based ACSS. +pub struct Hbacss0Input { + pub feld: F, + pub peds: Vec>, +} + /// The public polynomial output by the Pedersen-based ACSS. #[derive(Serialize, Deserialize, Clone)] #[serde(bound( serialize = "CG: PointSerializeCompressed", deserialize = "CG: PointDeserializeCompressed" ))] -pub struct PublicPoly(#[serde(with = "utils::serialize::point::base64::vec")] pub Vec); +pub struct PedPublicPoly(#[serde(with = "utils::serialize::point::base64::vec")] pub Vec); + +/// The public polynomial output by the ACSS. +#[derive(Serialize, Deserialize, Clone)] +#[serde(bound( + serialize = "CG: PointSerializeCompressed", + deserialize = "CG: PointDeserializeCompressed" +))] +pub struct FeldPublicPoly(#[serde(with = "utils::serialize::point::base64::vec")] pub Vec); /// The output of the Pedersen-based ACSS. #[derive(Clone)] pub struct Hbacss0Output { + pub feld_share: CG::ScalarField, + pub feld_public_poly: FeldPublicPoly, pub shares: Vec>, - pub public_polys: Vec>, + pub public_polys: Vec>, } impl<'a, CG, H, RBCConfig> AcssConfig<'a, CG, PartyId> for HbAcss0Config @@ -146,7 +162,7 @@ where H: Default + DynDigest + FixedOutputReset + BlockSizeUser + Clone + 'static, RBCConfig: for<'lt_rbc> ReliableBroadcastConfig<'lt_rbc, PartyId> + 'a, { - type Input = Vec>; + type Input = Hbacss0Input; type Output = Hbacss0Output; type Error = Box; @@ -218,12 +234,12 @@ where T: Transport, { type Error = Box; - type Input = Vec>; + type Input = Hbacss0Input; type Output = Hbacss0Output; async fn deal( self, - ped_sks: Self::Input, + s: Self::Input, cancel: CancellationToken, output: oneshot::Sender, rng: &mut RNG, @@ -234,7 +250,7 @@ where let id = self.config.id; let res = select! { res = self.acss_dealer( - ped_sks, + s, cancel.child_token(), output, rng @@ -307,7 +323,7 @@ where /// Beginning of the ACSS protocol as the dealer. async fn acss_dealer( mut self, - ped_sks: impl IntoIterator>, + s: Hbacss0Input, rbc_cancel: CancellationToken, output: oneshot::Sender>, rng: &mut RNG, @@ -321,7 +337,10 @@ where // Pedersen's Polynomial Commitment of degree t where p(0) = s let g = self.config.g; let h = self.config.h; - let vss_shares: Vec<_> = ped_sks + let feld_vss_share = feldman::share(&s.feld, &g, self.config.n, self.config.t, rng); + let feld_public_poly = FeldPublicPoly(feld_vss_share.get_public_poly().to_vec()); + let ped_vss_shares: Vec<_> = s + .peds .into_iter() .map(|ped_sk| { pedersen::share( @@ -335,25 +354,33 @@ where ) }) .collect(); - let public_polys: Vec<_> = vss_shares + let ped_public_polys: Vec<_> = ped_vss_shares .iter() .map(|vss_share| vss_share.get_public_poly().to_vec().into()) .collect(); + let get_party_shares = |i| { + let feld_share = *feld_vss_share + .get_party_secrets(i) + .expect("feldan output less than n shares"); + let ped_shares: Vec<_> = ped_vss_shares + .iter() + .map(|vss_share| { + vss_share + .get_party_secrets(&i) + .expect("feldman output less than n shares") // gen n shares, loop n times + }) + .collect(); + PartyShares { + feld_share, + ped_shares, + } + }; + // Encrypt and Disperse // Each share is encrypted towards the receiving party let shares = PartyId::iter_all(self.config.n) - .map(|i| -> Result, _> { - let shares: Vec<_> = vss_shares - .iter() - .map(|vss_share| { - vss_share - .get_party_secrets(&i) - .expect("feldman output less than n shares") // gen n shares, loop n times - }) - .collect(); - bson::to_vec(&PedersenPartyShares { shares }) - }) + .map(|i| -> Result, _> { bson::to_vec(&get_party_shares(i)) }) .collect::>, _>>() .map_err(|e| AcssError::BsonSer(e, "dealer failed to serialize vss shares"))?; // unexpected error, abort ACSS @@ -364,7 +391,8 @@ where // Disperse encrypted shares and public polynomial through the broadcast channel let broadcast = AcssBroadcastMessage { enc_shares, - public_polys: public_polys.clone(), + feld_public_poly: feld_public_poly.clone(), + ped_public_polys: ped_public_polys.clone(), }; let m = bson::to_vec(&broadcast) .map_err(|e| AcssError::BsonSer(e, "dealer failed to serialize broadcast message"))?; // unexpected error, abort ACSS @@ -384,14 +412,13 @@ where // Continue the execution of the acss protocol as a normal participant. let id = self.config.id; + let shares = get_party_shares(id); self.acss_continue( - vss_shares - .iter() - .map(|vss_share| vss_share.get_party_secrets(&id)) - .collect(), + Some(shares), output, &broadcast.enc_shares, - public_polys, + feld_public_poly, + ped_public_polys, rng, ) .await @@ -446,13 +473,15 @@ where // Decrypt and validate share let enc_shares = &m.enc_shares; let shared_key = enc_shares.derive_shared_key(&self.config.sk); - let public_polys = m.public_polys; + let feld_public_poly = m.feld_public_poly; + let ped_public_polys = m.ped_public_polys; // If the share is valid, the nodes enters the reconstruction process // otherwise, the node enters the recovery process - let share = ped_eval_verify( + let shares = dual_eval_verify( enc_shares, - public_polys.iter(), + &feld_public_poly, + ped_public_polys.iter(), &self.config.g, &self.config.h, self.config.id, @@ -460,17 +489,25 @@ where &self.config.pks[self.config.id], ) .ok(); - self.acss_continue(share, output, &m.enc_shares, public_polys, rng) - .await + self.acss_continue( + shares, + output, + &m.enc_shares, + feld_public_poly, + ped_public_polys, + rng, + ) + .await } /// Execute the agreement / implication / share recovery part of the protocol. async fn acss_continue( self, - shares: Option>>, + shares: Option>, output: oneshot::Sender>, enc_shares: &EphemeralMultiHybridCiphertext, - public_polys: Vec>, + feld_public_poly: FeldPublicPoly, + ped_public_polys: Vec>, rng: &mut RNG, ) -> Result<(), Box> where @@ -582,7 +619,12 @@ where AcssMessage::Ready => { hbacss0 - .ready_handler(sender, &mut state_machine, &public_polys) + .ready_handler( + sender, + &mut state_machine, + &feld_public_poly, + &ped_public_polys, + ) .await } @@ -591,7 +633,8 @@ where .implicate_handler( &ski, enc_shares, - &public_polys, + &feld_public_poly, + &ped_public_polys, sender, &mut state_machine, ) @@ -603,7 +646,8 @@ where .recovery_handler( &shared_key, enc_shares, - &public_polys, + &feld_public_poly, + &ped_public_polys, sender, &mut state_machine, ) @@ -622,13 +666,15 @@ where impl From> for ShareWithPoly { fn from(value: Hbacss0Output) -> Self { Self { + mvba_public_poly: value.feld_public_poly, + mvba_share: value.feld_share, public_polys: value.public_polys, shares: value.shares, } } } -impl PublicPoly { +impl PedPublicPoly { pub fn to_vec(self) -> Vec { self.0 } @@ -638,7 +684,7 @@ impl PublicPoly { } } -impl From> for PublicPoly { +impl From> for PedPublicPoly { fn from(v: Vec) -> Self { Self(v) } @@ -671,10 +717,12 @@ where // Decrypt and validate share let enc_shares = &m.enc_shares; let shared_key = enc_shares.derive_shared_key(&self.sk); - let public_poly = &m.public_polys; - ped_eval_verify( + let feld_public_poly = &m.feld_public_poly; + let ped_public_poly = &m.ped_public_polys; + dual_eval_verify( enc_shares, - public_poly, + feld_public_poly, + ped_public_poly, &self.g, &self.h, self.i, @@ -686,15 +734,17 @@ where } /// Verify that a hybrid encryption ciphertext can be decrypted and is a valid Feldman share for party i. -fn ped_eval_verify<'a, CG>( +#[allow(clippy::too_many_arguments)] +fn dual_eval_verify<'a, CG>( ct: &EphemeralMultiHybridCiphertext, - public_polys: impl IntoIterator, IntoIter: ExactSizeIterator>, + feld_poly: &FeldPublicPoly, + ped_public_polys: impl IntoIterator, IntoIter: ExactSizeIterator>, g: &CG, h: &CG, i: PartyId, shared_key: &CG, recipient_pk: &CG, -) -> Result>, ()> +) -> Result, ()> where CG: CurveGroup + PointSerializeCompressed, CG::ScalarField: FqDeserialize, @@ -704,36 +754,47 @@ where let pt = ct .decrypt_one_with_shared_key(i.as_index(), shared_key, recipient_pk) .map_err(|_| { - warn!("Failed to decrypt pedersen party shares"); + warn!("Failed to decrypt party shares"); })?; - let PedersenPartyShares:: { shares } = bson::from_slice(&pt).map_err(|e| { - warn!(error = ?e, "Failed to deserialize pedersen party shares"); + let party_shares: PartyShares = bson::from_slice(&pt).map_err(|e| { + warn!(error = ?e, "Failed to deserialize party shares"); })?; + let PartyShares { + feld_share, + ped_shares, + } = &party_shares; - let public_polys = public_polys.into_iter(); - if public_polys.len() != shares.len() { + let public_polys = ped_public_polys.into_iter(); + if public_polys.len() != ped_shares.len() { warn!( expected_len = public_polys.len(), - len = shares.len(), + len = ped_shares.len(), "Attempting to verify pedersen shares with invalid length" ); Err(())? } // Try to verify the shares, or return Err(()) - if public_polys.zip(shares.iter()).all(|(public_poly, share)| { - pedersen::eval_verify( - &public_poly.to_owned().to_vec(), - i.into(), - &share.si, - &share.ri, - g, - h, - ) - .is_ok() - }) { - Ok(shares) + if feldman::eval_verify(&feld_poly.0, i, feld_share, g).is_err() { + return Err(()); + } + + if public_polys + .zip(ped_shares.iter()) + .all(|(public_poly, share)| { + pedersen::eval_verify( + &public_poly.to_owned().to_vec(), + i.into(), + &share.si, + &share.ri, + g, + h, + ) + .is_ok() + }) + { + Ok(party_shares) } else { Err(()) } @@ -777,7 +838,9 @@ mod tests { use crate::helpers::PartyId; use crate::rbc::r4::Rbc4RoundsConfig; use crate::vss::acss::AcssConfig; - use crate::vss::acss::hbacss0::{APPNAME, HbAcss0Config, NIZK_DLEQ_SUFFIX, PedersenSecret}; + use crate::vss::acss::hbacss0::{ + APPNAME, HbAcss0Config, Hbacss0Input, NIZK_DLEQ_SUFFIX, PedersenSecret, + }; use crate::{ helpers::{lagrange_interpolate_at, u64_from_usize}, network::RetryStrategy, @@ -785,6 +848,7 @@ mod tests { }; use ark_bn254::Bn254; use ark_ec::{PrimeGroup, pairing::Pairing}; + use ark_ff::Zero; use ark_std::UniformRand; use dcipher_network::topic::dispatcher::TopicDispatcher; use dcipher_network::transports::in_memory::MemoryNetwork; @@ -811,6 +875,7 @@ mod tests { let g = G::generator(); let h = ark_bn254::G1Projective::hash_to_curve(b"PEDERSEN_H", b"TEST_DST_PEDERSEN_H"); + let feld_s = ScalarField::rand(&mut rand::thread_rng()); let s = ScalarField::rand(&mut rand::thread_rng()); let r = ScalarField::rand(&mut rand::thread_rng()); let mut sks: VecDeque = (1..=n) @@ -858,7 +923,10 @@ mod tests { .new_instance_with_prefix("hbacss0".to_owned(), Arc::new(transport)) .expect("failed to create acss instance"); acss.deal( - vec![PedersenSecret { s, r }], + Hbacss0Input { + feld: feld_s, + peds: vec![PedersenSecret { s, r }], + }, cancellation_token, sender, &mut OsRng, @@ -914,18 +982,28 @@ mod tests { }); } - let mut shares = vec![]; + let mut feld_shares = vec![]; + let mut ped_shares = vec![]; while let Some(res) = tasks.join_next().await { assert!(res.is_ok()); let (i, out) = res.unwrap(); - shares.push((i, out.shares[0].si)); + ped_shares.push((i, out.shares[0].si)); + feld_shares.push((i, out.feld_share)); } - let s = lagrange_interpolate_at::(&shares[0..=t], 0); - let s2 = lagrange_interpolate_at::(&shares[t..=2 * t], 0); + let ped_s = lagrange_interpolate_at::(&ped_shares[0..=t], 0); + let ped_s2 = lagrange_interpolate_at::(&ped_shares[t..=2 * t], 0); + let feld_s = lagrange_interpolate_at::(&feld_shares[0..=t], 0); + let feld_s2 = lagrange_interpolate_at::(&feld_shares[t..=2 * t], 0); + + assert_eq!(ped_s, ped_s2); + assert!(!ped_s.is_zero()); + + assert_eq!(feld_s, feld_s2); + assert!(!feld_s.is_zero()); - assert_eq!(s, s2) + assert_ne!(ped_s, feld_s); } #[test] diff --git a/crates/adkg/src/vss/acss/hbacss0/handlers.rs b/crates/adkg/src/vss/acss/hbacss0/handlers.rs index b889bd33..0157e1e8 100644 --- a/crates/adkg/src/vss/acss/hbacss0/handlers.rs +++ b/crates/adkg/src/vss/acss/hbacss0/handlers.rs @@ -1,18 +1,18 @@ //! Handlers for the various messages sent during the ACSS protocol. use super::{ - AcssMessage, AcssStatus, HbAcss0Instance, Hbacss0Output, ImplicateMessage, PublicPoly, - StateMachine, + AcssMessage, AcssStatus, FeldPublicPoly, HbAcss0Instance, Hbacss0Output, ImplicateMessage, + PedPublicPoly, StateMachine, }; use crate::helpers::PartyId; use crate::network::broadcast_with_self; use crate::rbc::ReliableBroadcastConfig; -use crate::vss::acss::hbacss0::types::ShareRecoveryMessage; +use crate::vss::acss::hbacss0::types::{PartyShares, ShareRecoveryMessage}; use crate::vss::pedersen::PedersenPartyShare; use crate::{ helpers::lagrange_interpolate_at, nizk::NIZKDleqProof, pke::ec_hybrid_chacha20poly1305::EphemeralMultiHybridCiphertext, - vss::acss::hbacss0::ped_eval_verify, + vss::acss::hbacss0::dual_eval_verify, }; use ark_ec::CurveGroup; use dcipher_network::TransportSender; @@ -81,7 +81,8 @@ where &self, sender: PartyId, state_machine: &mut StateMachine, - public_polys: &[PublicPoly], + feld_public_poly: &FeldPublicPoly, + ped_public_polys: &[PedPublicPoly], ) { // Skip messages if not waiting for Ok nor Readys nor in ShareRecovery match state_machine.status { @@ -147,8 +148,10 @@ where if output .send(Hbacss0Output { - shares: shares.to_owned(), - public_polys: public_polys.to_vec(), + feld_share: shares.feld_share, + shares: shares.ped_shares.clone(), + feld_public_poly: feld_public_poly.to_owned(), + public_polys: ped_public_polys.to_vec(), }) .is_err() { @@ -184,7 +187,8 @@ where &self, msg: &ImplicateMessage, enc_shares: &EphemeralMultiHybridCiphertext, - public_polys: &[PublicPoly], + feld_public_poly: &FeldPublicPoly, + ped_public_polys: &[PedPublicPoly], sender: PartyId, state_machine: &mut StateMachine, ) where @@ -259,9 +263,10 @@ where } // We know that the sender gave us a valid shared key, try to decrypt the original ciphertext sent by the dealer. - if ped_eval_verify( + if dual_eval_verify( enc_shares, - public_polys, + feld_public_poly, + ped_public_polys, &self.config.g, &self.config.h, sender, @@ -316,7 +321,8 @@ where &self, shared_key: &[u8], enc_shares: &EphemeralMultiHybridCiphertext, - public_polys: &[PublicPoly], + feld_public_poly: &FeldPublicPoly, + ped_public_polys: &[PedPublicPoly], sender: PartyId, state_machine: &mut StateMachine, ) where @@ -344,9 +350,10 @@ where // We don't verify the source / validity of the shared key. // We only need it such that decryption results in a valid dealer's share. - let Ok(shares) = ped_eval_verify( + let Ok(shares) = dual_eval_verify( enc_shares, - public_polys, + feld_public_poly, + ped_public_polys, &self.config.g, &self.config.h, sender, @@ -368,10 +375,13 @@ where #[allow(clippy::int_plus_one)] if state_machine.shares_recovery.len() >= self.config.t + 1 { // Enough valid shares, interpolate the polynomial + let n = state_machine.shares_recovery.len(); let mut points_peds: Vec<(Vec<_>, Vec<_>)> = vec![]; + let mut points_feld = vec![]; for (&k, shares) in state_machine.shares_recovery.iter() { - for (share_idx, share) in shares.iter().enumerate() { + points_feld.push((k.into(), shares.feld_share)); + for (share_idx, share) in shares.ped_shares.iter().enumerate() { if points_peds.len() <= share_idx { points_peds.push((Vec::with_capacity(n), Vec::with_capacity(n))); } @@ -381,6 +391,7 @@ where } } + let feld_share = lagrange_interpolate_at::(&points_feld, self.config.id.into()); let Some(new_shares) = points_peds .into_iter() .map(|(points_si, points_ri)| { @@ -426,7 +437,10 @@ where } // Update state machine - state_machine.status = AcssStatus::WaitingForReadys(new_shares); + state_machine.status = AcssStatus::WaitingForReadys(PartyShares { + feld_share, + ped_shares: new_shares, + }); } } } diff --git a/crates/adkg/src/vss/acss/hbacss0/types.rs b/crates/adkg/src/vss/acss/hbacss0/types.rs index b959b874..6735c3b3 100644 --- a/crates/adkg/src/vss/acss/hbacss0/types.rs +++ b/crates/adkg/src/vss/acss/hbacss0/types.rs @@ -4,13 +4,15 @@ use crate::nizk::NizkError; use crate::pke::ec_hybrid_chacha20poly1305::{ EphemeralMultiHybridCiphertext, HybridEncryptionError, }; -use crate::vss::acss::hbacss0::{Hbacss0Output, PublicPoly}; +use crate::vss::acss::hbacss0::{FeldPublicPoly, Hbacss0Output, PedPublicPoly}; use crate::vss::pedersen::PedersenPartyShare; use ark_ec::CurveGroup; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use thiserror::Error; use tokio::sync::oneshot; +use utils::serialize::fq::FqDeserialize; +use utils::serialize::fq::FqSerialize; use utils::serialize::{ SerializationError, point::{PointDeserializeCompressed, PointSerializeCompressed}, @@ -86,10 +88,10 @@ pub(super) enum AcssStatus { ShareRecovery, /// A valid share was received, waiting for 2t + 1 oks. - WaitingForOks(Vec>), + WaitingForOks(PartyShares), /// Enough ok / readys were received, waiting for 2t + 1 readys. - WaitingForReadys(Vec>), + WaitingForReadys(PartyShares), /// A share was recovered, about to exit. Complete, @@ -103,17 +105,17 @@ pub(super) enum AcssStatus { ))] pub(super) struct AcssBroadcastMessage { pub(super) enc_shares: EphemeralMultiHybridCiphertext, - pub(super) public_polys: Vec>, + pub(super) feld_public_poly: FeldPublicPoly, + pub(super) ped_public_polys: Vec>, } -/// Wrapper around Vec>> for serde -#[derive(Serialize, Deserialize)] -#[serde(bound( - serialize = "PedersenPartyShare: Serialize", - deserialize = "PedersenPartyShare: Deserialize<'de>" -))] -pub(super) struct PedersenPartyShares { - pub(super) shares: Vec>, +/// Shares obtained by the ACSS +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound(serialize = "F: FqSerialize", deserialize = "F: FqDeserialize"))] +pub(super) struct PartyShares { + #[serde(with = "utils::serialize::fq::base64_or_bytes")] + pub(super) feld_share: F, + pub(super) ped_shares: Vec>, } /// State machine used by handlers to update the state of the node. @@ -123,7 +125,7 @@ pub(super) struct StateMachine { // could be replaced by a bitmap pub(super) nodes_oks: HashMap, // count the number of parties that are OK pub(super) nodes_readys: HashMap, // count the number of parties that are ready - pub(super) shares_recovery: HashMap>>, // store the parties currently recovering + pub(super) shares_recovery: HashMap>, // store the parties currently recovering pub(super) output: Option>>, // require an option since we move the sender upon sending }