diff --git a/src/compressed.rs b/src/compressed.rs new file mode 100644 index 0000000..62713c9 --- /dev/null +++ b/src/compressed.rs @@ -0,0 +1,334 @@ +//xxx example, not zk. + +use crate::codec::ByteSchnorrCodec; +use crate::codec::ShakeDuplexSponge; +use crate::errors::Error as ProofError; +use crate::errors::Result as ProofResult; +use crate::fiat_shamir::MultiRoundNizk; +use crate::linear_relation; +use crate::serialization::deserialize_elements; +use crate::serialization::deserialize_scalars; +use crate::traits::InteractiveProof; +use ff::Field; +use group::prime::PrimeGroup; + +use crate::{ + group::msm::VariableMultiScalarMul, + linear_relation::CanonicalLinearRelation, + serialization::{read_elements, serialize_elements, serialize_scalars}, +}; + +struct SquashedLinearRelation { + generators: Vec, + image: G, +} + +pub(crate) fn powers(element: F, len: usize) -> Vec { + let mut powers = vec![F::ONE; len]; + for i in 1..len { + powers[i] = element * powers[i - 1]; + } + powers +} + +impl CanonicalLinearRelation { + // not really needed but will simplify the code. + fn squash(&self, challenge: G::Scalar) -> SquashedLinearRelation { + let powers = powers(challenge, self.image.len()); + + let squashed_image = G::msm(&powers, &self.image); + + // Determine the number of scalar variables + let num_scalars = self.num_scalars; + + let mut squashed_generators = vec![G::identity(); num_scalars]; + + // the matrix for a linear relation is sparse, and stored in yale format. + for (row, linear_combination) in self.linear_combinations.iter().enumerate() { + for (scalar_var, group_var) in linear_combination.iter() { + let col = scalar_var.index(); + let element = self.group_elements.get(*group_var).unwrap(); + squashed_generators[col] += element * powers[row]; + } + } + + SquashedLinearRelation { + generators: squashed_generators, + image: squashed_image, + } + } +} + +fn fold_generators( + left: &[G], + right: &[G], + x_inv: &G::Scalar, + x: &G::Scalar, +) -> Vec { + left.iter() + .zip(right.iter()) + .map(|(l, r)| *l * (*x_inv) + *r * (*x)) + .collect() +} + +fn fold_scalars(left: &[F], right: &[F], x: &F, x_inv: &F) -> Vec { + left.iter() + .zip(right.iter()) + .map(|(&l, &r)| l * x + r * x_inv) + .collect() +} + +enum CompressedProofMessage { + FinalMessage(G::Scalar), + IntermediateMessage([G; 2]), +} + +#[derive(Clone)] +struct CompressedWitness(Vec); + +impl CompressedProofMessage { + fn new_from_intermediate_message(intermediate_message: [G; 2]) -> Self { + Self::IntermediateMessage(intermediate_message) + } + + fn new_from_final_message(final_message: G::Scalar) -> Self { + Self::FinalMessage(final_message) + } +} + +impl InteractiveProof for SquashedLinearRelation { + type ProverState = (CompressedWitness, SquashedLinearRelation); + + type ProverMessage = CompressedProofMessage; + + type VerifierState = SquashedLinearRelation; + + type Challenge = G::Scalar; + + type Witness = CompressedWitness; + + fn get_initial_prover_state(&self, witness: &Self::Witness) -> Self::ProverState { + ( + witness.clone(), + SquashedLinearRelation { + generators: self.generators.clone(), + image: self.image, + }, + ) + } + + fn get_initial_verifier_state(&self) -> Self::VerifierState { + SquashedLinearRelation { + generators: self.generators.clone(), + image: self.image, + } + } + + fn prover_message( + &self, + state: &mut Self::ProverState, + challenge: &Self::Challenge, + ) -> Result { + let (witness, statement) = state; + assert_eq!(witness.0.len(), statement.generators.len()); + assert_eq!( + G::msm(&witness.0, &statement.generators), + statement.image, + "Invalid witness" + ); + if statement.generators.len() == 1 { + let computed = statement.generators[0] * witness.0[0]; + let final_message = witness.0[0]; + assert_eq!(statement.image, computed); + return Ok(CompressedProofMessage::new_from_final_message( + final_message, + )); + } + let n = witness.0.len() / 2; + let (w_left, w_right) = witness.0.split_at(n); + let (g_left, g_right) = statement.generators.split_at(n); + + // round messages + let A = G::msm_unchecked(w_left, &g_right); + let B = G::msm_unchecked(w_right, &g_left); + let new_witness = fold_scalars(w_left, w_right, &G::Scalar::ONE, &challenge); + let new_generators = fold_generators(g_left, g_right, &challenge, &G::Scalar::ONE); + let new_image = A + statement.image * challenge + B * challenge.square(); + statement.generators = new_generators; + statement.image = new_image; + witness.0 = new_witness; + + Ok(CompressedProofMessage::new_from_intermediate_message([ + A, B, + ])) + } + + fn update_verifier_state( + prover_message: &Self::ProverMessage, + state: &mut Self::VerifierState, + challenge: &Self::Challenge, + ) -> Result<(), ProofError> { + if state.generators.len() == 1 { + match prover_message { + CompressedProofMessage::FinalMessage(witness) => { + let computed = state.generators[0] * witness; + if computed == state.image { + return Ok(()); + } else { + return Err(ProofError::VerificationFailure); + } + } + CompressedProofMessage::IntermediateMessage(_) => { + return Err(ProofError::VerificationFailure); + } + } + } + match prover_message { + CompressedProofMessage::FinalMessage(_) => { + return Err(ProofError::VerificationFailure); + } + CompressedProofMessage::IntermediateMessage([A, B]) => { + let n = state.generators.len() / 2; + let (g_left, g_right) = state.generators.split_at(n); + let new_generators = fold_generators(g_left, g_right, &challenge, &G::Scalar::ONE); + let new_image = *A + state.image * challenge + *B * challenge.square(); + state.generators = new_generators; + state.image = new_image; + Ok(()) + } + } + } + + fn serialize_message(&self, prover_message: &Self::ProverMessage) -> Vec { + match prover_message { + CompressedProofMessage::FinalMessage(witness) => serialize_scalars::(&[*witness]), + CompressedProofMessage::IntermediateMessage(prover_message) => { + serialize_elements(prover_message) + } + } + } + + fn serialize_challenge(&self, challenge: &Self::Challenge) -> Vec { + serialize_scalars::(&[*challenge]) + } + + fn deserialize_message( + &self, + data: &[u8], + is_final_message: bool, + ) -> Result { + if is_final_message { + let witness = + deserialize_scalars::(data, 1).ok_or(ProofError::VerificationFailure)?; + Ok(CompressedProofMessage::new_from_final_message(witness[0])) + } else { + let elements = + deserialize_elements::(data, 2).ok_or(ProofError::VerificationFailure)?; + let intermediate_message: [G; 2] = [elements[0], elements[1]]; + Ok(CompressedProofMessage::IntermediateMessage( + intermediate_message, + )) + } + } + + fn deserialize_challenge(&self, data: &[u8]) -> Result { + let scalars = deserialize_scalars::(data, 1).ok_or(ProofError::VerificationFailure)?; + Ok(scalars[0]) + } + + fn protocol_identifier(&self) -> impl AsRef<[u8]> { + "TODO" + } + + fn instance_label(&self) -> impl AsRef<[u8]> { + "TODO" + } + + fn num_rounds(&self) -> usize { + self.generators.len().next_power_of_two().ilog2() as usize + 1 + } +} + +#[test] +fn test_compressed_bbs_nyms() { + use curve25519_dalek::ristretto::RistrettoPoint as G; + use curve25519_dalek::Scalar; + + let rng = &mut rand::thread_rng(); + let mut statement = linear_relation::LinearRelation::::new(); + // bbs variables + const N: usize = 127; + let var_ms = statement.allocate_scalars::(); + let var_G0 = statement.allocate_element(); + let var_Gs = statement.allocate_elements::(); + // xxx + // let var_X = statement.allocate_element(); + let var_e = statement.allocate_scalar(); + let var_A = statement.allocate_element(); + // nym group elements + let var_Ts = statement.allocate_elements::(); + + // bbs verification equation + // x A = G_0 + sum_{i=1}^n m_i G_i + e A + let var_Z = statement.allocate_eq( + var_Gs + .iter() + .zip(var_ms.iter()) + .map(|(g, m)| *g * *m) + .sum::>() + + var_G0 + + var_e * var_A, + ); + // pseudonym + let var_NYM = statement.allocate_eq( + var_Ts + .iter() + .zip(var_ms) + .map(|(t, m)| *t * m) + .sum::>(), + ); + + let challenge = Scalar::random(rng); // Random squash challenge + let G0 = G::random(rng); + let Gs = (0..N).map(|_| G::random(rng)).collect::>(); + let ms = (0..N).map(|_| Scalar::random(rng)).collect::>(); + // xxx + let Ts = (0..N).map(|_| G::random(rng)).collect::>(); + let x = Scalar::random(rng); + // computed by the server + let e = Scalar::random(rng); + let A = (x - e).invert() * (G0 + G::msm(&ms, &Gs)); + let Z = x * A; + // computed by the user + let NYM = G::msm(&ms, &Ts); + + // the public elements + statement.set_elements( + [(var_G0, G0), (var_A, A), (var_NYM, NYM), (var_Z, Z)] + .into_iter() + .chain(var_Gs.iter().copied().zip(Gs.iter().copied())) + .chain(var_Ts.iter().copied().zip(Ts.iter().copied())), + ); + // the private witness + let witness = [ms.as_slice(), &[e]].concat(); + + assert_eq!( + statement + .canonical() + .unwrap() + .is_witness_valid(&witness) + .unwrap_u8(), + 1 + ); + // All random challenges now + let squashed_statement = statement.canonical().unwrap().squash(challenge); + let witness_check = G::msm(&witness, &squashed_statement.generators); + + let multi_round_nizk: MultiRoundNizk< + SquashedLinearRelation, + ByteSchnorrCodec, + > = MultiRoundNizk::new("Pseudonym Proof".as_bytes(), squashed_statement); + let (prover_messages, _) = multi_round_nizk.prove(&CompressedWitness(witness)).unwrap(); + let verification_result = multi_round_nizk.verify(&prover_messages); + assert!(verification_result.is_ok()); +} diff --git a/src/fiat_shamir.rs b/src/fiat_shamir.rs index 28c69d7..ea3a335 100644 --- a/src/fiat_shamir.rs +++ b/src/fiat_shamir.rs @@ -13,7 +13,7 @@ //! - `C`: the codec ([`Codec`] trait). use crate::errors::Error; -use crate::traits::SigmaProtocol; +use crate::traits::{InteractiveProof, SigmaProtocol}; use crate::{codec::Codec, traits::SigmaProtocolSimulator}; use alloc::vec::Vec; @@ -28,6 +28,12 @@ type Transcript

