diff --git a/Cargo.toml b/Cargo.toml index d38c9ce..8761bb8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,26 +19,27 @@ description = "A toolkit for auto-generated implementations of Σ-protocols" exclude = [ ".gitignore" ] +rust-version = "1.81" [features] default = ["std"] -std = ["thiserror", "rand", "num-bigint/std", "num-traits/std", "sha3/std", "rand_core/std"] +std = ["rand", "num-bigint/std", "num-traits/std", "sha3/std", "rand_core/std"] [dependencies] +ahash = { version = "0.8", default-features = false } ff = { version = "0.13", features = ["derive"] } group = "0.13.0" +hashbrown = { version = "0.15", default-features = false } +keccak = { version = "0.1.5", default-features = false } num-bigint = { version = "0.4.6", default-features = false } num-traits = { version = "0.2.19", default-features = false, features = ["libm"] } rand = { version = "0.8.5", optional = true } rand_core = { version = "0.6", default-features = false } sha3 = { version = "0.10.8", default-features = false } subtle = { version = "2.6.1", default-features = false } -thiserror = { version = "1", optional = true } -keccak = { version = "0.1.5", default-features = false } +thiserror = { version = "2.0.16", default-features = false } zerocopy = { version = "0.8", default-features = false } zeroize = { version = "1.8.1", default-features = false, features = ["alloc"] } -hashbrown = { version = "0.15", default-features = false } -ahash = { version = "0.8", default-features = false } [dev-dependencies] bls12_381 = "0.8.0" diff --git a/examples/schnorr.rs b/examples/schnorr.rs index d60dc50..7fc9a8a 100644 --- a/examples/schnorr.rs +++ b/examples/schnorr.rs @@ -6,64 +6,83 @@ //! //! where $G$ is a generator of a prime-order group $\mathbb{G}$ and $P$ is a public group element. +use std::process::ExitCode; + use curve25519_dalek::scalar::Scalar; use curve25519_dalek::RistrettoPoint; use group::Group; use rand::rngs::OsRng; use sigma_proofs::errors::Error; +use sigma_proofs::linear_relation::{Allocator, GroupVar, ScalarVar}; use sigma_proofs::LinearRelation; -type ProofResult = Result; - /// Create a discrete logarithm relation for the given public key P #[allow(non_snake_case)] -fn create_relation(P: RistrettoPoint) -> LinearRelation { +fn create_relation() -> ( + LinearRelation, + ScalarVar, + GroupVar, +) { let mut relation = LinearRelation::new(); let x = relation.allocate_scalar(); - let G = relation.allocate_element(); - let P_var = relation.allocate_eq(x * G); - relation.set_element(G, RistrettoPoint::generator()); - relation.set_element(P_var, P); + let G = relation.allocate_element_with(RistrettoPoint::generator()); + let P = relation.allocate_eq(x * G); - relation + (relation, x, P) } /// Prove knowledge of the discrete logarithm: given witness x and public key P, /// generate a proof that P = x * G #[allow(non_snake_case)] -fn prove(x: Scalar, P: RistrettoPoint) -> ProofResult> { - let nizk = create_relation(P).into_nizk(b"sigma-proofs-example"); - nizk?.prove_batchable(&vec![x], &mut OsRng) +fn prove(x: Scalar) -> Result, Error> { + let (mut relation, x_var, _) = create_relation(); + let witness = [(x_var, x)]; + relation.compute_image(witness)?; + relation + .into_nizk(b"sigma-proofs-example")? + .prove_batchable(witness, &mut OsRng) } /// Verify a proof of knowledge of discrete logarithm for the given public key P #[allow(non_snake_case)] -fn verify(P: RistrettoPoint, proof: &[u8]) -> ProofResult<()> { - let nizk = create_relation(P).into_nizk(b"sigma-proofs-example"); - nizk?.verify_batchable(proof) +fn verify(P: RistrettoPoint, proof: &[u8]) -> Result<(), Error> { + let (mut relation, _, P_var) = create_relation(); + relation.assign_element(P_var, P); + relation + .into_nizk(b"sigma-proofs-example")? + .verify_batchable(proof) } #[allow(non_snake_case)] -fn main() { +fn main() -> ExitCode { let x = Scalar::random(&mut OsRng); // Private key (witness) let P = RistrettoPoint::generator() * x; // Public key (statement) println!("Generated new key pair:"); println!("Public key P: {:?}", hex::encode(P.compress().as_bytes())); - match prove(x, P) { + let proof = match prove(x) { Ok(proof) => { println!("Proof generated successfully:"); println!("Proof (hex): {}", hex::encode(&proof)); + proof + } + Err(e) => { + println!("✗ Failed to generate proof: {e:?}"); + return ExitCode::FAILURE; + } + }; - // Verify the proof - match verify(P, &proof) { - Ok(()) => println!("✓ Proof verified successfully!"), - Err(e) => println!("✗ Proof verification failed: {e:?}"), - } + // Verify the proof + match verify(P, &proof) { + Ok(()) => println!("✓ Proof verified successfully!"), + Err(e) => { + println!("✗ Proof verification failed: {e:?}"); + return ExitCode::FAILURE; } - Err(e) => println!("✗ Failed to generate proof: {e:?}"), } + + ExitCode::SUCCESS } diff --git a/examples/simple_composition.rs b/examples/simple_composition.rs index 4256ee8..320c952 100644 --- a/examples/simple_composition.rs +++ b/examples/simple_composition.rs @@ -8,6 +8,7 @@ use sigma_proofs::{ codec::Shake128DuplexSponge, composition::{ComposedRelation, ComposedWitness}, errors::Error, + linear_relation::{Allocator, ScalarVar}, LinearRelation, Nizk, }; @@ -18,29 +19,30 @@ type ProofResult = Result; /// 1. Knowledge of discrete log: P1 = x1 * G /// 2. Knowledge of DLEQ: (P2 = x2 * G, Q = x2 * H) #[allow(non_snake_case)] -fn create_relation(P1: G, P2: G, Q: G, H: G) -> ComposedRelation { +fn create_relation(P1: G, P2: G, Q: G, H: G) -> (ComposedRelation, ScalarVar, ScalarVar) { // First relation: discrete logarithm P1 = x1 * G let mut rel1 = LinearRelation::::new(); let x1 = rel1.allocate_scalar(); - let G1 = rel1.allocate_element(); + let G1 = rel1.allocate_element_with(G::generator()); let P1_var = rel1.allocate_eq(x1 * G1); - rel1.set_element(G1, G::generator()); - rel1.set_element(P1_var, P1); + rel1.assign_element(P1_var, P1); // Second relation: DLEQ (P2 = x2 * G, Q = x2 * H) let mut rel2 = LinearRelation::::new(); let x2 = rel2.allocate_scalar(); - let G2 = rel2.allocate_element(); - let H_var = rel2.allocate_element(); + let G2 = rel2.allocate_element_with(G::generator()); + let H_var = rel2.allocate_element_with(H); let P2_var = rel2.allocate_eq(x2 * G2); let Q_var = rel2.allocate_eq(x2 * H_var); - rel2.set_element(G2, G::generator()); - rel2.set_element(H_var, H); - rel2.set_element(P2_var, P2); - rel2.set_element(Q_var, Q); + rel2.assign_element(P2_var, P2); + rel2.assign_element(Q_var, Q); // Compose into OR protocol - ComposedRelation::or([rel1.canonical().unwrap(), rel2.canonical().unwrap()]) + ( + ComposedRelation::or([rel1.canonical().unwrap(), rel2.canonical().unwrap()]), + x1, + x2, + ) } /// Prove knowledge of one of the witnesses (we know x2 for the DLEQ) @@ -50,22 +52,18 @@ fn prove(P1: G, x2: Scalar, H: G) -> ProofResult> { let P2 = G::generator() * x2; let Q = H * x2; - let instance = create_relation(P1, P2, Q, H); + let (relation, var_x1, var_x2) = create_relation(P1, P2, Q, H); // Create OR witness with branch 1 being the real one (index 1) - let witness = ComposedWitness::Or(vec![ - ComposedWitness::Simple(vec![Scalar::from(0u64)]), - ComposedWitness::Simple(vec![x2]), - ]); - let nizk = Nizk::<_, Shake128DuplexSponge>::new(b"or_proof_example", instance); - - nizk.prove_batchable(&witness, &mut OsRng) + let witness = ComposedWitness::or([[(var_x1, Scalar::from(0u64))], [(var_x2, x2)]]); + let nizk = Nizk::<_, Shake128DuplexSponge>::new(b"or_proof_example", relation); + nizk.prove_batchable(witness, &mut OsRng) } /// Verify an OR proof given the public values #[allow(non_snake_case)] fn verify(P1: G, P2: G, Q: G, H: G, proof: &[u8]) -> ProofResult<()> { - let protocol = create_relation(P1, P2, Q, H); - let nizk = Nizk::<_, Shake128DuplexSponge>::new(b"or_proof_example", protocol); + let (relation, _, _) = create_relation(P1, P2, Q, H); + let nizk = Nizk::<_, Shake128DuplexSponge>::new(b"or_proof_example", relation); nizk.verify_batchable(proof) } diff --git a/src/composition.rs b/src/composition.rs index c3d689e..e92b53a 100644 --- a/src/composition.rs +++ b/src/composition.rs @@ -28,8 +28,10 @@ use rand_core::{CryptoRng, RngCore as Rng}; use sha3::{Digest, Sha3_256}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; +use crate::codec::Codec; use crate::errors::InvalidInstance; use crate::group::serialization::{deserialize_scalars, serialize_scalars}; +use crate::linear_relation::{Allocator, ScalarVar}; use crate::{ codec::Shake128DuplexSponge, errors::Error, @@ -51,7 +53,7 @@ pub enum ComposedRelation { Or(Vec>), } -impl ComposedRelation { +impl ComposedRelation { /// Create a [ComposedRelation] for an AND relation from the given list of relations. pub fn and>>(witness: impl IntoIterator) -> Self { Self::And(witness.into_iter().map(|x| x.into()).collect()) @@ -69,10 +71,10 @@ impl From> for ComposedRelation { } } -impl TryFrom> for ComposedRelation { +impl> TryFrom> for ComposedRelation { type Error = InvalidInstance; - fn try_from(value: LinearRelation) -> Result { + fn try_from(value: LinearRelation) -> Result { Ok(Self::Simple(CanonicalLinearRelation::try_from(value)?)) } } @@ -223,6 +225,18 @@ impl ComposedWitness { } } +impl From<[(ScalarVar, G::Scalar); N]> for ComposedWitness { + fn from(value: [(ScalarVar, G::Scalar); N]) -> Self { + Self::Simple(value.into()) + } +} + +impl From, G::Scalar)>> for ComposedWitness { + fn from(value: Vec<(ScalarVar, G::Scalar)>) -> Self { + Self::Simple(value.into()) + } +} + impl From< as SigmaProtocol>::Witness> for ComposedWitness { @@ -237,6 +251,29 @@ const fn composed_challenge_size() -> usize { (G::Scalar::NUM_BITS as usize).div_ceil(8) } +impl ComposedRelation +where + Self: SigmaProtocol, +{ + /// Convert this Protocol into a non-interactive zero-knowledge proof + /// using the Shake128DuplexSponge codec and a specified session identifier. + /// + /// This method provides a convenient way to create a NIZK from a Protocol + /// without exposing the specific codec type to the API caller. + /// + /// # Parameters + /// - `session_identifier`: Domain separator bytes for the Fiat-Shamir transform + /// + /// # Returns + /// A `Nizk` instance ready for proving and verification + pub fn into_nizk(self, session_identifier: &[u8]) -> Nizk> + where + Shake128DuplexSponge: Codec::Challenge>, + { + Nizk::new(session_identifier, self) + } +} + impl ComposedRelation { fn is_witness_valid(&self, witness: &ComposedWitness) -> Choice { match (self, witness) { @@ -261,7 +298,7 @@ impl ComposedRelation< fn prover_commit_simple( protocol: &CanonicalLinearRelation, - witness: & as SigmaProtocol>::Witness, + witness: as SigmaProtocol>::Witness, rng: &mut (impl Rng + CryptoRng), ) -> Result<(ComposedCommitment, ComposedProverState), Error> { protocol.prover_commit(witness, rng).map(|(c, s)| { @@ -284,7 +321,7 @@ impl ComposedRelation< fn prover_commit_and( protocols: &[ComposedRelation], - witnesses: &[ComposedWitness], + witnesses: Vec>, rng: &mut (impl Rng + CryptoRng), ) -> Result<(ComposedCommitment, ComposedProverState), Error> { if protocols.len() != witnesses.len() { @@ -294,7 +331,7 @@ impl ComposedRelation< let mut commitments = Vec::with_capacity(protocols.len()); let mut prover_states = Vec::with_capacity(protocols.len()); - for (p, w) in protocols.iter().zip(witnesses.iter()) { + for (p, w) in protocols.iter().zip(witnesses.into_iter()) { let (c, s) = p.prover_commit(w, rng)?; commitments.push(c); prover_states.push(s); @@ -326,7 +363,7 @@ impl ComposedRelation< fn prover_commit_or( instances: &[ComposedRelation], - witnesses: &[ComposedWitness], + witnesses: Vec>, rng: &mut (impl Rng + CryptoRng), ) -> Result<(ComposedCommitment, ComposedProverState), Error> where @@ -341,14 +378,15 @@ impl ComposedRelation< // Selector value set when the first valid witness is found. let mut valid_witness_found = Choice::from(0); - for (i, w) in witnesses.iter().enumerate() { + for (i, w) in witnesses.into_iter().enumerate() { + // Determine whether or not to use the real witness for this relation. This rule uses + // the first valid witness found in the given list. + let select_witness = instances[i].is_witness_valid(&w) & !valid_witness_found; + let (commitment, prover_state) = instances[i].prover_commit(w, rng)?; let (simulated_commitment, simulated_challenge, simulated_response) = instances[i].simulate_transcript(rng)?; - let valid_witness = instances[i].is_witness_valid(w) & !valid_witness_found; - let select_witness = valid_witness; - let commitment = ComposedCommitment::conditional_select( &simulated_commitment, &commitment, @@ -363,7 +401,7 @@ impl ComposedRelation< simulated_response, )); - valid_witness_found |= valid_witness; + valid_witness_found |= select_witness; } if valid_witness_found.unwrap_u8() == 0 { @@ -439,7 +477,7 @@ impl SigmaProtocol fn prover_commit( &self, - witness: &Self::Witness, + witness: Self::Witness, rng: &mut (impl Rng + CryptoRng), ) -> Result<(Self::Commitment, Self::ProverState), Error> { match (self, witness) { @@ -788,23 +826,3 @@ impl SigmaProtocolSimu } } } - -impl ComposedRelation { - /// Convert this Protocol into a non-interactive zero-knowledge proof - /// using the Shake128DuplexSponge codec and a specified session identifier. - /// - /// This method provides a convenient way to create a NIZK from a Protocol - /// without exposing the specific codec type to the API caller. - /// - /// # Parameters - /// - `session_identifier`: Domain separator bytes for the Fiat-Shamir transform - /// - /// # Returns - /// A `Nizk` instance ready for proving and verification - pub fn into_nizk( - self, - session_identifier: &[u8], - ) -> Nizk, Shake128DuplexSponge> { - Nizk::new(session_identifier, self) - } -} diff --git a/src/errors.rs b/src/errors.rs index 3350d40..2a8827d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -8,19 +8,27 @@ //! - Mismatched parameter lengths (e.g., during batch verification), //! - Access to unassigned group variables in constraint systems. -use alloc::string::String; -#[cfg(not(feature = "std"))] -use core::fmt; +use alloc::string::{String, ToString}; + +// Publicly export the unassigned variable errors from this module. +pub use crate::linear_relation::collections::{UnassignedGroupVarError, UnassignedScalarVarError}; /// Represents an invalid instance error. -#[derive(Debug)] -#[cfg_attr(feature = "std", derive(thiserror::Error))] -#[cfg_attr(feature = "std", error("Invalid instance: {message}"))] +#[derive(Debug, thiserror::Error)] +#[error("Invalid instance: {message}")] pub struct InvalidInstance { /// The error message describing what's invalid about the instance. pub message: String, } +impl From for InvalidInstance { + fn from(value: UnassignedGroupVarError) -> Self { + Self { + message: value.to_string(), + } + } +} + impl InvalidInstance { /// Create a new InvalidInstance error with the given message. pub fn new(message: impl Into) -> Self { @@ -31,6 +39,7 @@ impl InvalidInstance { } impl From for Error { + // TODO: Don't drop the error message here. fn from(_err: InvalidInstance) -> Self { Error::InvalidInstanceWitnessPair } @@ -40,45 +49,18 @@ impl From for Error { /// /// This may occur during proof generation, response computation, or verification. #[non_exhaustive] -#[derive(Debug)] -#[cfg_attr(feature = "std", derive(thiserror::Error))] +#[derive(Debug, thiserror::Error)] pub enum Error { /// The proof is invalid: verification failed. - #[cfg_attr(feature = "std", error("Verification failed."))] + #[error("Verification failed.")] VerificationFailure, /// Indicates an invalid statement/witness pair - #[cfg_attr(feature = "std", error("Invalid instance/witness pair."))] + #[error("Invalid instance/witness pair.")] InvalidInstanceWitnessPair, - /// Uninitialized group element variable. - #[cfg_attr( - feature = "std", - error("Uninitialized group element variable: {var_debug}") - )] - UnassignedGroupVar { - /// Debug representation of the unassigned variable. - var_debug: String, - }, -} - -// Manual Display implementation for no_std compatibility -#[cfg(not(feature = "std"))] -impl fmt::Display for InvalidInstance { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Invalid instance: {}", self.message) - } -} - -#[cfg(not(feature = "std"))] -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Error::VerificationFailure => write!(f, "Verification failed."), - Error::InvalidInstanceWitnessPair => write!(f, "Invalid instance/witness pair."), - Error::UnassignedGroupVar { var_debug } => { - write!(f, "Uninitialized group element variable: {}", var_debug) - } - } - } + #[error(transparent)] + UnassignedScalarVar(#[from] UnassignedScalarVarError), + #[error(transparent)] + UnassignedGroupVarError(#[from] UnassignedGroupVarError), } pub type Result = core::result::Result; diff --git a/src/fiat_shamir.rs b/src/fiat_shamir.rs index 28c69d7..0950b34 100644 --- a/src/fiat_shamir.rs +++ b/src/fiat_shamir.rs @@ -40,12 +40,7 @@ type Transcript

= ( /// - `P`: the Sigma protocol implementation. /// - `C`: the codec used for Fiat-Shamir. #[derive(Debug)] -pub struct Nizk -where - P: SigmaProtocol, - P::Challenge: PartialEq, - C: Codec, -{ +pub struct Nizk { /// Current codec state. pub hash_state: C, /// Underlying interactive proof. @@ -55,14 +50,14 @@ where impl Nizk where P: SigmaProtocol, - P::Challenge: PartialEq, - C: Codec + Clone, + C: Codec, { /// Constructs a new [`Nizk`] instance. /// /// # Parameters - /// - `iv`: Domain separation tag for the hash function (e.g., protocol name or context). - /// - `instance`: An instance of the interactive Sigma protocol. + /// - `session_identifier`: Domain separation tag for the protocol session (e.g. the name of + /// the application such as "private_wallet_protocol"). Should be globally unique. + /// - `interactive_proof`: An instance of the interactive Sigma protocol. /// /// # Returns /// A new [`Nizk`] that can generate and verify non-interactive proofs. @@ -78,6 +73,8 @@ where } } + /// Construct a new [`Nizk`] instance with the hash state instantiated from the given + /// initialization vector (IV). pub fn from_iv(iv: [u8; 64], interactive_proof: P) -> Self { let hash_state = C::from_iv(iv); Self { @@ -85,7 +82,14 @@ where interactive_proof, } } +} +impl Nizk +where + P: SigmaProtocol, + P::Challenge: PartialEq, + C: Codec + Clone, +{ /// Generates a non-interactive proof for a witness. /// /// Executes the interactive protocol steps (commit, derive challenge via hash, respond), @@ -105,12 +109,13 @@ where /// Panics if local verification fails. fn prove( &self, - witness: &P::Witness, + witness: impl Into, rng: &mut (impl RngCore + CryptoRng), ) -> Result, Error> { let mut hash_state = self.hash_state.clone(); - let (commitment, prover_state) = self.interactive_proof.prover_commit(witness, rng)?; + let (commitment, prover_state) = + self.interactive_proof.prover_commit(witness.into(), rng)?; // Fiat Shamir challenge let serialized_commitment = self.interactive_proof.serialize_commitment(&commitment); hash_state.prover_message(&serialized_commitment); @@ -171,7 +176,7 @@ where /// Panics if serialization fails (should not happen under correct implementation). pub fn prove_batchable( &self, - witness: &P::Witness, + witness: impl Into, rng: &mut (impl RngCore + CryptoRng), ) -> Result, Error> { let (commitment, _challenge, response) = self.prove(witness, rng)?; @@ -250,7 +255,7 @@ where /// Panics if serialization fails. pub fn prove_compact( &self, - witness: &P::Witness, + witness: impl Into, rng: &mut (impl RngCore + CryptoRng), ) -> Result, Error> { let (_commitment, challenge, response) = self.prove(witness, rng)?; @@ -277,6 +282,9 @@ where /// - The recomputed commitment or response is invalid under the Sigma protocol. pub fn verify_compact(&self, proof: &[u8]) -> Result<(), Error> { // Deserialize challenge and response from compact proof + // TODO: This way of deserializing the proof, with framing based on deserializing and the + // serialization, is non-standard and quite error prone if a given message ever has more + // than one valid encoding. let challenge = self.interactive_proof.deserialize_challenge(proof)?; let challenge_size = self.interactive_proof.serialize_challenge(&challenge).len(); let response = self diff --git a/src/lib.rs b/src/lib.rs index 993d549..7975e66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,12 +30,15 @@ //! instance.set_elements([(var_G, RistrettoPoint::generator()), (var_H, RistrettoPoint::random(&mut rng))]); //! //! // Assign the image of the linear map. -//! let witness = vec![Scalar::random(&mut rng), Scalar::random(&mut rng)]; -//! instance.compute_image(&witness); +//! let witness = [ +//! (var_x, Scalar::random(&mut rng)), +//! (var_r, Scalar::random(&mut rng)) +//! ]; +//! instance.compute_image(witness); //! //! // Create a non-interactive argument for the instance. //! let nizk = instance.into_nizk(b"your session identifier").unwrap(); -//! let narg_string: Vec = nizk.prove_batchable(&witness, &mut rng).unwrap(); +//! let narg_string: Vec = nizk.prove_batchable(witness, &mut rng).unwrap(); //! // Print the narg string. //! println!("{}", hex::encode(narg_string)); //! ``` diff --git a/src/linear_relation/allocator.rs b/src/linear_relation/allocator.rs new file mode 100644 index 0000000..a7c9361 --- /dev/null +++ b/src/linear_relation/allocator.rs @@ -0,0 +1,262 @@ +use core::{array, iter::zip, marker::PhantomData}; + +use group::{prime::PrimeGroup, Group}; + +use crate::{ + errors::UnassignedGroupVarError, + linear_relation::{GroupMap, GroupVar, ScalarMap, ScalarVar}, + LinearRelation, +}; + +pub trait Allocator { + type G; + + /// Allocates a scalar variable for use in the linear map. + fn allocate_scalar(&mut self) -> ScalarVar; + + /// Allocates `N` new scalar variables, with `N` known at compile-time. + /// + /// # Returns + /// An array of [`ScalarVar`] representing the newly allocated scalar references. + /// + /// # Example + /// ``` + /// # use sigma_proofs::LinearRelation; + /// use curve25519_dalek::RistrettoPoint as G; + /// + /// let mut relation = LinearRelation::::new(); + /// let [var_x, var_y] = relation.allocate_scalars(); + /// let vars = relation.allocate_scalars::<10>(); + /// ``` + fn allocate_scalars(&mut self) -> [ScalarVar; N] { + array::from_fn(|_| self.allocate_scalar()) + } + + /// Allocates `n` new scalar variables, with `n` decided at runtime. + /// + /// # Returns + /// A `Vec` of [`ScalarVar`] representing the newly allocated scalar references. + /// + /// # Example + /// ``` + /// # use sigma_proofs::LinearRelation; + /// use curve25519_dalek::RistrettoPoint as G; + /// + /// let mut relation = LinearRelation::::new(); + /// let vars = relation.allocate_scalars_vec(2); + /// assert_eq!(vars.len(), 2); + /// ``` + fn allocate_scalars_vec(&mut self, n: usize) -> Vec> { + (0..n).map(|_| self.allocate_scalar()).collect() + } + + /// Allocates a group element variable (i.e. elliptic curve point) for use in the linear map. + fn allocate_element(&mut self) -> GroupVar; + + /// Allocates `N` group element variables, with `N` known at compile-time. + /// + /// # Returns + /// An array of [`GroupVar`] representing the newly allocated group element references. + /// + /// # Example + /// ``` + /// # use sigma_proofs::LinearRelation; + /// use curve25519_dalek::RistrettoPoint as G; + /// + /// let mut relation = LinearRelation::::new(); + /// let [var_g, var_h] = relation.allocate_elements(); + /// let vars = relation.allocate_elements::<10>(); + /// ``` + fn allocate_elements(&mut self) -> [GroupVar; N] { + array::from_fn(|_| self.allocate_element()) + } + + /// Allocates `N` group element variables, with `N` decided at runtime. + /// + /// # Returns + /// A `Vec` of [`GroupVar`] representing the newly allocated group element references. + /// + /// # Example + /// ``` + /// # use sigma_proofs::LinearRelation; + /// use curve25519_dalek::RistrettoPoint as G; + /// + /// let mut relation = LinearRelation::::new(); + /// let vars = relation.allocate_elements_vec(2); + /// assert_eq!(vars.len(), 2); + /// ``` + fn allocate_elements_vec(&mut self, n: usize) -> Vec> { + (0..n).map(|_| self.allocate_element()).collect() + } + + fn allocate>(&mut self) -> T { + T::allocate(self) + } + + // TODO(victor/scalarvars): Should this be part of this trait, or should it be split off into + // its own trait? + fn assign_element(&mut self, var: GroupVar, element: Self::G); + + fn assign_elements( + &mut self, + assignments: impl IntoIterator, Self::G)>, + ) { + for (var, elem) in assignments.into_iter() { + self.assign_element(var, elem); + } + } + + fn get_element(&self, var: GroupVar) -> Result; +} + +pub trait Allocate { + type G; + + fn allocate + ?Sized>(alloc: &mut A) -> Self; +} + +impl Allocate for ScalarVar { + type G = G; + + fn allocate + ?Sized>(alloc: &mut A) -> Self { + alloc.allocate_scalar() + } +} + +impl Allocate for GroupVar { + type G = G; + + fn allocate + ?Sized>(alloc: &mut A) -> Self { + alloc.allocate_element() + } +} + +impl Allocate for [T; N] { + type G = T::G; + + fn allocate + ?Sized>(alloc: &mut A) -> Self { + array::from_fn(|_| alloc.allocate()) + } +} + +impl Allocate for (T1, T2) +where + T1: Allocate, + T2: Allocate, +{ + type G = T1::G; + + fn allocate + ?Sized>(alloc: &mut A) -> Self { + (T1::allocate(alloc), T2::allocate(alloc)) + } +} + +// TODO(victor/scalarvars) Rename this from Heap. Its not really a heap. +#[derive(Clone, Debug)] +pub struct Heap { + pub elements: GroupMap, + // TODO(victor/scalarvars): Should this be a ScalarMap? I hesitate to do so because I don't + // really want to store witness values on a struct like this. One particular reason for this is + // that this is a member of LinearRelation, which seems ok, but we do not want to carry the + // witness assignments in that struct as we convert it to a CanonicalRelation or Nizk. + pub num_scalars: usize, +} + +impl Default for Heap { + fn default() -> Self { + Self { + elements: Default::default(), + num_scalars: 0, + } + } +} + +impl Allocator for Heap { + type G = G; + + fn allocate_scalar(&mut self) -> ScalarVar { + self.num_scalars += 1; + ScalarVar(self.num_scalars - 1, PhantomData) + } + + fn allocate_element(&mut self) -> GroupVar { + self.elements.allocate_element() + } + + fn assign_element(&mut self, var: GroupVar, element: Self::G) { + self.elements.assign_element(var, element) + } + + fn get_element(&self, var: GroupVar) -> Result { + self.elements.get(var) + } +} + +pub trait ScalarAssignment { + type G: Group; + type Assignment; + + fn assign(&self, map: &mut ScalarMap, value: Self::Assignment); + + fn assignments(&self, value: Self::Assignment) -> ScalarMap { + let mut map = ScalarMap::default(); + map.assign(self, value); + map + } +} + +impl ScalarAssignment for ScalarVar { + type G = G; + type Assignment = G::Scalar; + + fn assign(&self, map: &mut ScalarMap, value: Self::Assignment) { + map.assign_scalar(*self, value) + } +} + +impl ScalarAssignment for [T; N] { + type G = T::G; + type Assignment = [T::Assignment; N]; + + fn assign(&self, map: &mut ScalarMap, value: Self::Assignment) { + for (var, value) in zip(self, value) { + var.assign(map, value); + } + } +} + +impl ScalarAssignment for (T1, T2) +where + T1: ScalarAssignment, + T2: ScalarAssignment, +{ + type G = T1::G; + type Assignment = (T1::Assignment, T2::Assignment); + + fn assign(&self, map: &mut ScalarMap, value: Self::Assignment) { + self.0.assign(map, value.0); + self.1.assign(map, value.1); + } +} + +impl ScalarMap { + pub fn assign + ?Sized>(&mut self, var: &A, value: A::Assignment) { + var.assign(self, value) + } +} + +#[non_exhaustive] +pub struct StructuredRelation { + pub vars: Vars, + pub relation: LinearRelation, +} + +impl> StructuredRelation { + fn new() -> Self { + let mut relation = LinearRelation::new(); + Self { + vars: relation.allocate(), + relation, + } + } +} diff --git a/src/linear_relation/canonical.rs b/src/linear_relation/canonical.rs index 84d2723..b9539fe 100644 --- a/src/linear_relation/canonical.rs +++ b/src/linear_relation/canonical.rs @@ -5,17 +5,20 @@ use alloc::vec::Vec; use core::iter; use core::marker::PhantomData; #[cfg(not(feature = "std"))] -use hashbrown::HashMap; +use hashbrown::{HashMap, HashSet}; #[cfg(feature = "std")] -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use ff::Field; use group::prime::PrimeGroup; use subtle::{Choice, ConstantTimeEq}; -use super::{GroupMap, GroupVar, LinearCombination, LinearRelation, ScalarTerm, ScalarVar}; +use super::{ + GroupMap, GroupVar, LinearCombination, LinearRelation, ScalarAssignments, ScalarTerm, ScalarVar, +}; use crate::errors::{Error, InvalidInstance}; use crate::group::msm::VariableMultiScalarMul; +use crate::linear_relation::Allocator; use crate::serialization::serialize_elements; // XXX. this definition is uncomfortably similar to LinearRelation, exception made for the weights. @@ -26,6 +29,7 @@ use crate::serialization::serialize_elements; /// constraint is of the form: image_i = Σ (scalar_j * group_element_k) /// without weights or extra scalars. #[derive(Clone, Debug, Default)] +#[non_exhaustive] pub struct CanonicalLinearRelation { /// The image group elements (left-hand side of equations) pub image: Vec>, @@ -34,21 +38,10 @@ pub struct CanonicalLinearRelation { pub linear_combinations: Vec, GroupVar)>>, /// The group elements map pub group_elements: GroupMap, - /// Number of scalar variables - pub num_scalars: usize, + /// Set of scalar variables used in this relation. + pub scalar_vars: HashSet>, } -/// Private type alias used to simplify function signatures below. -/// -/// The cache is essentially a mapping (GroupVar, Scalar) => GroupVar, which maps the original -/// weighted group vars to a new assignment, such that if a pair appears more than once, it will -/// map to the same group variable in the canonical linear relation. -#[cfg(feature = "std")] -type WeightedGroupCache = HashMap, Vec<(::Scalar, GroupVar)>>; -#[cfg(not(feature = "std"))] -type WeightedGroupCache = - HashMap, Vec<(::Scalar, GroupVar)>, RandomState>; - impl CanonicalLinearRelation { /// Create a new empty canonical linear relation. /// @@ -59,10 +52,12 @@ impl CanonicalLinearRelation { image: Vec::new(), linear_combinations: Vec::new(), group_elements: GroupMap::default(), - num_scalars: 0, + scalar_vars: HashSet::default(), } } + // QUESTION: Why does this currently panic when a variable is unassigned? Should this return + // Result instead? /// Evaluate the canonical linear relation with the provided scalars /// /// This returns a list of image points produced by evaluating each linear combination in the @@ -73,13 +68,13 @@ impl CanonicalLinearRelation { /// Panics if the number of scalars given is less than the number of scalar variables in this /// linear relation. /// If the vector of scalars if longer than the number of terms in each linear combinations, the extra terms are ignored. - pub fn evaluate(&self, scalars: &[G::Scalar]) -> Vec { + pub fn evaluate(&self, scalars: impl ScalarAssignments) -> Vec { self.linear_combinations .iter() .map(|lc| { let scalars = lc .iter() - .map(|(scalar_var, _)| scalars[scalar_var.index()]) + .map(|(scalar_var, _)| scalars.get(*scalar_var).unwrap()) .collect::>(); let bases = lc .iter() @@ -90,99 +85,6 @@ impl CanonicalLinearRelation { .collect() } - /// Get or create a GroupVar for a weighted group element, with deduplication - fn get_or_create_weighted_group_var( - &mut self, - group_var: GroupVar, - weight: &G::Scalar, - original_group_elements: &GroupMap, - weighted_group_cache: &mut WeightedGroupCache, - ) -> Result, InvalidInstance> { - // Check if we already have this (weight, group_var) combination - let entry = weighted_group_cache.entry(group_var).or_default(); - - // Find if we already have this weight for this group_var - if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) { - return Ok(*existing_var); - } - - // Create new weighted group element - // Use a special case for one, as this is the most common weight. - let original_group_val = original_group_elements.get(group_var)?; - let weighted_group = match *weight == G::Scalar::ONE { - true => original_group_val, - false => original_group_val * weight, - }; - - // Add to our group elements with new index (length) - let new_var = self.group_elements.push(weighted_group); - - // Cache the mapping for this group_var and weight - entry.push((*weight, new_var)); - - Ok(new_var) - } - - /// Process a single constraint equation and add it to the canonical relation. - fn process_constraint( - &mut self, - &image_var: &GroupVar, - equation: &LinearCombination, - original_relation: &LinearRelation, - weighted_group_cache: &mut WeightedGroupCache, - ) -> Result<(), InvalidInstance> { - let mut rhs_terms = Vec::new(); - - // Collect RHS terms that have scalar variables and apply weights - for weighted_term in equation.terms() { - if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar { - let group_var = weighted_term.term.elem; - let weight = &weighted_term.weight; - - if weight.is_zero_vartime() { - continue; // Skip zero weights - } - - let canonical_group_var = self.get_or_create_weighted_group_var( - group_var, - weight, - &original_relation.linear_map.group_elements, - weighted_group_cache, - )?; - - rhs_terms.push((scalar_var, canonical_group_var)); - } - } - - // Compute the canonical image by subtracting constant terms from the original image - let mut canonical_image = original_relation.linear_map.group_elements.get(image_var)?; - for weighted_term in equation.terms() { - if let ScalarTerm::Unit = weighted_term.term.scalar { - let group_val = original_relation - .linear_map - .group_elements - .get(weighted_term.term.elem)?; - canonical_image -= group_val * weighted_term.weight; - } - } - - // Only include constraints that are non-trivial (not zero constraints). - if rhs_terms.is_empty() { - if canonical_image.is_identity().into() { - return Ok(()); - } - return Err(InvalidInstance::new( - "trivially false constraint: constraint has empty right-hand side and non-identity left-hand side", - )); - } - - let canonical_image_group_var = self.group_elements.push(canonical_image); - self.image.push(canonical_image_group_var); - self.linear_combinations.push(rhs_terms); - - Ok(()) - } - /// Serialize the linear relation to bytes. /// /// The output format is: @@ -372,12 +274,14 @@ impl CanonicalLinearRelation { // Build the canonical relation let mut canonical = Self::new(); - canonical.num_scalars = (max_scalar_index + 1) as usize; + canonical.scalar_vars = (0..=max_scalar_index as usize) + .map(|i| ScalarVar(i, PhantomData)) + .collect(); // Add all group elements to the map let mut group_var_map = Vec::new(); for elem in &group_elements_ordered { - let var = canonical.group_elements.push(*elem); + let var = canonical.group_elements.allocate_element_with(*elem); group_var_map.push(var); } @@ -410,37 +314,33 @@ impl CanonicalLinearRelation { } } -impl TryFrom> for CanonicalLinearRelation { +impl> TryFrom> + for CanonicalLinearRelation +{ type Error = InvalidInstance; - fn try_from(value: LinearRelation) -> Result { + fn try_from(value: LinearRelation) -> Result { Self::try_from(&value) } } -impl TryFrom<&LinearRelation> for CanonicalLinearRelation { +impl> TryFrom<&LinearRelation> + for CanonicalLinearRelation +{ type Error = InvalidInstance; - fn try_from(relation: &LinearRelation) -> Result { - if relation.image.len() != relation.linear_map.linear_combinations.len() { + fn try_from(relation: &LinearRelation) -> Result { + if relation.image.len() != relation.linear_combinations.len() { return Err(InvalidInstance::new( "Number of equations must be equal to number of image elements.", )); } - let mut canonical = CanonicalLinearRelation::new(); - canonical.num_scalars = relation.linear_map.num_scalars; - - // Cache for deduplicating weighted group elements - #[cfg(feature = "std")] - let mut weighted_group_cache = HashMap::new(); - #[cfg(not(feature = "std"))] - let mut weighted_group_cache = HashMap::with_hasher(RandomState::new()); - - // Process each constraint using the modular helper method - for (lhs, rhs) in iter::zip(&relation.image, &relation.linear_map.linear_combinations) { + // Process each constraint using the canonical linear relation builder. + let mut builder = CanonicalLinearRelationBuilder::default(); + for (lhs, rhs) in iter::zip(&relation.image, &relation.linear_combinations) { // If any group element in the image is not assigned, return `InvalidInstance`. - let lhs_value = relation.linear_map.group_elements.get(*lhs)?; + let lhs_value = relation.heap.get_element(*lhs)?; // Compute the constant terms on the right-hand side of the equation. // If any group element in the linear constraints is not assigned, return `InvalidInstance`. @@ -449,11 +349,11 @@ impl TryFrom<&LinearRelation> for CanonicalLinearRelation { .iter() .filter(|term| matches!(term.term.scalar, ScalarTerm::Unit)) .map(|term| { - let elem = relation.linear_map.group_elements.get(term.term.elem)?; + let elem = relation.heap.get_element(term.term.elem)?; let scalar = term.weight; Ok((elem, scalar)) }) - .collect::, Vec), _>>()?; + .collect::, Vec), Self::Error>>()?; let rhs_constant_term = G::msm(&rhs_constant_terms.1, &rhs_constant_terms.0); @@ -477,10 +377,10 @@ impl TryFrom<&LinearRelation> for CanonicalLinearRelation { return Err(InvalidInstance::new("Trivial kernel in this relation")); } - canonical.process_constraint(lhs, rhs, relation, &mut weighted_group_cache)?; + builder.process_constraint(lhs, rhs, relation)?; } - Ok(canonical) + Ok(builder.build()) } } @@ -493,10 +393,138 @@ impl CanonicalLinearRelation { /// /// Panics if the number of scalars given is less than the number of scalar variables. /// If the number of scalars is more than the number of scalar variables, the extra elements are ignored. - pub fn is_witness_valid(&self, witness: &[G::Scalar]) -> Choice { + pub fn is_witness_valid(&self, witness: impl ScalarAssignments) -> Choice { let got = self.evaluate(witness); self.image_elements() .zip(got) .fold(Choice::from(1), |acc, (lhs, rhs)| acc & lhs.ct_eq(&rhs)) } } + +/// Private type alias used to simplify function signatures below. +/// +/// The cache is essentially a mapping (GroupVar, Scalar) => GroupVar, which maps the original +/// weighted group vars to a new assignment, such that if a pair appears more than once, it will +/// map to the same group variable in the canonical linear relation. +#[cfg(feature = "std")] +type WeightedGroupCache = HashMap, Vec<(::Scalar, GroupVar)>>; +#[cfg(not(feature = "std"))] +type WeightedGroupCache = + HashMap, Vec<(::Scalar, GroupVar)>, RandomState>; + +#[derive(Debug)] +struct CanonicalLinearRelationBuilder { + relation: CanonicalLinearRelation, + weighted_group_cache: WeightedGroupCache, +} + +impl CanonicalLinearRelationBuilder { + /// Get or create a GroupVar for a weighted group element, with deduplication + fn get_or_create_weighted_group_var>( + &mut self, + group_var: GroupVar, + weight: &G::Scalar, + original_alloc: &A, + ) -> Result, InvalidInstance> { + // Check if we already have this (weight, group_var) combination + let entry = self.weighted_group_cache.entry(group_var).or_default(); + + // Find if we already have this weight for this group_var + if let Some((_, existing_var)) = entry.iter().find(|(w, _)| w == weight) { + return Ok(*existing_var); + } + + // Create new weighted group element + // Use a special case for one, as this is the most common weight. + let original_group_val = original_alloc.get_element(group_var)?; + let weighted_group = match *weight == G::Scalar::ONE { + true => original_group_val, + false => original_group_val * weight, + }; + + // Add to our group elements with new index (length) + let new_var = self + .relation + .group_elements + .allocate_element_with(weighted_group); + + // Cache the mapping for this group_var and weight + entry.push((*weight, new_var)); + + Ok(new_var) + } + + /// Process a single constraint equation and add it to the canonical relation. + fn process_constraint>( + &mut self, + &image_var: &GroupVar, + equation: &LinearCombination, + allocator: &A, + ) -> Result<(), InvalidInstance> { + let mut rhs_terms = Vec::new(); + + // Collect RHS terms that have scalar variables and apply weights + for weighted_term in equation.terms() { + if let ScalarTerm::Var(scalar_var) = weighted_term.term.scalar { + let group_var = weighted_term.term.elem; + let weight = &weighted_term.weight; + + if weight.is_zero_vartime() { + continue; // Skip zero weights + } + + let canonical_group_var = + self.get_or_create_weighted_group_var(group_var, weight, allocator)?; + + rhs_terms.push((scalar_var, canonical_group_var)); + self.relation.scalar_vars.insert(scalar_var); + } + } + + // Compute the canonical image by subtracting constant terms from the original image + let mut canonical_image = allocator.get_element(image_var)?; + for weighted_term in equation.terms() { + if let ScalarTerm::Unit = weighted_term.term.scalar { + let group_val = allocator.get_element(weighted_term.term.elem)?; + canonical_image -= group_val * weighted_term.weight; + } + } + + // Only include constraints that are non-trivial (not zero constraints). + if rhs_terms.is_empty() { + if canonical_image.is_identity().into() { + return Ok(()); + } + return Err(InvalidInstance::new( + "trivially false constraint: constraint has empty right-hand side and non-identity left-hand side", + )); + } + + let canonical_image_group_var = self + .relation + .group_elements + .allocate_element_with(canonical_image); + self.relation.image.push(canonical_image_group_var); + self.relation.linear_combinations.push(rhs_terms); + + Ok(()) + } + + fn build(self) -> CanonicalLinearRelation { + self.relation + } +} + +impl Default for CanonicalLinearRelationBuilder { + fn default() -> Self { + #[cfg(feature = "std")] + let weighted_group_cache = HashMap::new(); + #[cfg(not(feature = "std"))] + let weighted_group_cache = HashMap::with_hasher(RandomState::new()); + + Self { + relation: CanonicalLinearRelation::new(), + weighted_group_cache, + } + } +} diff --git a/src/linear_relation/collections.rs b/src/linear_relation/collections.rs new file mode 100644 index 0000000..f2cb227 --- /dev/null +++ b/src/linear_relation/collections.rs @@ -0,0 +1,329 @@ +//! # Collections for Group and Scalar Vars +//! +//! This module provides collections of group elements and scalars, [GroupMap] and [ScalarMap]. +//! These collections act as a mapping of opaque variable references to values. + +use alloc::vec::Vec; +use core::marker::PhantomData; +use group::Group; + +use super::{GroupVar, ScalarVar}; + +/// Ordered mapping of [GroupVar] to group elements assignments. +#[derive(Clone, Debug)] +pub struct GroupMap(Vec>); + +impl GroupMap { + pub fn allocate_element(&mut self) -> GroupVar { + self.0.push(None); + GroupVar(self.0.len() - 1, PhantomData) + } + + /// Add a new group element to the map and return its variable reference + pub fn allocate_element_with(&mut self, element: G) -> GroupVar { + self.0.push(Some(element)); + GroupVar(self.0.len() - 1, PhantomData) + } + + /// Assign a group element value to a variable. + /// + /// # Parameters + /// + /// - `var`: The variable to assign. + /// - `element`: The value to assign to the variable. + /// + /// # Panics + /// + /// Panics if the given assignment conflicts with the existing assignment. + pub fn assign_element(&mut self, var: GroupVar, element: G) { + if self.0.len() <= var.0 { + self.0.resize(var.0 + 1, None); + } else if let Some(assignment) = self.0[var.0] { + assert_eq!( + assignment, element, + "conflicting assignments for var {var:?}" + ) + } + self.0[var.0] = Some(element); + } + + /// Assigns specific group elements to variables. + /// + /// # Parameters + /// + /// - `assignments`: A collection of `(GroupVar, G)` pairs that can be iterated over. + /// + /// # Panics + /// + /// Panics if the collection contains two conflicting assignments for the same variable. + pub fn assign_elements(&mut self, assignments: impl IntoIterator, G)>) { + for (var, elem) in assignments.into_iter() { + self.assign_element(var, elem); + } + } + + /// Get the element value assigned to the given variable. + /// + /// Returns [`InvalidInstance`] if a value is not assigned. + pub fn get(&self, var: GroupVar) -> Result { + match self.0.get(var.0) { + Some(Some(elem)) => Ok(*elem), + Some(None) | None => Err(UnassignedGroupVarError(var.to_elided())), + } + } + + /// Iterate over the assigned variable and group element pairs in this map. + // NOTE: Not implemented as `IntoIterator` for now because doing so requires explicitly + // defining an iterator type, See https://github.com/rust-lang/rust/issues/63063 + #[allow(clippy::should_implement_trait)] + pub fn into_iter(self) -> impl Iterator, Option)> { + self.0 + .into_iter() + .enumerate() + .map(|(i, x)| (GroupVar(i, PhantomData), x)) + } + + pub fn iter(&self) -> impl Iterator, Option<&G>)> { + self.0 + .iter() + .enumerate() + .map(|(i, opt)| (GroupVar(i, PhantomData), opt.as_ref())) + } + + pub fn vars(&self) -> impl Iterator> { + (0..self.len()).map(|i| GroupVar(i, PhantomData)) + } + + /// Get the number of elements in the map + pub fn len(&self) -> usize { + self.0.len() + } + + /// Check if the map is empty + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Default for GroupMap { + fn default() -> Self { + Self(Vec::default()) + } +} + +impl FromIterator<(GroupVar, G)> for GroupMap { + fn from_iter, G)>>(iter: T) -> Self { + iter.into_iter() + .fold(Self::default(), |mut instance, (var, val)| { + instance.assign_element(var, val); + instance + }) + } +} + +/// Ordered mapping of [ScalarVar] to scalar assignments. +#[derive(Clone, Debug)] +pub struct ScalarMap(Vec>); + +impl ScalarMap { + pub fn allocate_scalar(&mut self) -> ScalarVar { + self.0.push(None); + ScalarVar(self.0.len() - 1, PhantomData) + } + + /// Add a new scalar to the map and return its variable reference + pub fn allocate_scalar_with(&mut self, scalar: G::Scalar) -> ScalarVar { + self.0.push(Some(scalar)); + ScalarVar(self.0.len() - 1, PhantomData) + } + + /// Assign a scalar value to a variable. + /// + /// # Parameters + /// + /// - `var`: The variable to assign. + /// - `scalar`: The value to assign to the variable. + /// + /// # Panics + /// + /// Panics if the given assignment conflicts with the existing assignment. + pub fn assign_scalar(&mut self, var: ScalarVar, scalar: G::Scalar) { + if self.0.len() <= var.0 { + self.0.resize(var.0 + 1, None); + } else if let Some(assignment) = self.0[var.0] { + assert_eq!( + assignment, scalar, + "conflicting assignments for var {var:?}" + ) + } + self.0[var.0] = Some(scalar); + } + + /// Assigns specific scalars to variables. + /// + /// # Parameters + /// + /// - `assignments`: A collection of `(ScalarVar, G::Scalar)` pairs that can be iterated over. + /// + /// # Panics + /// + /// Panics if the collection contains two conflicting assignments for the same variable. + pub fn assign_scalars( + &mut self, + assignments: impl IntoIterator, G::Scalar)>, + ) { + for (var, elem) in assignments.into_iter() { + self.assign_scalar(var, elem); + } + } + + /// Get the scalar value assigned to the given variable. + /// + /// Returns [`InvalidInstance`] if a value is not assigned. + pub fn get(&self, var: ScalarVar) -> Result { + match self.0.get(var.0) { + Some(Some(elem)) => Ok(*elem), + Some(None) | None => Err(UnassignedScalarVarError(var.to_elided())), + } + } + + /// Iterate over the assigned variable and scalar pairs in this map. + // NOTE: Not implemented as `IntoIterator` for now because doing so requires explicitly + // defining an iterator type, See https://github.com/rust-lang/rust/issues/63063 + #[allow(clippy::should_implement_trait)] + pub fn into_iter(self) -> impl Iterator, Option)> { + self.0 + .into_iter() + .enumerate() + .map(|(i, x)| (ScalarVar(i, PhantomData), x)) + } + + /// Iterate over the assigned variable and scalar pairs in this map. + pub fn iter(&self) -> impl Iterator, Option<&G::Scalar>)> { + self.0 + .iter() + .enumerate() + .map(|(i, opt)| (ScalarVar(i, PhantomData), opt.as_ref())) + } + + /// Iterate over the scalar variable references in this scalar map. + pub fn vars(&self) -> impl Iterator> { + (0..self.0.len()).map(|i| ScalarVar(i, PhantomData)) + } + + pub fn zip<'a>( + &'a self, + other: &'a Self, + ) -> impl Iterator, Option, Option)> + use<'a, G> + { + // NOTE: Due to the packed representation, we know that var `i` is stored at position `i`. + // This simplifies the implementation by allowing iteration over the longer of the two to + // consider all allocated variables. `left` is the longer if different. + let (left, right) = match self.len() >= other.len() { + true => (self, other), + false => (other, self), + }; + left.vars() + .map(|var| (var, left.get(var).ok(), right.get(var).ok())) + } + + /// Get the number of scalars in the map + pub fn len(&self) -> usize { + self.0.len() + } + + /// Check if the map is empty + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Default for ScalarMap { + fn default() -> Self { + Self(Vec::default()) + } +} + +impl From, G::Scalar)>> for ScalarMap { + fn from(value: Vec<(ScalarVar, G::Scalar)>) -> Self { + Self::from_iter(value) + } +} + +impl From<[(ScalarVar, G::Scalar); N]> for ScalarMap { + fn from(value: [(ScalarVar, G::Scalar); N]) -> Self { + Self::from_iter(value) + } +} + +impl FromIterator<(ScalarVar, G::Scalar)> for ScalarMap { + fn from_iter, G::Scalar)>>(iter: T) -> Self { + iter.into_iter() + .fold(Self::default(), |mut instance, (var, val)| { + instance.assign_scalar(var, val); + instance + }) + } +} + +// TODO(victor/scalarvars): Potentially fold this into the definitions in allocator. +// A trait providing a mapping from [ScalarVar] for scalar values of type `G::Scalar`. +// TODO: The generic should at least by an associated type instead. A single struct will not +// implement multiple parameterizations of ScalarAssignments. +pub trait ScalarAssignments { + fn get(&self, var: ScalarVar) -> Result; +} + +impl ScalarAssignments for ScalarMap { + fn get(&self, var: ScalarVar) -> Result { + self.get(var) + } +} + +impl ScalarAssignments for &ScalarMap { + fn get(&self, var: ScalarVar) -> Result { + (*self).get(var) + } +} + +impl, G::Scalar)]>> ScalarAssignments for A { + /// Access the assignment of a [ScalarVar] from an array-like struct (e.g. `[_; N]` or `Vec`). + /// + /// The variable is fetched via a linear search. For small arrays, this is optimal and avoids + /// allocation into a [ScalarMap]. For statements with a large number of scalars, this will not + /// be as effcicient as allocating a [ScalarMap]. + fn get(&self, var: ScalarVar) -> Result<::Scalar, UnassignedScalarVarError> { + self.as_ref() + .iter() + .copied() + .find_map(|(var_i, scalar)| (var == var_i).then_some(scalar)) + .ok_or(UnassignedScalarVarError(var.to_elided())) + } +} + +/// An uninhabited type used to elide the type parameter on [UnassignedScalarVarError] and +/// [UnassignedGroupVarError]. +#[derive(Copy, Clone, Debug)] +enum Elided {} + +impl GroupVar { + fn to_elided(self) -> GroupVar { + GroupVar(self.0, PhantomData) + } +} + +impl ScalarVar { + fn to_elided(self) -> ScalarVar { + ScalarVar(self.0, PhantomData) + } +} + +/// Error for an attempted access to an unassigned [GroupVar]. +#[derive(Clone, Debug, thiserror::Error)] +#[error("Unassigned group variable: {0:?}")] +pub struct UnassignedGroupVarError(GroupVar); + +/// Error for an attempted access to an unassigned [ScalarVar]. +#[derive(Clone, Debug, thiserror::Error)] +#[error("Unassigned scalar variable: {0:?}")] +pub struct UnassignedScalarVarError(ScalarVar); diff --git a/src/linear_relation/mod.rs b/src/linear_relation/mod.rs index 8fe6383..7a1ce43 100644 --- a/src/linear_relation/mod.rs +++ b/src/linear_relation/mod.rs @@ -8,8 +8,24 @@ //! - [`LinearMap`]: a collection of linear combinations acting on group elements. //! - [`LinearRelation`]: a higher-level structure managing linear maps and their associated images. -use alloc::format; +/// Implementations of conversion operations such as From and FromIterator for var and term types. +mod convert; +/// Implementations of core ops for the linear combination types. +mod ops; + +/// Implementation of canonical linear relation. +mod canonical; +pub use canonical::CanonicalLinearRelation; + +/// Collections for group elements and scalars, used in the linear maps. +pub(crate) mod collections; +pub use collections::{GroupMap, ScalarAssignments, ScalarMap}; + +mod allocator; +pub use allocator::Allocator; + use alloc::vec::Vec; +use collections::{UnassignedGroupVarError, UnassignedScalarVarError}; use core::iter; use core::marker::PhantomData; @@ -19,39 +35,42 @@ use group::prime::PrimeGroup; use crate::codec::Shake128DuplexSponge; use crate::errors::{Error, InvalidInstance}; use crate::group::msm::VariableMultiScalarMul; +use crate::linear_relation::allocator::Heap; use crate::Nizk; -/// Implementations of conversion operations such as From and FromIterator for var and term types. -mod convert; -/// Implementations of core ops for the linear combination types. -mod ops; - -/// Implementation of canonical linear relation. -mod canonical; -pub use canonical::CanonicalLinearRelation; - -/// A wrapper representing an index for a scalar variable. +/// A wrapper representing an reference for a scalar variable. /// /// Used to reference scalars in sparse linear combinations. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] pub struct ScalarVar(usize, PhantomData); impl ScalarVar { + // QUESTION: Should I mark this method as deprecated? It currently leaks the internal + // representation of the variable and may not be stable. It's not clear what valid use cases + // there are for this index. pub fn index(&self) -> usize { self.0 } } +// Implement copy and clone for all G +impl Copy for ScalarVar {} +impl Clone for ScalarVar { + fn clone(&self) -> Self { + *self + } +} + impl core::hash::Hash for ScalarVar { fn hash(&self, state: &mut H) { self.0.hash(state) } } -/// A wrapper representing an index for a group element (point). +/// A wrapper representing a reference for a group element (i.e. elliptic curve point). /// /// Used to reference group elements in sparse linear combinations. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] pub struct GroupVar(usize, PhantomData); impl GroupVar { @@ -60,6 +79,14 @@ impl GroupVar { } } +// Implement copy and clone for all G +impl Copy for GroupVar {} +impl Clone for GroupVar { + fn clone(&self) -> Self { + *self + } +} + impl core::hash::Hash for GroupVar { fn hash(&self, state: &mut H) { self.0.hash(state) @@ -73,13 +100,16 @@ pub enum ScalarTerm { } impl ScalarTerm { - // NOTE: This function is private intentionally as it would be replaced if a ScalarMap struct - // were to be added. - fn value(self, scalars: &[G::Scalar]) -> G::Scalar { - match self { - Self::Var(var) => scalars[var.0], + // TODO: Move this function onto ScalarMap instead? Maybe ScalarMap should have an associated + // valuation function. + fn value( + self, + scalars: &impl ScalarAssignments, + ) -> Result { + Ok(match self { + Self::Var(var) => scalars.get(var)?, Self::Unit => G::Scalar::ONE, - } + }) } } @@ -124,235 +154,41 @@ impl core::iter::Sum for Sum { /// where: /// - `(s_i * P_i)` are the terms, with `s_i` scalars (referenced by `scalar_vars`) and `P_i` group elements (referenced by `element_vars`). /// - `w_i` are the constant weight scalars -/// -/// The indices refer to external lists managed by the containing LinearMap. pub type LinearCombination = Sum, ::Scalar>>; -impl LinearMap { - fn map(&self, scalars: &[G::Scalar]) -> Result, InvalidInstance> { - self.linear_combinations - .iter() - .map(|lc| { - let weighted_coefficients = - lc.0.iter() - .map(|weighted| weighted.term.scalar.value(scalars) * weighted.weight) - .collect::>(); - let elements = - lc.0.iter() - .map(|weighted| self.group_elements.get(weighted.term.elem)) - .collect::, InvalidInstance>>(); - match elements { - Ok(elements) => Ok(G::msm(&weighted_coefficients, &elements)), - Err(error) => Err(error), - } - }) - .collect::, InvalidInstance>>() - } -} - -/// Ordered mapping of [GroupVar] to group elements assignments. -#[derive(Clone, Debug)] -pub struct GroupMap(Vec>); - -impl GroupMap { - /// Assign a group element value to a point variable. - /// - /// # Parameters - /// - /// - `var`: The variable to assign. - /// - `element`: The value to assign to the variable. - /// - /// # Panics - /// - /// Panics if the given assignment conflicts with the existing assignment. - pub fn assign_element(&mut self, var: GroupVar, element: G) { - if self.0.len() <= var.0 { - self.0.resize(var.0 + 1, None); - } else if let Some(assignment) = self.0[var.0] { - assert_eq!( - assignment, element, - "conflicting assignments for var {var:?}" - ) - } - self.0[var.0] = Some(element); - } - - /// Assigns specific group elements to point variables (indices). - /// - /// # Parameters - /// - /// - `assignments`: A collection of `(GroupVar, GroupElement)` pairs that can be iterated over. - /// - /// # Panics - /// - /// Panics if the collection contains two conflicting assignments for the same variable. - pub fn assign_elements(&mut self, assignments: impl IntoIterator, G)>) { - for (var, elem) in assignments.into_iter() { - self.assign_element(var, elem); - } - } - - /// Get the element value assigned to the given point var. - /// - /// Returns [`InvalidInstance`] if a value is not assigned. - pub fn get(&self, var: GroupVar) -> Result { - match self.0.get(var.0) { - Some(Some(elem)) => Ok(*elem), - Some(None) | None => Err(InvalidInstance::new(format!( - "unassigned group variable {}", - var.0 - ))), - } - } - - /// Iterate over the assigned variable and group element pairs in this mapping. - // NOTE: Not implemented as `IntoIterator` for now because doing so requires explicitly - // defining an iterator type, See https://github.com/rust-lang/rust/issues/63063 - #[allow(clippy::should_implement_trait)] - pub fn into_iter(self) -> impl Iterator, Option)> { - self.0 - .into_iter() - .enumerate() - .map(|(i, x)| (GroupVar(i, PhantomData), x)) - } - - pub fn iter(&self) -> impl Iterator, Option<&G>)> { - self.0 - .iter() - .enumerate() - .map(|(i, opt)| (GroupVar(i, PhantomData), opt.as_ref())) - } - - /// Add a new group element to the map and return its variable index - pub fn push(&mut self, element: G) -> GroupVar { - let index = self.0.len(); - self.0.push(Some(element)); - GroupVar(index, PhantomData) - } - - /// Get the number of elements in the map - pub fn len(&self) -> usize { - self.0.len() - } - - /// Check if the map is empty - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } -} - -impl Default for GroupMap { - fn default() -> Self { - Self(Vec::default()) - } -} - -impl FromIterator<(GroupVar, G)> for GroupMap { - fn from_iter, G)>>(iter: T) -> Self { - iter.into_iter() - .fold(Self::default(), |mut instance, (var, val)| { - instance.assign_element(var, val); - instance - }) - } -} - -/// A LinearMap represents a list of linear combinations over group elements. +/// This structure represents the *preimage problem* for a group linear map: given a set of scalar inputs, +/// determine whether their image under the linear map matches a target set of group elements. /// -/// It supports dynamic allocation of scalars and elements, -/// and evaluates by performing multi-scalar multiplications. +/// Internally, the constraint system is defined through: +/// - A list of group elements and linear equations. +/// - A list of [`GroupVar`] references (`image`) that specify the expected output for each constraint. +#[non_exhaustive] #[derive(Clone, Default, Debug)] -pub struct LinearMap { +pub struct LinearRelation> { /// The set of linear combination constraints (equations). pub linear_combinations: Vec>, - /// The list of group elements referenced in the linear map. - /// - /// Uninitialized group elements are presented with `None`. - pub group_elements: GroupMap, - /// The total number of scalar variables allocated. - pub num_scalars: usize, - /// The total number of group element variables allocated. - pub num_elements: usize, + pub heap: A, + /// References pointing to elements representing the "target" images for each constraint. + pub image: Vec>, } -impl LinearMap { - /// Creates a new empty [`LinearMap`]. - /// - /// # Returns - /// - /// A [`LinearMap`] instance with empty linear combinations and group elements, - /// and zero allocated scalars and elements. +impl LinearRelation { + /// Create a new empty [`LinearRelation`]. pub fn new() -> Self { Self { linear_combinations: Vec::new(), - group_elements: GroupMap::default(), - num_scalars: 0, - num_elements: 0, + heap: Default::default(), + image: Vec::new(), } } - - /// Returns the number of constraints (equations) in this linear map. - pub fn num_constraints(&self) -> usize { - self.linear_combinations.len() - } - - /// Adds a new linear combination constraint to the linear map. - /// - /// # Parameters - /// - `lc`: The [`LinearCombination`] to add. - pub fn append(&mut self, lc: LinearCombination) { - self.linear_combinations.push(lc); - } - - /// Evaluates all linear combinations in the linear map with the provided scalars. - /// - /// # Parameters - /// - `scalars`: A slice of scalar values corresponding to the scalar variables. - /// - /// # Returns - /// - /// A vector of group elements, each being the result of evaluating one linear combination with the scalars. - pub fn evaluate(&self, scalars: &[G::Scalar]) -> Result, Error> { - self.linear_combinations - .iter() - .map(|lc| { - // TODO: The multiplication by the (public) weight is potentially wasteful in the - // weight is most commonly 1, but multiplication is constant time. - let weighted_coefficients = - lc.0.iter() - .map(|weighted| weighted.term.scalar.value(scalars) * weighted.weight) - .collect::>(); - let elements = - lc.0.iter() - .map(|weighted| self.group_elements.get(weighted.term.elem)) - .collect::, _>>()?; - Ok(G::msm(&weighted_coefficients, &elements)) - }) - .collect() - } } -/// A wrapper struct coupling a [`LinearMap`] with the corresponding expected output (image) elements. -/// -/// This structure represents the *preimage problem* for a group linear map: given a set of scalar inputs, -/// determine whether their image under the linear map matches a target set of group elements. -/// -/// Internally, the constraint system is defined through: -/// - A list of group elements and linear equations (held in the [`LinearMap`] field), -/// - A list of [`GroupVar`] indices (`image`) that specify the expected output for each constraint. -#[derive(Clone, Default, Debug)] -pub struct LinearRelation { - /// The underlying linear map describing the structure of the statement. - pub linear_map: LinearMap, - /// Indices pointing to elements representing the "target" images for each constraint. - pub image: Vec>, -} - -impl LinearRelation { - /// Create a new empty [`LinearRelation`]. - pub fn new() -> Self { +impl> LinearRelation { + /// Create a new empty [`LinearRelation`] using the given [`Allocator`]. + pub fn new_in(allocator: A) -> Self { Self { - linear_map: LinearMap::new(), + linear_combinations: Vec::new(), + heap: allocator, image: Vec::new(), } } @@ -364,7 +200,7 @@ impl LinearRelation { /// - `lhs`: The image group element variable (left-hand side of the equation). /// - `rhs`: An instance of [`LinearCombination`] representing the linear combination on the right-hand side. pub fn append_equation(&mut self, lhs: GroupVar, rhs: impl Into>) { - self.linear_map.append(rhs.into()); + self.linear_combinations.push(rhs.into()); self.image.push(lhs); } @@ -379,104 +215,13 @@ impl LinearRelation { var } - /// Allocates a scalar variable for use in the linear map. - pub fn allocate_scalar(&mut self) -> ScalarVar { - self.linear_map.num_scalars += 1; - ScalarVar(self.linear_map.num_scalars - 1, PhantomData) - } - - /// Allocates space for `N` new scalar variables. - /// - /// # Returns - /// An array of [`ScalarVar`] representing the newly allocated scalar indices. - /// - /// # Example - /// ``` - /// # use sigma_proofs::LinearRelation; - /// use curve25519_dalek::RistrettoPoint as G; - /// - /// let mut relation = LinearRelation::::new(); - /// let [var_x, var_y] = relation.allocate_scalars(); - /// let vars = relation.allocate_scalars::<10>(); - /// ``` - pub fn allocate_scalars(&mut self) -> [ScalarVar; N] { - let mut vars = [ScalarVar(usize::MAX, PhantomData); N]; - for var in vars.iter_mut() { - *var = self.allocate_scalar(); - } - vars - } - - /// Allocates a vector of new scalar variables. - /// - /// # Returns - /// A vector of [`ScalarVar`] representing the newly allocated scalar indices. - /// /// # Example - /// ``` - /// # use sigma_proofs::LinearRelation; - /// use curve25519_dalek::RistrettoPoint as G; - /// - /// let mut relation = LinearRelation::::new(); - /// let [var_x, var_y] = relation.allocate_scalars(); - /// let vars = relation.allocate_scalars_vec(10); - /// ``` - pub fn allocate_scalars_vec(&mut self, n: usize) -> Vec> { - (0..n).map(|_| self.allocate_scalar()).collect() - } - - /// Allocates a point variable (group element) for use in the linear map. - pub fn allocate_element(&mut self) -> GroupVar { - self.linear_map.num_elements += 1; - GroupVar(self.linear_map.num_elements - 1, PhantomData) - } - - /// Allocates a point variable (group element) and sets it immediately to the given value + /// Allocates a group element variable (i.e. elliptic curve point) and sets it immediately to the given value pub fn allocate_element_with(&mut self, element: G) -> GroupVar { let var = self.allocate_element(); - self.set_element(var, element); + self.assign_element(var, element); var } - /// Allocates `N` point variables (group elements) for use in the linear map. - /// - /// # Returns - /// An array of [`GroupVar`] representing the newly allocated group element indices. - /// - /// # Example - /// ``` - /// # use sigma_proofs::LinearRelation; - /// use curve25519_dalek::RistrettoPoint as G; - /// - /// let mut relation = LinearRelation::::new(); - /// let [var_g, var_h] = relation.allocate_elements(); - /// let vars = relation.allocate_elements::<10>(); - /// ``` - pub fn allocate_elements(&mut self) -> [GroupVar; N] { - let mut vars = [GroupVar(usize::MAX, PhantomData); N]; - for var in vars.iter_mut() { - *var = self.allocate_element(); - } - vars - } - - /// Allocates a vector of new point variables (group elements). - /// - /// # Returns - /// A vector of [`GroupVar`] representing the newly allocated group element indices. - /// - /// # Example - /// ``` - /// # use sigma_proofs::LinearRelation; - /// use curve25519_dalek::RistrettoPoint as G; - /// let mut relation = LinearRelation::::new(); - /// let [var_g, var_h - /// ] = relation.allocate_elements(); - /// let vars = relation.allocate_elements_vec(10); - /// ``` - pub fn allocate_elements_vec(&mut self, n: usize) -> Vec> { - (0..n).map(|_| self.allocate_element()).collect() - } - /// Allocates a point variable (group element) and sets it immediately to the given value. pub fn allocate_elements_with(&mut self, elements: &[G]) -> Vec> { elements @@ -485,33 +230,6 @@ impl LinearRelation { .collect() } - /// Assign a group element value to a point variable. - /// - /// # Parameters - /// - /// - `var`: The variable to assign. - /// - `element`: The value to assign to the variable. - /// - /// # Panics - /// - /// Panics if the given assignment conflicts with the existing assignment. - pub fn set_element(&mut self, var: GroupVar, element: G) { - self.linear_map.group_elements.assign_element(var, element) - } - - /// Assigns specific group elements to point variables (indices). - /// - /// # Parameters - /// - /// - `assignments`: A collection of `(GroupVar, GroupElement)` pairs that can be iterated over. - /// - /// # Panics - /// - /// Panics if the collection contains two conflicting assignments for the same variable. - pub fn set_elements(&mut self, assignments: impl IntoIterator, G)>) { - self.linear_map.group_elements.assign_elements(assignments) - } - /// Evaluates all linear combinations in the linear map with the provided scalars, computing the /// left-hand side of this constraints (i.e. the image). /// @@ -525,20 +243,17 @@ impl LinearRelation { /// /// Return `Ok` on success, and an error if unassigned elements prevent the image from being /// computed. Modifies the group elements assigned in the [LinearRelation]. - pub fn compute_image(&mut self, scalars: &[G::Scalar]) -> Result<(), Error> { - if self.linear_map.num_constraints() != self.image.len() { + pub fn compute_image(&mut self, scalars: impl ScalarAssignments) -> Result<(), Error> { + if self.linear_combinations.len() != self.image.len() { // NOTE: This is a panic, rather than a returned error, because this can only happen if // this implementation has a bug. panic!("invalid LinearRelation: different number of constraints and image variables"); } - let mapped_scalars = self.linear_map.map(scalars)?; + let mapped_scalars: Vec<(GroupVar, G)> = + iter::zip(self.image.iter().copied(), self.evaluate(scalars)?).collect(); - for (mapped_scalar, lhs) in iter::zip(mapped_scalars, &self.image) { - self.linear_map - .group_elements - .assign_element(*lhs, mapped_scalar) - } + self.heap.assign_elements(mapped_scalars); Ok(()) } @@ -548,10 +263,43 @@ impl LinearRelation { /// /// A vector of group elements (`Vec`) representing the linear map's image. // TODO: Should this return GroupMap? - pub fn image(&self) -> Result, InvalidInstance> { + pub fn image(&self) -> Result, UnassignedGroupVarError> { self.image .iter() - .map(|&var| self.linear_map.group_elements.get(var)) + .map(|&var| self.heap.get_element(var)) + .collect() + } + + /// Evaluates all linear combinations in the linear relation with the provided scalars. + /// + /// # Parameters + /// - `scalars`: A slice of scalar values corresponding to the scalar variables. + /// + /// # Returns + /// + /// A vector of group elements, each being the result of evaluating one linear combination with the scalars. + pub fn evaluate(&self, scalars: impl ScalarAssignments) -> Result, Error> { + self.linear_combinations + .iter() + .map(|lc| { + // TODO: The multiplication by the (public) weight is potentially wasteful in the + // weight is most commonly 1, but multiplication is constant time. + let weighted_coefficients = + lc.0.iter() + .map(|weighted| { + weighted + .term + .scalar + .value(&scalars) + .map(|scalar| scalar * weighted.weight) + }) + .collect::, UnassignedScalarVarError>>()?; + let elements = + lc.0.iter() + .map(|weighted| self.heap.get_element(weighted.term.elem)) + .collect::, _>>()?; + Ok(G::msm(&weighted_coefficients, &elements)) + }) .collect() } @@ -579,11 +327,11 @@ impl LinearRelation { /// /// relation.set_element(g_var, G::generator()); /// let x = Scalar::random(&mut OsRng); - /// relation.compute_image(&[x]).unwrap(); + /// relation.compute_image([(x_var, x)]).unwrap(); /// /// // Convert to NIZK with custom context /// let nizk = relation.into_nizk(b"my-protocol-v1").unwrap(); - /// let proof = nizk.prove_batchable(&vec![x], &mut OsRng).unwrap(); + /// let proof = nizk.prove_batchable([(x_var, x)], &mut OsRng).unwrap(); /// assert!(nizk.verify_batchable(&proof).is_ok()); /// ``` pub fn into_nizk( @@ -600,3 +348,25 @@ impl LinearRelation { self.try_into() } } + +impl> Allocator for LinearRelation { + type G = G; + + /// Allocates a scalar variable for use in the linear map. + fn allocate_scalar(&mut self) -> ScalarVar { + self.heap.allocate_scalar() + } + + /// Allocates a group element variable (i.e. elliptic curve point) for use in the linear map. + fn allocate_element(&mut self) -> GroupVar { + self.heap.allocate_element() + } + + fn assign_element(&mut self, var: GroupVar, element: Self::G) { + self.heap.assign_element(var, element) + } + + fn get_element(&self, var: GroupVar) -> Result { + self.heap.get_element(var) + } +} diff --git a/src/schnorr_protocol.rs b/src/schnorr_protocol.rs index ea54e74..13c9fea 100644 --- a/src/schnorr_protocol.rs +++ b/src/schnorr_protocol.rs @@ -8,7 +8,7 @@ use crate::errors::Error; use crate::group::serialization::{ deserialize_elements, deserialize_scalars, serialize_elements, serialize_scalars, }; -use crate::linear_relation::CanonicalLinearRelation; +use crate::linear_relation::{CanonicalLinearRelation, ScalarMap}; use crate::traits::{SigmaProtocol, SigmaProtocolSimulator}; use alloc::vec::Vec; @@ -21,10 +21,14 @@ use rand_core::{CryptoRng, RngCore, RngCore as Rng}; impl SigmaProtocol for CanonicalLinearRelation { type Commitment = Vec; - type ProverState = (Vec, Vec); - type Response = Vec; - type Witness = Vec; type Challenge = G::Scalar; + /// Prover response to the challenge. Includes one scalar per witness scalar. + // NOTE: This could be a ScalarMap in that each scalar here is a associated with a variable, + // however this type is part of the public interface and is linked to the wire format. + type Response = Vec; + /// Prover state is a pair of (nonces, witness). Each scalar in the witness has a nonce. + type ProverState = (ScalarMap, Self::Witness); + type Witness = ScalarMap; /// Prover's first message: generates a commitment using random nonces. /// @@ -43,10 +47,10 @@ impl SigmaProtocol for CanonicalLinearRelation { /// If the witness vector is larger, extra variables are ignored. fn prover_commit( &self, - witness: &Self::Witness, + witness: Self::Witness, rng: &mut (impl RngCore + CryptoRng), ) -> Result<(Self::Commitment, Self::ProverState), Error> { - if witness.len() < self.num_scalars { + if witness.len() < self.scalar_vars.len() { return Err(Error::InvalidInstanceWitnessPair); } @@ -61,12 +65,13 @@ impl SigmaProtocol for CanonicalLinearRelation { return Err(Error::InvalidInstanceWitnessPair); } - let nonces = (0..self.num_scalars) - .map(|_| G::Scalar::random(&mut *rng)) - .collect::>(); + let nonces = witness + .vars() + .map(|var| (var, G::Scalar::random(&mut *rng))) + .collect::>(); let commitment = self.evaluate(&nonces); - let prover_state = (nonces.to_vec(), witness.to_vec()); + let prover_state = (nonces, witness.clone()); Ok((commitment, prover_state)) } @@ -88,10 +93,12 @@ impl SigmaProtocol for CanonicalLinearRelation { ) -> Result { let (nonces, witness) = prover_state; + // NOTE: It should only be possible to fail to unwrap here if there is an error in this + // library, or if it is used in an unintended way (e.g. manually constructing the prover + // state). Also note that this drops the explicit link with a given variable. let responses = nonces - .into_iter() - .zip(witness) - .map(|(r, w)| r + w * challenge) + .zip(&witness) + .map(|(_, r, w)| r.unwrap() + w.unwrap() * challenge) .collect(); Ok(responses) } @@ -117,11 +124,18 @@ impl SigmaProtocol for CanonicalLinearRelation { challenge: &Self::Challenge, response: &Self::Response, ) -> Result<(), Error> { - if commitment.len() != self.image.len() || response.len() != self.num_scalars { + if commitment.len() != self.image.len() || response.len() != self.scalar_vars.len() { return Err(Error::InvalidInstanceWitnessPair); } - let lhs = self.evaluate(response); + let response_map = self + .scalar_vars + .iter() + .copied() + .zip(response.iter().copied()) + .collect::>(); + + let lhs = self.evaluate(response_map); let mut rhs = Vec::new(); for (img, g) in self.image_elements().zip(commitment) { rhs.push(img * challenge + g); @@ -226,7 +240,7 @@ impl SigmaProtocol for CanonicalLinearRelation { /// # Errors /// - Returns [`Error::VerificationFailure`] if the byte data is malformed or the length is incorrect. fn deserialize_response(&self, data: &[u8]) -> Result { - deserialize_scalars::(data, self.num_scalars).ok_or(Error::VerificationFailure) + deserialize_scalars::(data, self.scalar_vars.len()).ok_or(Error::VerificationFailure) } fn instance_label(&self) -> impl AsRef<[u8]> { @@ -251,7 +265,9 @@ where /// # Returns /// - A commitment and response forming a valid proof for the given challenge. fn simulate_response(&self, rng: &mut R) -> Self::Response { - let response: Vec = (0..self.num_scalars) + let response: Vec = self + .scalar_vars + .iter() .map(|_| G::Scalar::random(&mut *rng)) .collect(); response @@ -290,11 +306,18 @@ where challenge: &Self::Challenge, response: &Self::Response, ) -> Result { - if response.len() != self.num_scalars { + if response.len() != self.scalar_vars.len() { return Err(Error::InvalidInstanceWitnessPair); } - let response_image = self.evaluate(response); + let response_map = self + .scalar_vars + .iter() + .copied() + .zip(response.iter().copied()) + .collect::>(); + + let response_image = self.evaluate(response_map); let commitment = response_image .iter() .zip(self.image_elements()) diff --git a/src/tests/spec/custom_schnorr_protocol.rs b/src/tests/spec/custom_schnorr_protocol.rs index 330d681..a26b0b5 100644 --- a/src/tests/spec/custom_schnorr_protocol.rs +++ b/src/tests/spec/custom_schnorr_protocol.rs @@ -2,16 +2,18 @@ use group::prime::PrimeGroup; use rand::{CryptoRng, Rng}; use crate::errors::Error; -use crate::linear_relation::{CanonicalLinearRelation, LinearRelation}; +use crate::linear_relation::{Allocator, CanonicalLinearRelation, LinearRelation, ScalarMap}; use crate::tests::spec::random::SRandom; use crate::traits::{SigmaProtocol, SigmaProtocolSimulator}; pub struct DeterministicSchnorrProof(pub CanonicalLinearRelation); -impl TryFrom> for DeterministicSchnorrProof { +impl> TryFrom> + for DeterministicSchnorrProof +{ type Error = Error; - fn try_from(linear_relation: LinearRelation) -> Result { + fn try_from(linear_relation: LinearRelation) -> Result { let relation = CanonicalLinearRelation::try_from(&linear_relation)?; Ok(Self(relation)) } @@ -32,15 +34,15 @@ impl SigmaProtocol for DeterministicSchnorrProof { fn prover_commit( &self, - witness: &Self::Witness, + witness: Self::Witness, rng: &mut (impl Rng + CryptoRng), ) -> Result<(Self::Commitment, Self::ProverState), Error> { - let mut nonces: Vec = Vec::new(); - for _i in 0..self.0.num_scalars { - nonces.push(::random_scalar_elt(rng)); - } + let nonces = witness + .vars() + .map(|var| (var, ::random_scalar_elt(rng))) + .collect::>(); let commitment = self.0.evaluate(&nonces); - let prover_state = (nonces.to_vec(), witness.to_vec()); + let prover_state = (nonces, witness.clone()); Ok((commitment, prover_state)) } diff --git a/src/tests/spec/test_vectors.rs b/src/tests/spec/test_vectors.rs index c21c9ca..a4e4af9 100644 --- a/src/tests/spec/test_vectors.rs +++ b/src/tests/spec/test_vectors.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::codec::KeccakByteSchnorrCodec; use crate::fiat_shamir::Nizk; -use crate::linear_relation::CanonicalLinearRelation; +use crate::linear_relation::{CanonicalLinearRelation, ScalarMap}; use crate::tests::spec::{custom_schnorr_protocol::DeterministicSchnorrProof, rng::TestDRNG}; type SchnorrNizk = Nizk, KeccakByteSchnorrCodec>; @@ -58,11 +58,17 @@ fn test_spec_testvectors() { .expect("Failed to parse statement"); // Decode the witness from the test vector - let witness = crate::group::serialization::deserialize_scalars::( + let witness_vec = crate::group::serialization::deserialize_scalars::( &vector.witness, - parsed_instance.num_scalars, + parsed_instance.scalar_vars.len(), ) .expect("Failed to deserialize witness"); + let witness = parsed_instance + .scalar_vars + .iter() + .copied() + .zip(witness_vec) + .collect::>(); // Verify the parsed instance can be re-serialized to the same label assert_eq!( @@ -91,7 +97,7 @@ fn test_spec_testvectors() { // Generate proof with the proof generation RNG let mut proof_rng = TestDRNG::new(proof_generation_rng_seed); - let proof_bytes = nizk.prove_batchable(&witness, &mut proof_rng).unwrap(); + let proof_bytes = nizk.prove_batchable(witness, &mut proof_rng).unwrap(); // Verify the proof matches assert_eq!( diff --git a/src/tests/test_composition.rs b/src/tests/test_composition.rs index 9d27318..ff0f113 100644 --- a/src/tests/test_composition.rs +++ b/src/tests/test_composition.rs @@ -26,9 +26,10 @@ fn test_composition_example() { let (relation4, witness4) = pedersen_commitment(&mut rng); let (relation5, witness5) = bbs_blind_commitment(&mut rng); - let wrong_witness2 = (0..witness2.len()) - .map(|_| ::Scalar::random(&mut rng)) - .collect::>(); + let wrong_witness2 = witness2 + .vars() + .map(|var| (var, ::Scalar::random(&mut rng))) + .collect(); // second layer protocol definitions let or_protocol1 = ComposedRelation::::or([relation1, relation2]); let or_witness1 = ComposedWitness::or([witness1, wrong_witness2]); @@ -43,8 +44,8 @@ fn test_composition_example() { let nizk = instance.into_nizk(domain_sep); // Batchable and compact proofs - let proof_batchable_bytes = nizk.prove_batchable(&witness, &mut rng).unwrap(); - let proof_compact_bytes = nizk.prove_compact(&witness, &mut rng).unwrap(); + let proof_batchable_bytes = nizk.prove_batchable(witness.clone(), &mut rng).unwrap(); + let proof_compact_bytes = nizk.prove_compact(witness, &mut rng).unwrap(); // Verify proofs assert!(nizk.verify_batchable(&proof_batchable_bytes).is_ok()); assert!(nizk.verify_compact(&proof_compact_bytes).is_ok()); @@ -60,12 +61,14 @@ fn test_or_one_true() { let (relation1, witness1) = dleq::(&mut rng); let (relation2, witness2) = dleq::(&mut rng); - let wrong_witness1 = (0..witness1.len()) - .map(|_| ::Scalar::random(&mut rng)) - .collect::>(); - let wrong_witness2 = (0..witness2.len()) - .map(|_| ::Scalar::random(&mut rng)) - .collect::>(); + let wrong_witness1 = witness1 + .vars() + .map(|var| (var, ::Scalar::random(&mut rng))) + .collect(); + let wrong_witness2 = witness2 + .vars() + .map(|var| (var, ::Scalar::random(&mut rng))) + .collect(); let or_protocol = ComposedRelation::or([relation1, relation2]); @@ -77,8 +80,8 @@ fn test_or_one_true() { for witness in [witness_or_1, witness_or_2] { // Batchable and compact proofs - let proof_batchable_bytes = nizk.prove_batchable(&witness, &mut rng).unwrap(); - let proof_compact_bytes = nizk.prove_compact(&witness, &mut rng).unwrap(); + let proof_batchable_bytes = nizk.prove_batchable(witness.clone(), &mut rng).unwrap(); + let proof_compact_bytes = nizk.prove_compact(witness, &mut rng).unwrap(); // Verify proofs assert!(nizk.verify_batchable(&proof_batchable_bytes).is_ok()); assert!(nizk.verify_compact(&proof_compact_bytes).is_ok()); @@ -101,8 +104,8 @@ fn test_or_both_true() { let nizk = or_protocol.into_nizk(b"test_or_both_true"); // Batchable and compact proofs - let proof_batchable_bytes = nizk.prove_batchable(&witness, &mut rng).unwrap(); - let proof_compact_bytes = nizk.prove_compact(&witness, &mut rng).unwrap(); + let proof_batchable_bytes = nizk.prove_batchable(witness.clone(), &mut rng).unwrap(); + let proof_compact_bytes = nizk.prove_compact(witness, &mut rng).unwrap(); // Verify proofs assert!(nizk.verify_batchable(&proof_batchable_bytes).is_ok()); assert!(nizk.verify_compact(&proof_compact_bytes).is_ok()); diff --git a/src/tests/test_relations.rs b/src/tests/test_relations.rs index 83dca64..aa501d6 100644 --- a/src/tests/test_relations.rs +++ b/src/tests/test_relations.rs @@ -1,16 +1,18 @@ +use std::iter; + use ff::Field; use group::prime::PrimeGroup; use rand::RngCore; use crate::codec::Shake128DuplexSponge; use crate::fiat_shamir::Nizk; -use crate::linear_relation::{CanonicalLinearRelation, LinearRelation, Sum}; +use crate::linear_relation::{Allocator, CanonicalLinearRelation, LinearRelation, ScalarMap, Sum}; /// LinearMap for knowledge of a discrete logarithm relative to a fixed basepoint. #[allow(non_snake_case)] pub fn discrete_logarithm( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let x = G::Scalar::random(rng); let mut relation = LinearRelation::new(); @@ -19,13 +21,13 @@ pub fn discrete_logarithm( let var_X = relation.allocate_eq(var_x * var_G); - relation.set_element(var_G, G::generator()); - relation.compute_image(&[x]).unwrap(); - - let X = relation.linear_map.group_elements.get(var_X).unwrap(); + relation.assign_element(var_G, G::generator()); + let witness = ScalarMap::from_iter([(var_x, x)]); + relation.compute_image(&witness).unwrap(); + let X = relation.get_element(var_X).unwrap(); assert_eq!(X, G::generator() * x); - let witness = vec![x]; + let instance = (&relation).try_into().unwrap(); (instance, witness) } @@ -34,7 +36,7 @@ pub fn discrete_logarithm( #[allow(non_snake_case)] pub fn shifted_dlog( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let x = G::Scalar::random(rng); let mut relation = LinearRelation::new(); @@ -45,19 +47,17 @@ pub fn shifted_dlog( // another way of writing this is: relation.append_equation(var_X, (var_x + G::Scalar::from(1)) * var_G); - relation.set_element(var_G, G::generator()); - relation.compute_image(&[x]).unwrap(); + relation.assign_element(var_G, G::generator()); + let witness = ScalarMap::from_iter([(var_x, x)]); + relation.compute_image(&witness).unwrap(); - let witness = vec![x]; let instance = (&relation).try_into().unwrap(); (instance, witness) } /// LinearMap for knowledge of a discrete logarithm equality between two pairs. #[allow(non_snake_case)] -pub fn dleq( - rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +pub fn dleq(rng: &mut R) -> (CanonicalLinearRelation, ScalarMap) { let H = G::random(&mut *rng); let x = G::Scalar::random(&mut *rng); let mut relation = LinearRelation::new(); @@ -68,15 +68,15 @@ pub fn dleq( let var_X = relation.allocate_eq(var_x * var_G); let var_Y = relation.allocate_eq(var_x * var_H); - relation.set_elements([(var_G, G::generator()), (var_H, H)]); - relation.compute_image(&[x]).unwrap(); - - let X = relation.linear_map.group_elements.get(var_X).unwrap(); - let Y = relation.linear_map.group_elements.get(var_Y).unwrap(); + relation.assign_elements([(var_G, G::generator()), (var_H, H)]); + let witness = ScalarMap::from_iter([(var_x, x)]); + relation.compute_image(&witness).unwrap(); + let X = relation.heap.elements.get(var_X).unwrap(); + let Y = relation.heap.elements.get(var_Y).unwrap(); assert_eq!(X, G::generator() * x); assert_eq!(Y, H * x); - let witness = vec![x]; + let instance = (&relation).try_into().unwrap(); (instance, witness) } @@ -85,7 +85,7 @@ pub fn dleq( #[allow(non_snake_case)] pub fn shifted_dleq( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let H = G::random(&mut *rng); let x = G::Scalar::random(&mut *rng); let mut relation = LinearRelation::new(); @@ -96,15 +96,15 @@ pub fn shifted_dleq( let var_X = relation.allocate_eq(var_x * var_G + var_H); let var_Y = relation.allocate_eq(var_x * var_H + var_G); - relation.set_elements([(var_G, G::generator()), (var_H, H)]); - relation.compute_image(&[x]).unwrap(); - - let X = relation.linear_map.group_elements.get(var_X).unwrap(); - let Y = relation.linear_map.group_elements.get(var_Y).unwrap(); + relation.assign_elements([(var_G, G::generator()), (var_H, H)]); + let witness = ScalarMap::from_iter([(var_x, x)]); + relation.compute_image(&witness).unwrap(); + let X = relation.heap.elements.get(var_X).unwrap(); + let Y = relation.heap.elements.get(var_Y).unwrap(); assert_eq!(X, G::generator() * x + H); assert_eq!(Y, H * x + G::generator()); - let witness = vec![x]; + let instance = (&relation).try_into().unwrap(); (instance, witness) } @@ -113,7 +113,7 @@ pub fn shifted_dleq( #[allow(non_snake_case)] pub fn pedersen_commitment( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let H = G::random(&mut *rng); let x = G::Scalar::random(&mut *rng); let r = G::Scalar::random(&mut *rng); @@ -124,13 +124,13 @@ pub fn pedersen_commitment( let var_C = relation.allocate_eq(var_x * var_G + var_r * var_H); - relation.set_elements([(var_H, H), (var_G, G::generator())]); - relation.compute_image(&[x, r]).unwrap(); - - let C = relation.linear_map.group_elements.get(var_C).unwrap(); + relation.assign_elements([(var_H, H), (var_G, G::generator())]); + let witness = ScalarMap::from_iter([(var_x, x), (var_r, r)]); + relation.compute_image(&witness).unwrap(); - let witness = vec![x, r]; + let C = relation.heap.elements.get(var_C).unwrap(); assert_eq!(C, G::generator() * x + H * r); + let instance = (&relation).try_into().unwrap(); (instance, witness) } @@ -138,7 +138,7 @@ pub fn pedersen_commitment( #[allow(non_snake_case)] pub fn twisted_pedersen_commitment( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let H = G::random(&mut *rng); let x = G::Scalar::random(&mut *rng); let r = G::Scalar::random(&mut *rng); @@ -152,10 +152,10 @@ pub fn twisted_pedersen_commitment( + (var_r * G::Scalar::from(2) + G::Scalar::from(3)) * var_H, ); - relation.set_elements([(var_H, H), (var_G, G::generator())]); - relation.compute_image(&[x, r]).unwrap(); + relation.assign_elements([(var_H, H), (var_G, G::generator())]); + let witness = ScalarMap::from_iter([(var_x, x), (var_r, r)]); + relation.compute_image(&witness).unwrap(); - let witness = vec![x, r]; let instance = (&relation).try_into().unwrap(); (instance, witness) } @@ -166,7 +166,7 @@ pub fn range_instance_generation( mut rng: &mut R, input: u64, range: std::ops::Range, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let G = G::generator(); let H = G::random(&mut rng); @@ -243,20 +243,20 @@ pub fn range_instance_generation( let s2 = (0..bases.len()) .map(|i| (G::Scalar::ONE - b[i]) * s[i]) .collect::>(); - let witness = [x, r] - .iter() - .chain(&b) - .chain(&s) - .chain(&s2) - .copied() - .collect::>(); - instance.set_elements([(var_G, G), (var_H, H)]); - instance.set_element(var_C, G * x + H * r); + instance.assign_elements([(var_G, G), (var_H, H)]); + instance.assign_element(var_C, G * x + H * r); for i in 0..bases.len() { - instance.set_element(var_Ds[i], G * b[i] + H * s[i]); + instance.assign_element(var_Ds[i], G * b[i] + H * s[i]); } + let witness = [(var_x, x), (var_r, r)] + .into_iter() + .chain(iter::zip(vars_b, b)) + .chain(iter::zip(vars_s, s)) + .chain(iter::zip(var_s2, s2)) + .collect::>(); + (instance.canonical().unwrap(), witness) } @@ -264,7 +264,7 @@ pub fn range_instance_generation( #[allow(non_snake_case)] pub fn test_range( mut rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { range_instance_generation(&mut rng, 822, 0..1337) } @@ -273,7 +273,7 @@ pub fn test_range( #[allow(non_snake_case)] pub fn bbs_blind_commitment( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let [Q_2, J_1, J_2, J_3] = [ G::random(&mut *rng), G::random(&mut *rng), @@ -307,7 +307,7 @@ pub fn bbs_blind_commitment( + var_msg_3 * var_J_3, ); - relation.set_elements([ + relation.assign_elements([ (var_Q_2, Q_2), (var_J_1, J_1), (var_J_2, J_2), @@ -315,9 +315,14 @@ pub fn bbs_blind_commitment( (var_C, C), ]); - let witness = vec![secret_prover_blind, msg_1, msg_2, msg_3]; + let witness = ScalarMap::from_iter([ + (var_secret_prover_blind, secret_prover_blind), + (var_msg_1, msg_1), + (var_msg_2, msg_2), + (var_msg_3, msg_3), + ]); - assert!(vec![C] == relation.linear_map.evaluate(&witness).unwrap()); + assert!(vec![C] == relation.evaluate(&witness).unwrap()); let instance = (&relation).try_into().unwrap(); (instance, witness) } @@ -325,7 +330,7 @@ pub fn bbs_blind_commitment( /// LinearMap for the user's specific relation: A * 1 + gen__disj1_x_r * B pub fn weird_linear_combination( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let B = G::random(&mut *rng); let gen__disj1_x_r = G::Scalar::random(&mut *rng); let mut sigma__lr = LinearRelation::new(); @@ -337,23 +342,23 @@ pub fn weird_linear_combination( let sigma__eq1 = sigma__lr.allocate_eq(A * G::Scalar::from(1) + gen__disj1_x_r_var * var_B); // Set the group elements - sigma__lr.set_elements([(A, G::generator()), (var_B, B)]); - sigma__lr.compute_image(&[gen__disj1_x_r]).unwrap(); + sigma__lr.assign_elements([(A, G::generator()), (var_B, B)]); + let witness = ScalarMap::from_iter([(gen__disj1_x_r_var, gen__disj1_x_r)]); + sigma__lr.compute_image(&witness).unwrap(); - let result = sigma__lr.linear_map.group_elements.get(sigma__eq1).unwrap(); + let result = sigma__lr.heap.elements.get(sigma__eq1).unwrap(); // Verify the relation computes correctly let expected = G::generator() + B * gen__disj1_x_r; assert_eq!(result, expected); - let witness = vec![gen__disj1_x_r]; let instance = (&sigma__lr).try_into().unwrap(); (instance, witness) } fn simple_subtractions( mut rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let x = G::Scalar::random(&mut rng); let B = G::random(&mut rng); let X = B * (x - G::Scalar::from(1)); @@ -362,17 +367,17 @@ fn simple_subtractions( let var_x = linear_relation.allocate_scalar(); let var_B = linear_relation.allocate_element(); let var_X = linear_relation.allocate_eq((var_x + (-G::Scalar::from(1))) * var_B); - linear_relation.set_element(var_B, B); - linear_relation.set_element(var_X, X); + linear_relation.assign_element(var_B, B); + linear_relation.assign_element(var_X, X); let instance = (&linear_relation).try_into().unwrap(); - let witness = vec![x]; + let witness = ScalarMap::from_iter([(var_x, x)]); (instance, witness) } fn subtractions_with_shift( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let B = G::generator(); let x = G::Scalar::random(rng); let X = B * (x - G::Scalar::from(2)); @@ -382,17 +387,17 @@ fn subtractions_with_shift( let var_B = linear_relation.allocate_element(); let var_X = linear_relation.allocate_eq((var_x + (-G::Scalar::from(1))) * var_B + (-var_B)); - linear_relation.set_element(var_B, B); - linear_relation.set_element(var_X, X); + linear_relation.assign_element(var_B, B); + linear_relation.assign_element(var_X, X); let instance = (&linear_relation).try_into().unwrap(); - let witness = vec![x]; + let witness = ScalarMap::from_iter([(var_x, x)]); (instance, witness) } #[allow(non_snake_case)] fn cmz_wallet_spend_relation( mut rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { // Simulate the wallet spend relation from cmz let P_W = G::random(&mut rng); let A = G::random(&mut rng); @@ -419,25 +424,27 @@ fn cmz_wallet_spend_relation( let var_C = relation .allocate_eq((var_n_balance + var_i_price + fee) * var_P_W + var_z_w_balance * var_A); - relation.set_elements([(var_P_W, P_W), (var_A, A)]); + relation.assign_elements([(var_P_W, P_W), (var_A, A)]); // Include fee in the witness - relation - .compute_image(&[n_balance, i_price, z_w_balance]) - .unwrap(); + let witness = ScalarMap::from_iter([ + (var_n_balance, n_balance), + (var_i_price, i_price), + (var_z_w_balance, z_w_balance), + ]); + relation.compute_image(&witness).unwrap(); - let C = relation.linear_map.group_elements.get(var_C).unwrap(); + let C = relation.heap.elements.get(var_C).unwrap(); let expected = P_W * w_balance + A * z_w_balance; assert_eq!(C, expected); - let witness = vec![n_balance, i_price, z_w_balance]; let instance = (&relation).try_into().unwrap(); (instance, witness) } fn nested_affine_relation( mut rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let mut instance = LinearRelation::::new(); let var_r = instance.allocate_scalar(); let var_A = instance.allocate_element(); @@ -450,18 +457,18 @@ fn nested_affine_relation( let B = G::random(&mut rng); let r = G::Scalar::random(&mut rng); let C = A * G::Scalar::from(4) + B * (r * G::Scalar::from(2) + G::Scalar::from(3)); - instance.set_element(var_A, A); - instance.set_element(var_B, B); - instance.set_element(eq1, C); + instance.assign_element(var_A, A); + instance.assign_element(var_B, B); + instance.assign_element(eq1, C); - let witness = vec![r]; + let witness = ScalarMap::from_iter([(var_r, r)]); let instance = CanonicalLinearRelation::try_from(&instance).unwrap(); (instance, witness) } fn pedersen_commitment_equality( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let mut instance = LinearRelation::new(); let [m, r1, r2] = instance.allocate_scalars(); @@ -470,13 +477,13 @@ fn pedersen_commitment_equality( instance.allocate_eq(var_G * m + var_H * r1); instance.allocate_eq(var_G * m + var_H * r2); - instance.set_elements([(var_G, G::generator()), (var_H, G::random(&mut *rng))]); + instance.assign_elements([(var_G, G::generator()), (var_H, G::random(&mut *rng))]); - let witness = vec![ - G::Scalar::from(42), - G::Scalar::random(&mut *rng), - G::Scalar::random(&mut *rng), - ]; + let witness = ScalarMap::from_iter([ + (m, G::Scalar::from(42)), + (r1, G::Scalar::random(&mut *rng)), + (r2, G::Scalar::random(&mut *rng)), + ]); instance.compute_image(&witness).unwrap(); (instance.canonical().unwrap(), witness) @@ -484,7 +491,7 @@ fn pedersen_commitment_equality( fn elgamal_subtraction( rng: &mut R, -) -> (CanonicalLinearRelation, Vec) { +) -> (CanonicalLinearRelation, ScalarMap) { let mut instance = LinearRelation::new(); let [dk, a, r] = instance.allocate_scalars(); let [ek, C, D, H, G] = instance.allocate_elements(); @@ -498,13 +505,13 @@ fn elgamal_subtraction( instance.append_equation(C, G * v + dk * D + a * G); // set dk for testing to - let witness = vec![ - G::Scalar::from(4242), - G::Scalar::from(1000), - G::Scalar::random(&mut *rng), - ]; + let witness = ScalarMap::from_iter([ + (dk, G::Scalar::from(4242)), + (a, G::Scalar::from(1000)), + (r, G::Scalar::random(&mut *rng)), + ]); let alt_gen = G::random(&mut *rng); - instance.set_elements([(G, G::generator()), (H, alt_gen)]); + instance.assign_elements([(G, G::generator()), (H, alt_gen)]); instance.compute_image(&witness).unwrap(); (instance.canonical().unwrap(), witness) @@ -542,14 +549,17 @@ fn test_cmz_wallet_with_fee() { + var_z_w_balance * var_A, ); - relation.set_elements([(var_P_W, P_W), (var_A, A)]); - relation - .compute_image(&[n_balance, i_price, z_w_balance]) - .unwrap(); + relation.assign_elements([(var_P_W, P_W), (var_A, A)]); + let witness = ScalarMap::from_iter([ + (var_n_balance, n_balance), + (var_i_price, i_price), + (var_z_w_balance, z_w_balance), + ]); + relation.compute_image(&witness).unwrap(); // Try to convert to CanonicalLinearRelation - this should fail let nizk = relation.into_nizk(b"session_identifier").unwrap(); - let result = nizk.prove_batchable(&vec![n_balance, i_price, z_w_balance], &mut rng); + let result = nizk.prove_batchable(witness, &mut rng); assert!(result.is_ok()); let proof = result.unwrap(); let verify_result = nizk.verify_batchable(&proof); @@ -594,10 +604,10 @@ fn test_relations() { // Test both proof types let proof_batchable = nizk - .prove_batchable(&witness, &mut rng) + .prove_batchable(witness.clone(), &mut rng) .unwrap_or_else(|_| panic!("Failed to create batchable proof for {relation_name}")); let proof_compact = nizk - .prove_compact(&witness, &mut rng) + .prove_compact(witness, &mut rng) .unwrap_or_else(|_| panic!("Failed to create compact proof for {relation_name}")); // Verify both proof types diff --git a/src/tests/test_validation_criteria.rs b/src/tests/test_validation_criteria.rs index fd45920..164fa01 100644 --- a/src/tests/test_validation_criteria.rs +++ b/src/tests/test_validation_criteria.rs @@ -5,7 +5,7 @@ #[cfg(test)] mod instance_validation { - use crate::linear_relation::{CanonicalLinearRelation, LinearRelation}; + use crate::linear_relation::{Allocator, CanonicalLinearRelation, LinearRelation}; use bls12_381::{G1Projective as G, Scalar}; use ff::Field; use group::Group; @@ -21,7 +21,7 @@ mod instance_validation { // Set only one element, leaving var_g unassigned let x_val = G::generator() * Scalar::from(42u64); - relation.set_element(var_x_g, x_val); + relation.assign_element(var_x_g, x_val); // Add equation: X = x * G (but G is not set) relation.append_equation(var_x_g, var_x * var_g); @@ -39,8 +39,8 @@ mod instance_validation { let [var_x] = relation.allocate_scalars(); let [var_G] = relation.allocate_elements(); let var_X = relation.allocate_eq(var_G * var_x); - relation.set_element(var_G, G::generator()); - relation.set_element(var_X, G::identity()); + relation.assign_element(var_G, G::generator()); + relation.assign_element(var_X, G::identity()); let result = CanonicalLinearRelation::try_from(&relation); assert!(result.is_err()); @@ -49,8 +49,8 @@ mod instance_validation { let mut relation = LinearRelation::::new(); let [var_B] = relation.allocate_elements(); let var_X = relation.allocate_eq(var_B * Scalar::from(0)); - relation.set_element(var_B, G::generator()); - relation.set_element(var_X, G::identity()); + relation.assign_element(var_B, G::generator()); + relation.assign_element(var_X, G::identity()); let result = CanonicalLinearRelation::try_from(&relation); assert!(result.is_ok()); @@ -60,8 +60,8 @@ mod instance_validation { let [var_x] = relation.allocate_scalars(); let [var_C] = relation.allocate_elements(); let var_X = relation.allocate_eq(var_C * var_x * Scalar::from(0)); - relation.set_element(var_C, G::generator()); - relation.set_element(var_X, G::identity()); + relation.assign_element(var_C, G::generator()); + relation.assign_element(var_X, G::identity()); let result = CanonicalLinearRelation::try_from(&relation); assert!(result.is_ok()); } @@ -75,7 +75,7 @@ mod instance_validation { let x = relation.allocate_scalar(); let var_B = relation.allocate_element(); let var_X = relation.allocate_eq((x + (-Scalar::ONE)) * var_B + (-var_B)); - relation.set_element(var_X, G::identity()); + relation.assign_element(var_X, G::identity()); assert!(CanonicalLinearRelation::try_from(&relation).is_err()); // 2. because var_X is not assigned @@ -83,7 +83,7 @@ mod instance_validation { let x = relation.allocate_scalar(); let var_B = relation.allocate_element(); let _var_X = relation.allocate_eq((x + (-Scalar::ONE)) * var_B + (-var_B)); - relation.set_element(var_B, G::generator()); + relation.assign_element(var_B, G::generator()); assert!(CanonicalLinearRelation::try_from(&relation).is_err()); } @@ -93,7 +93,7 @@ mod instance_validation { let mut relation = LinearRelation::::new(); let [var_x] = relation.allocate_scalars(); let [var_g, var_h] = relation.allocate_elements(); - relation.set_elements([ + relation.assign_elements([ (var_g, G::generator()), (var_h, G::generator() * Scalar::from(2u64)), ]); @@ -101,9 +101,9 @@ mod instance_validation { // Add two equations but only one image element let var_img_1 = relation.allocate_eq(var_x * var_g + var_h); relation.allocate_eq(var_x * var_h + var_g); - relation.set_element(var_g, G::generator()); - relation.set_element(var_h, G::generator() * Scalar::from(2)); - relation.set_element(var_img_1, G::generator() * Scalar::from(3)); + relation.assign_element(var_g, G::generator()); + relation.assign_element(var_h, G::generator() * Scalar::from(2)); + relation.assign_element(var_img_1, G::generator() * Scalar::from(3)); assert!(relation.canonical().is_err()); } @@ -112,13 +112,13 @@ mod instance_validation { let rng = &mut rand::thread_rng(); let relation = LinearRelation::::new(); let nizk = relation.into_nizk(b"test_session").unwrap(); - let narg_string = nizk.prove_batchable(&vec![], rng).unwrap(); + let narg_string = nizk.prove_batchable([], rng).unwrap(); assert!(narg_string.is_empty()); let mut relation = LinearRelation::::new(); let var_B = relation.allocate_element(); let var_C = relation.allocate_eq(var_B * Scalar::from(1)); - relation.set_elements([(var_B, G::generator()), (var_C, G::generator())]); + relation.assign_elements([(var_B, G::generator()), (var_C, G::generator())]); assert!(CanonicalLinearRelation::try_from(&relation).is_ok()); } @@ -136,7 +136,7 @@ mod instance_validation { let mut linear_relation = LinearRelation::::new(); let B_var = linear_relation.allocate_element(); let C_var = linear_relation.allocate_eq(B_var); - linear_relation.set_elements([(B_var, B), (C_var, C)]); + linear_relation.assign_elements([(B_var, B), (C_var, C)]); assert!(linear_relation .canonical() .err() @@ -149,7 +149,7 @@ mod instance_validation { let mut linear_relation = LinearRelation::::new(); let [B_var, A_var] = linear_relation.allocate_elements(); let X_var = linear_relation.allocate_eq(B_var * pub_scalar + A_var * Scalar::from(3)); - linear_relation.set_elements([(B_var, B), (A_var, A), (X_var, X)]); + linear_relation.assign_elements([(B_var, B), (A_var, A), (X_var, X)]); assert!(linear_relation .canonical() .err() @@ -161,7 +161,7 @@ mod instance_validation { let mut linear_relation = LinearRelation::::new(); let B_var = linear_relation.allocate_element(); let C_var = linear_relation.allocate_eq(B_var); - linear_relation.set_elements([(B_var, B), (C_var, B)]); + linear_relation.assign_elements([(B_var, B), (C_var, B)]); assert!(linear_relation.canonical().is_ok()); // The following relation is valid and should pass. @@ -169,7 +169,7 @@ mod instance_validation { let mut linear_relation = LinearRelation::::new(); let [B_var, A_var] = linear_relation.allocate_elements(); let C_var = linear_relation.allocate_eq(B_var * pub_scalar + A_var * Scalar::from(3)); - linear_relation.set_elements([(B_var, B), (A_var, A), (C_var, C)]); + linear_relation.assign_elements([(B_var, B), (A_var, A), (C_var, C)]); assert!(linear_relation.canonical().is_ok()); // The following relation is for @@ -180,7 +180,7 @@ mod instance_validation { let [B_var, A_var] = linear_relation.allocate_elements(); let X_var = linear_relation .allocate_eq(B_var * x_var + B_var * pub_scalar + A_var * Scalar::from(3)); - linear_relation.set_elements([(B_var, B), (A_var, A), (X_var, X)]); + linear_relation.assign_elements([(B_var, B), (A_var, A), (X_var, X)]); assert!(linear_relation.canonical().is_ok()); } @@ -201,7 +201,7 @@ mod instance_validation { // The equation 0 = x*A + y*B + C // Has a non-trivial solution. - linear_relation.set_elements([(Z_var, Z), (A_var, A), (B_var, B), (C_var, C)]); + linear_relation.assign_elements([(Z_var, Z), (A_var, A), (B_var, B), (C_var, C)]); assert!(linear_relation.canonical().is_ok()); // Adding more non-trivial statements does not affect the validity of the relation. @@ -210,7 +210,7 @@ mod instance_validation { linear_relation.append_equation(F_var, f_var * A_var); let f = Scalar::random(&mut rng); let F = A * f; - linear_relation.set_elements([(F_var, F), (A_var, A)]); + linear_relation.assign_elements([(F_var, F), (A_var, A)]); assert!(linear_relation.canonical().is_ok()); } } @@ -220,7 +220,7 @@ mod proof_validation { use crate::codec::KeccakByteSchnorrCodec; use crate::composition::{ComposedRelation, ComposedWitness}; use crate::fiat_shamir::Nizk; - use crate::linear_relation::{CanonicalLinearRelation, LinearRelation}; + use crate::linear_relation::{Allocator, CanonicalLinearRelation, LinearRelation}; use bls12_381::{G1Projective as G, Scalar}; use ff::Field; use rand::RngCore; @@ -239,13 +239,12 @@ mod proof_validation { let x = Scalar::from(42u64); let x_g = G::generator() * x; - relation.set_elements([(var_g, G::generator()), (var_x_g, x_g)]); + relation.assign_elements([(var_g, G::generator()), (var_x_g, x_g)]); relation.append_equation(var_x_g, var_x * var_g); let nizk = TestNizk::new(b"test_session", relation.canonical().unwrap()); - let witness = vec![x]; - let proof = nizk.prove_batchable(&witness, &mut rng).unwrap(); + let proof = nizk.prove_batchable([(var_x, x)], &mut rng).unwrap(); (proof, nizk) } @@ -407,15 +406,15 @@ mod proof_validation { let x_var = lr1.allocate_scalar(); let A_var = lr1.allocate_element(); let eq1 = lr1.allocate_eq(x_var * A_var); - lr1.set_element(A_var, A); - lr1.set_element(eq1, C); + lr1.assign_element(A_var, A); + lr1.assign_element(eq1, C); // Create the second branch: C = y*B let mut lr2 = LinearRelation::new(); let y_var = lr2.allocate_scalar(); let B_var = lr2.allocate_element(); let eq2 = lr2.allocate_eq(y_var * B_var); - lr2.set_element(B_var, B); - lr2.set_element(eq2, C); + lr2.assign_element(B_var, B); + lr2.assign_element(eq2, C); // Create OR composition let or_relation = @@ -424,11 +423,8 @@ mod proof_validation { // Create a correct witness for branch 1 (C = y*B) // Note: x is NOT a valid witness for branch 0 because C ≠ x*A - let witness_correct = ComposedWitness::Or(vec![ - ComposedWitness::Simple(vec![x]), - ComposedWitness::Simple(vec![y]), - ]); - let proof = nizk.prove_batchable(&witness_correct, &mut rng).unwrap(); + let witness_correct = ComposedWitness::or([[(x_var, x)], [(y_var, y)]]); + let proof = nizk.prove_batchable(witness_correct, &mut rng).unwrap(); assert!( nizk.verify_batchable(&proof).is_ok(), "Valid proof should verify" @@ -437,23 +433,17 @@ mod proof_validation { // Now test with ONLY invalid witnesses (neither branch satisfied) // Branch 0 requires C = x*A, but we use random x // Branch 1 requires C = y*B, but we use a different random value - let wrong_y = Scalar::random(&mut rng); - let witness_wrong = ComposedWitness::Or(vec![ - ComposedWitness::Simple(vec![x]), - ComposedWitness::Simple(vec![wrong_y]), - ]); - let proof_result = nizk.prove_batchable(&witness_wrong, &mut rng); + let witness_wrong = + ComposedWitness::or([[(x_var, x)], [(y_var, Scalar::random(&mut rng))]]); + let proof_result = nizk.prove_batchable(witness_wrong, &mut rng); assert!( proof_result.is_err(), "Proof should fail with invalid witnesses" ); // Create a correct witness for both branches - let witness_correct = ComposedWitness::Or(vec![ - ComposedWitness::Simple(vec![y]), - ComposedWitness::Simple(vec![y]), - ]); - let proof = nizk.prove_batchable(&witness_correct, &mut rng).unwrap(); + let witness_correct = ComposedWitness::or([[(y_var, y)], [(y_var, y)]]); + let proof = nizk.prove_batchable(witness_correct, &mut rng).unwrap(); assert!( nizk.verify_batchable(&proof).is_ok(), "Prover fails when all witnesses in an OR proof are valid" diff --git a/src/traits.rs b/src/traits.rs index f669656..76cfec2 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -57,7 +57,7 @@ pub trait SigmaProtocol { /// - The internal state to use when computing the response. fn prover_commit( &self, - witness: &Self::Witness, + witness: Self::Witness, rng: &mut (impl Rng + CryptoRng), ) -> Result<(Self::Commitment, Self::ProverState), Error>;