= (

::Response, ); +#[allow(unused)] +type MultiRoundTranscript

= ( + Vec<

::ProverMessage>, + Vec<

::Challenge>, +); + /// A Fiat-Shamir transformation of a [`SigmaProtocol`] into a non-interactive proof. /// /// [`Nizk`] wraps an interactive Sigma protocol `P` @@ -304,3 +310,86 @@ where self.verify(&commitment, &challenge, &response) } } + +#[allow(unused)] +pub struct MultiRoundNizk +where + P: InteractiveProof, + P::Challenge: PartialEq, + C: Codec, +{ + /// Current codec state. + pub hash_state: C, + /// Underlying interactive proof. + pub interactive_proof: P, +} + +#[allow(unused)] +impl MultiRoundNizk +where + P: InteractiveProof, + P::Challenge: PartialEq, + C: Codec + Clone, +{ + /// Constructs a new [`MultiRoundNizk`] instance. + /// + /// # Parameters + /// - `iv`: Domain separation tag for the hash function (e.g., protocol name or context). + /// - `instance`: An instance of the [`InteractiveProof`]. + /// + /// # Returns + /// A new [`MultiRoundNizk`] that can generate and verify non-interactive proofs. + pub fn new(session_identifier: &[u8], interactive_proof: P) -> Self { + let hash_state = C::new( + interactive_proof.protocol_identifier().as_ref(), + session_identifier, + interactive_proof.instance_label().as_ref(), + ); + Self { + hash_state, + interactive_proof, + } + } + + pub fn from_iv(iv: [u8; 64], interactive_proof: P) -> Self { + let hash_state = C::from_iv(iv); + Self { + hash_state, + interactive_proof, + } + } + + pub fn prove(&self, witness: &P::Witness) -> Result, Error> { + let mut hash_state = self.hash_state.clone(); + let num_rounds = self.interactive_proof.num_rounds(); + let mut statement = self.interactive_proof.get_initial_prover_state(witness); + let mut messages = vec![]; + let mut challenges = vec![]; + (0..num_rounds).for_each(|_| { + let challenge = hash_state.verifier_challenge(); + let message = self + .interactive_proof + .prover_message(&mut statement, &challenge) + .unwrap(); + let serialized_message = self.interactive_proof.serialize_message(&message); + hash_state.prover_message(&serialized_message); + messages.push(message); + challenges.push(challenge); + }); + Ok((messages, challenges)) + } + + pub fn verify(&self, prover_messages: &[P::ProverMessage]) -> Result<(), Error> { + let mut hash_state = self.hash_state.clone(); + let num_rounds = self.interactive_proof.num_rounds(); + assert_eq!(prover_messages.len(), num_rounds); + let mut statement = self.interactive_proof.get_initial_verifier_state(); + for message in prover_messages { + let challenge = hash_state.verifier_challenge(); + P::update_verifier_state(message, &mut statement, &challenge)?; + let serialized_message = self.interactive_proof.serialize_message(&message); + hash_state.prover_message(&serialized_message); + } + Ok(()) + } +} diff --git a/src/group/msm.rs b/src/group/msm.rs index adebc78..b3a1af0 100644 --- a/src/group/msm.rs +++ b/src/group/msm.rs @@ -27,11 +27,11 @@ const fn ln_without_floats(a: usize) -> usize { /// ``` /// Implementations can override this with optimized algorithms for specific groups, /// while a default naive implementation is provided for all [`PrimeGroup`] types. -pub trait VariableMultiScalarMul { +pub trait VariableMultiScalarMul: Sized { /// The scalar field type associated with the group. type Scalar; /// The group element (point) type. - type Point; + type Point: PrimeGroup; /// Computes the multi-scalar multiplication (MSM) over the provided scalars and points. /// @@ -44,7 +44,12 @@ pub trait VariableMultiScalarMul { /// /// # Panics /// Panics if `scalars.len() != bases.len()`. - fn msm(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self; + fn msm(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self { + assert_eq!(scalars.len(), bases.len()); + Self::msm_unchecked(scalars, bases) + } + + fn msm_unchecked(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self; } impl VariableMultiScalarMul for G { @@ -61,7 +66,7 @@ impl VariableMultiScalarMul for G { /// /// # Panics /// Panics if `scalars.len() != bases.len()`. - fn msm(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self { + fn msm_unchecked(scalars: &[Self::Scalar], bases: &[Self::Point]) -> Self { assert_eq!(scalars.len(), bases.len()); // NOTE: Based on the msm benchmark in this repo, msm_pippenger provides improvements over diff --git a/src/group/serialization.rs b/src/group/serialization.rs index ec2f711..b72b5a1 100644 --- a/src/group/serialization.rs +++ b/src/group/serialization.rs @@ -118,3 +118,10 @@ pub fn deserialize_scalars(data: &[u8], count: usize) -> Option(data: &[u8], count: usize) -> Option<(Vec, &[u8])> { + let element_len = group_elt_serialized_len::(); + let elements = deserialize_elements::(data, count)?; + Some((elements, &data[count * element_len..])) +} diff --git a/src/lib.rs b/src/lib.rs index 993d549..ef9d007 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,3 +89,6 @@ pub use linear_relation::LinearRelation; #[deprecated = "Use sigma_proofs::group::serialization instead"] pub use group::serialization; + +#[allow(unused)] +mod compressed; diff --git a/src/linear_relation/mod.rs b/src/linear_relation/mod.rs index 8fe6383..d49eceb 100644 --- a/src/linear_relation/mod.rs +++ b/src/linear_relation/mod.rs @@ -34,7 +34,7 @@ pub use canonical::CanonicalLinearRelation; /// /// Used to reference scalars in sparse linear combinations. #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct ScalarVar(usize, PhantomData); +pub struct ScalarVar(pub(crate) usize, PhantomData); impl ScalarVar { pub fn index(&self) -> usize { diff --git a/src/traits.rs b/src/traits.rs index f669656..86b584f 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -142,3 +142,47 @@ pub trait SigmaProtocolSimulator: SigmaProtocol { rng: &mut R, ) -> Result, Error>; } + +pub trait InteractiveProof { + type ProverState; + type ProverMessage; + type VerifierState; + type Challenge; + type Witness; + + fn get_initial_prover_state(&self, witness: &Self::Witness) -> Self::ProverState; + + fn get_initial_verifier_state(&self) -> Self::VerifierState; + + fn prover_message( + &self, + state: &mut Self::ProverState, + challenge: &Self::Challenge, + ) -> Result; + + fn update_verifier_state( + prover_message: &Self::ProverMessage, + state: &mut Self::VerifierState, + challenge: &Self::Challenge, + ) -> Result<(), Error>; + + fn serialize_message(&self, prover_message: &Self::ProverMessage) -> Vec; + + /// Serializes a challenge to bytes. + fn serialize_challenge(&self, challenge: &Self::Challenge) -> Vec; + + fn deserialize_message( + &self, + data: &[u8], + is_final_message: bool, + ) -> Result; + + /// Deserializes a challenge from bytes. + fn deserialize_challenge(&self, data: &[u8]) -> Result; + + fn protocol_identifier(&self) -> impl AsRef<[u8]>; + + fn instance_label(&self) -> impl AsRef<[u8]>; + + fn num_rounds(&self) -> usize; +}