diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad2d371a0..d5edb3c42 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -144,7 +144,7 @@ jobs: strategy: fail-fast: false matrix: - example: [ethsign, zklogin, sha256, sha512, keccak] + example: [ethsign, zklogin, sha256, sha512, keccak, hash_based_sig] steps: - name: Checkout Repository uses: actions/checkout@v4 diff --git a/Cargo.toml b/Cargo.toml index d87bdd4e3..59c094f46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ rand = { version = "0.9.1", default-features = false, features = [ rayon = "1.10.0" regex = "1.10" rsa = { version = "0.9.8", features = ["sha2"] } +rstest = "0.26.1" seq-macro = "0.3.5" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" diff --git a/crates/examples/examples/hash_based_sig.rs b/crates/examples/examples/hash_based_sig.rs new file mode 100644 index 000000000..c715ab91f --- /dev/null +++ b/crates/examples/examples/hash_based_sig.rs @@ -0,0 +1,204 @@ +use std::array; + +use anyhow::Result; +use binius_core::Word; +use binius_examples::{Cli, ExampleCircuit}; +use binius_frontend::{ + circuits::hash_based_sig::{ + winternitz_ots::WinternitzSpec, + witness_utils::{ValidatorSignatureData, XmssHasherData, populate_xmss_hashers}, + xmss::XmssSignature, + xmss_aggregate::{XmssMultisigHashers, circuit_xmss_multisig}, + }, + compiler::{CircuitBuilder, Wire, circuit::WitnessFiller}, + util::pack_bytes_into_wires_le, +}; +use clap::Args; +use rand::{RngCore, SeedableRng, rngs::StdRng}; + +/// Hash-based multi-signature verification example circuit +struct HashBasedSigExample { + spec: WinternitzSpec, + tree_height: usize, + num_validators: usize, + param: Vec, + message: Vec, + epoch: Wire, + validator_roots: Vec<[Wire; 4]>, + validator_signatures: Vec, + hashers: XmssMultisigHashers, +} + +#[derive(Args, Debug)] +struct Params { + /// Number of validators in the multi-signature + #[arg(short = 'n', long, default_value_t = 3)] + num_validators: usize, + + /// Height of the Merkle tree (2^height slots) + #[arg(short = 't', long, default_value_t = 3)] + tree_height: usize, + + /// Winternitz spec: 1 or 2 + #[arg(short = 's', long, default_value_t = 1)] + spec: u8, +} + +#[derive(Args, Debug)] +struct Instance {} + +impl ExampleCircuit for HashBasedSigExample { + type Params = Params; + type Instance = Instance; + + fn build(params: Params, builder: &mut CircuitBuilder) -> Result { + let spec = match params.spec { + 1 => WinternitzSpec::spec_1(), + 2 => WinternitzSpec::spec_2(), + _ => anyhow::bail!("Invalid spec: must be 1 or 2"), + }; + + let tree_height = params.tree_height; + if tree_height >= 10 { + anyhow::bail!("tree_height {} exceeds the maximum supported height of 10", tree_height); + } + let num_validators = params.num_validators; + + let param_wire_count = spec.domain_param_len.div_ceil(8); + let param: Vec = (0..param_wire_count).map(|_| builder.add_inout()).collect(); + let message: Vec = (0..4).map(|_| builder.add_inout()).collect(); + let epoch = builder.add_inout(); + + let validator_roots: Vec<[Wire; 4]> = (0..num_validators) + .map(|_| array::from_fn(|_| builder.add_inout())) + .collect(); + + let validator_signatures: Vec = (0..num_validators) + .map(|_| XmssSignature { + nonce: (0..3).map(|_| builder.add_witness()).collect(), + epoch, // All validators use the same epoch wire + signature_hashes: (0..spec.dimension()) + .map(|_| array::from_fn(|_| builder.add_witness())) + .collect(), + public_key_hashes: (0..spec.dimension()) + .map(|_| array::from_fn(|_| builder.add_witness())) + .collect(), + auth_path: (0..tree_height) + .map(|_| array::from_fn(|_| builder.add_witness())) + .collect(), + }) + .collect(); + + let hashers = circuit_xmss_multisig( + builder, + &spec, + ¶m, + &message, + epoch, + &validator_roots, + &validator_signatures, + ); + + Ok(Self { + spec, + tree_height, + num_validators, + param, + message, + epoch, + validator_roots, + validator_signatures, + hashers, + }) + } + + fn populate_witness(&self, _instance: Instance, w: &mut WitnessFiller) -> Result<()> { + let mut rng = StdRng::seed_from_u64(0); + + let mut param_bytes = vec![0u8; self.spec.domain_param_len]; + rng.fill_bytes(&mut param_bytes); + + let mut message_bytes = [0u8; 32]; + rng.fill_bytes(&mut message_bytes); + + // Safe because tree_height is validated to be < 10 in build() + let epoch = rng.next_u32() % (1u32 << self.tree_height); + + // Pack param_bytes (pad to match wire count) + let mut padded_param = vec![0u8; self.param.len() * 8]; + padded_param[..param_bytes.len()].copy_from_slice(¶m_bytes); + pack_bytes_into_wires_le(w, &self.param, &padded_param); + pack_bytes_into_wires_le(w, &self.message, &message_bytes); + w[self.epoch] = Word::from_u64(epoch as u64); + + // Generate a signature for each validator + for val_idx in 0..self.num_validators { + let validator_data = ValidatorSignatureData::generate( + &mut rng, + ¶m_bytes, + &message_bytes, + epoch, + &self.spec, + self.tree_height, + ); + + pack_bytes_into_wires_le(w, &self.validator_roots[val_idx], &validator_data.root); + + let mut nonce_padded = [0u8; 24]; + nonce_padded[..23].copy_from_slice(&validator_data.nonce); + pack_bytes_into_wires_le(w, &self.validator_signatures[val_idx].nonce, &nonce_padded); + + for (i, sig_hash) in validator_data.signature_hashes.iter().enumerate() { + pack_bytes_into_wires_le( + w, + &self.validator_signatures[val_idx].signature_hashes[i], + sig_hash, + ); + } + + for (i, pk_hash) in validator_data.public_key_hashes.iter().enumerate() { + pack_bytes_into_wires_le( + w, + &self.validator_signatures[val_idx].public_key_hashes[i], + pk_hash, + ); + } + + for (i, auth_node) in validator_data.auth_path.iter().enumerate() { + pack_bytes_into_wires_le( + w, + &self.validator_signatures[val_idx].auth_path[i], + auth_node, + ); + } + + let hasher_data = XmssHasherData { + param_bytes: param_bytes.clone(), + message_bytes, + nonce_bytes: validator_data.nonce.to_vec(), + epoch: epoch as u64, + coords: validator_data.coords, + sig_hashes: validator_data.signature_hashes, + pk_hashes: validator_data.public_key_hashes, + auth_path: validator_data.auth_path, + }; + + populate_xmss_hashers( + w, + &self.hashers.validator_hashers[val_idx], + &self.spec, + &hasher_data, + ); + } + + Ok(()) + } +} + +fn main() -> Result<()> { + let _tracing_guard = tracing_profile::init_tracing()?; + + Cli::::new("hash_based_sig") + .about("Hash-based multi-signature (XMSS) verification example") + .run() +} diff --git a/crates/examples/snapshots/hash_based_sig.snap b/crates/examples/snapshots/hash_based_sig.snap new file mode 100644 index 000000000..d508ea5fd --- /dev/null +++ b/crates/examples/snapshots/hash_based_sig.snap @@ -0,0 +1,11 @@ +hash_based_sig circuit +-- +Number of gates: 4544883 +Number of evaluation instructions: 4634982 +Number of AND constraints: 4580460 +Number of MUL constraints: 0 +Length of value vec: 8388608 + Constants: 376 + Inout: 20 + Witness: 29886 + Internal: 4536540 diff --git a/crates/frontend/Cargo.toml b/crates/frontend/Cargo.toml index 300cda176..cf1cba533 100644 --- a/crates/frontend/Cargo.toml +++ b/crates/frontend/Cargo.toml @@ -31,3 +31,4 @@ rsa = { workspace = true, features = ["sha2"] } hex-literal = { workspace = true } num-traits = { workspace = true } proptest = { workspace = true } +rstest = { workspace = true } diff --git a/crates/frontend/src/circuits/hash_based_sig/mod.rs b/crates/frontend/src/circuits/hash_based_sig/mod.rs index 0245d8f3e..ca7903541 100644 --- a/crates/frontend/src/circuits/hash_based_sig/mod.rs +++ b/crates/frontend/src/circuits/hash_based_sig/mod.rs @@ -3,3 +3,6 @@ pub mod codeword; pub mod hashing; pub mod merkle_tree; pub mod winternitz_ots; +pub mod witness_utils; +pub mod xmss; +pub mod xmss_aggregate; diff --git a/crates/frontend/src/circuits/hash_based_sig/witness_utils.rs b/crates/frontend/src/circuits/hash_based_sig/witness_utils.rs new file mode 100644 index 000000000..5ddf5f29a --- /dev/null +++ b/crates/frontend/src/circuits/hash_based_sig/witness_utils.rs @@ -0,0 +1,340 @@ +//! Witness population utilities for hash-based signature verification. +//! +//! This module provides helper functions for populating witness data +//! in hash-based signature circuits, including XMSS and Winternitz OTS. + +use rand::{RngCore, rngs::StdRng}; + +use super::{ + hashing::{ + build_chain_hash, build_message_hash, build_public_key_hash, build_tree_hash, + hash_chain_keccak, hash_message, hash_public_key_keccak, hash_tree_node_keccak, + }, + winternitz_ots::{WinternitzSpec, grind_nonce}, + xmss::XmssHashers, +}; +use crate::compiler::circuit::WitnessFiller; + +/// Builds a complete Merkle tree from leaf nodes. +/// +/// This function assumes the number of leaves is a power of 2. +/// +/// # Returns +/// A tuple containing: +/// - Vector of tree levels (index 0 = leaves, last index = root) +/// - The root hash +/// +/// # Panics +/// Panics if leaves.len() is not a power of 2 +pub fn build_merkle_tree(param: &[u8], leaves: &[[u8; 32]]) -> (Vec>, [u8; 32]) { + assert!(leaves.len().is_power_of_two(), "Number of leaves must be a power of 2"); + + let tree_depth = leaves.len().trailing_zeros() as usize; + let mut tree_levels = vec![leaves.to_vec()]; + + for level in 0..tree_depth { + let current_level = &tree_levels[level]; + let mut next_level = Vec::new(); + + for i in (0..current_level.len()).step_by(2) { + let parent = hash_tree_node_keccak( + param, + ¤t_level[i], + ¤t_level[i + 1], + level as u32, + (i / 2) as u32, + ); + next_level.push(parent); + } + + tree_levels.push(next_level); + } + + let root = tree_levels[tree_depth][0]; + (tree_levels, root) +} + +/// Extracts the authentication path for a given leaf index in a Merkle tree. +/// +/// This function assumes the tree has power-of-2 leaves. +/// +/// # Arguments +/// * `tree_levels` - All levels of the tree (from build_merkle_tree) +/// * `leaf_index` - Index of the leaf to build path for +/// +/// # Returns +/// Vector of sibling hashes from leaf to root +pub fn extract_auth_path(tree_levels: &[Vec<[u8; 32]>], leaf_index: usize) -> Vec<[u8; 32]> { + let mut auth_path = Vec::new(); + let mut idx = leaf_index; + let tree_height = tree_levels.len() - 1; + + for level in 0..tree_height { + let sibling_idx = idx ^ 1; + auth_path.push(tree_levels[level][sibling_idx]); + idx /= 2; + } + + auth_path +} + +/// Helper structure containing signature data for a validator. +/// +/// This is useful for generating test data or populating witness values +/// in multi-signature scenarios. +pub struct ValidatorSignatureData { + /// Root hash of the validator's Merkle tree + pub root: [u8; 32], + /// Nonce (23 bytes) + pub nonce: [u8; 23], + /// Signature hashes for each Winternitz chain + pub signature_hashes: Vec<[u8; 32]>, + /// Public key hashes for each Winternitz chain + pub public_key_hashes: Vec<[u8; 32]>, + /// Authentication path in the Merkle tree + pub auth_path: Vec<[u8; 32]>, + /// Codeword coordinates + pub coords: Vec, +} + +impl ValidatorSignatureData { + /// Generate a valid signature for a validator at a given epoch. + /// + /// This function generates all the cryptographic data needed for a validator's + /// signature including the Winternitz OTS signature, public key, and Merkle tree + /// authentication path. + /// + /// # Panics + /// Panics if: + /// - The epoch is greater than the number of leaves in the tree. + /// - A `grind_nonce` fails to find a valid nonce + /// - A coordinate returned by `grind_nonce` is invalid. + pub fn generate( + rng: &mut StdRng, + param_bytes: &[u8], + message_bytes: &[u8; 32], + epoch: u32, + spec: &WinternitzSpec, + tree_height: usize, + ) -> Self { + assert!( + tree_height <= 10, + "Tree height {} exceeds maximum supported height of 10", + tree_height, + ); + + // Validate epoch is within valid range for the tree + let num_leaves = 1usize << tree_height; + assert!( + (epoch as usize) < num_leaves, + "Epoch {} exceeds maximum leaf index {} for tree height {}", + epoch, + num_leaves - 1, + tree_height + ); + + let grind_result = + grind_nonce(spec, rng, param_bytes, message_bytes).expect("Failed to find valid nonce"); + + let mut nonce = [0u8; 23]; + nonce.copy_from_slice(&grind_result.nonce); + let coords = grind_result.coords; + + // Generate Winternitz signature and public key + let mut signature_hashes = Vec::new(); + let mut public_key_hashes = Vec::new(); + + for (chain_idx, &coord) in coords.iter().enumerate() { + assert!( + (coord as usize) < spec.chain_len(), + "Coordinate {} exceeds chain length {}", + coord, + spec.chain_len() + ); + + let mut sig_hash = [0u8; 32]; + rng.fill_bytes(&mut sig_hash); + signature_hashes.push(sig_hash); + + let pk_hash = hash_chain_keccak( + param_bytes, + chain_idx, + &sig_hash, + coord as usize, + spec.chain_len() - 1 - coord as usize, + ); + public_key_hashes.push(pk_hash); + } + + // Build a Merkle tree with 2^tree_height leaves + let mut leaves = vec![[0u8; 32]; num_leaves]; + leaves[epoch as usize] = hash_public_key_keccak(param_bytes, &public_key_hashes); + for (i, leaf) in leaves.iter_mut().enumerate() { + if i != epoch as usize { + rng.fill_bytes(leaf); + } + } + + let (tree_levels, root) = build_merkle_tree(param_bytes, &leaves); + let auth_path = extract_auth_path(&tree_levels, epoch as usize); + + ValidatorSignatureData { + root, + nonce, + signature_hashes, + public_key_hashes, + auth_path, + coords, + } + } +} + +/// Data structure containing all the information needed to populate XMSS hashers. +pub struct XmssHasherData { + /// Parameter bytes (variable length based on spec) + pub param_bytes: Vec, + /// Message bytes (32 bytes) + pub message_bytes: [u8; 32], + /// Nonce bytes (variable length, typically 23) + pub nonce_bytes: Vec, + /// Epoch/leaf index + pub epoch: u64, + /// Codeword coordinates + pub coords: Vec, + /// Signature hashes for each chain + pub sig_hashes: Vec<[u8; 32]>, + /// Public key hashes for each chain + pub pk_hashes: Vec<[u8; 32]>, + /// Authentication path for Merkle tree + pub auth_path: Vec<[u8; 32]>, +} + +/// Populates all hashers in an XmssHashers struct with witness data. +/// +/// This function fills in the message hasher, chain hashers, public key hasher, +/// and Merkle path hashers with the appropriate witness data for verification. +/// +/// # Arguments +/// +/// * `w` - The witness filler to populate +/// * `hashers` - The XMSS hashers to populate +/// * `spec` - The Winternitz specification +/// * `data` - The data to use for population +/// +/// # Panics +/// Panics if: +/// - `data.coord.len()` is not equal to `spec.dimension()` +/// - `data.sig_hashes.len()` is not equal to `spec.dimension()` +/// - `data.pk_hashes.len()` is not equal to `spec.dimension()` +pub fn populate_xmss_hashers( + w: &mut WitnessFiller, + hashers: &XmssHashers, + spec: &WinternitzSpec, + data: &XmssHasherData, +) { + assert_eq!( + data.coords.len(), + spec.dimension(), + "Coordinates length {} doesn't match spec dimension {}", + data.coords.len(), + spec.dimension() + ); + assert_eq!( + data.sig_hashes.len(), + spec.dimension(), + "Signature hashes length {} doesn't match spec dimension {}", + data.sig_hashes.len(), + spec.dimension() + ); + assert_eq!( + data.pk_hashes.len(), + spec.dimension(), + "Public key hashes length {} doesn't match spec dimension {}", + data.pk_hashes.len(), + spec.dimension() + ); + + // Populate message hasher + let message_hash = hash_message(&data.param_bytes, &data.nonce_bytes, &data.message_bytes); + let tweaked_message = + build_message_hash(&data.param_bytes, &data.nonce_bytes, &data.message_bytes); + + hashers + .winternitz_ots + .message_hasher + .populate_message(w, &tweaked_message); + hashers + .winternitz_ots + .message_hasher + .populate_digest(w, message_hash); + + // Populate chain hashers + let mut hasher_idx = 0; + for (chain_idx, &coord) in data.coords.iter().enumerate() { + let mut current_hash = data.sig_hashes[chain_idx]; + + for step in 0..spec.chain_len() { + let position = step + coord as usize; + let position_plus_one = position + 1; + + let next_hash = + hash_chain_keccak(&data.param_bytes, chain_idx, ¤t_hash, position, 1); + + let hasher = &hashers.winternitz_ots.chain_hashers[hasher_idx]; + let chain_message = build_chain_hash( + &data.param_bytes, + ¤t_hash, + chain_idx as u64, + position_plus_one as u64, + ); + hasher.populate_message(w, &chain_message); + hasher.populate_digest(w, next_hash); + + if position_plus_one < spec.chain_len() { + current_hash = next_hash; + } + + hasher_idx += 1; + } + } + + // Populate public key hasher + let pk_message = build_public_key_hash(&data.param_bytes, &data.pk_hashes); + let pk_hash = hash_public_key_keccak(&data.param_bytes, &data.pk_hashes); + hashers.public_key_hasher.populate_message(w, &pk_message); + hashers.public_key_hasher.populate_digest(w, pk_hash); + + // Populate merkle path hashers + let mut current_hash = pk_hash; + let mut current_index = data.epoch as usize; + + for (level, auth_sibling) in data.auth_path.iter().enumerate() { + let (left, right) = if current_index % 2 == 0 { + (¤t_hash, auth_sibling) + } else { + (auth_sibling, ¤t_hash) + }; + + let parent = hash_tree_node_keccak( + &data.param_bytes, + left, + right, + level as u32, + (current_index / 2) as u32, + ); + + let tree_message = build_tree_hash( + &data.param_bytes, + left, + right, + level as u32, + (current_index / 2) as u32, + ); + + hashers.merkle_path_hashers[level].populate_message(w, &tree_message); + hashers.merkle_path_hashers[level].populate_digest(w, parent); + + current_hash = parent; + current_index /= 2; + } +} diff --git a/crates/frontend/src/circuits/hash_based_sig/xmss.rs b/crates/frontend/src/circuits/hash_based_sig/xmss.rs new file mode 100644 index 000000000..545618b1e --- /dev/null +++ b/crates/frontend/src/circuits/hash_based_sig/xmss.rs @@ -0,0 +1,442 @@ +use super::{ + hashing::circuit_public_key_hash, + merkle_tree::circuit_merkle_path, + winternitz_ots::{WinternitzOtsHashers, WinternitzSpec, circuit_winternitz_ots}, +}; +use crate::{ + circuits::keccak::Keccak, + compiler::{CircuitBuilder, Wire}, +}; + +/// An XMSS signature. +/// +/// This structure contains all the witness data for an XMSS signature to be +/// verified. +#[derive(Clone)] +pub struct XmssSignature { + /// Nonce (23 bytes in 3 wires) + pub nonce: Vec, + /// The epoch is the index of key-pair used in the signature + pub epoch: Wire, + /// Winternitz signature hash values + pub signature_hashes: Vec<[Wire; 4]>, + /// Winternitz public key hashes + pub public_key_hashes: Vec<[Wire; 4]>, + /// Merkle authentication path + pub auth_path: Vec<[Wire; 4]>, +} + +/// The collection of Keccak hashers used in XMSS verification. +pub struct XmssHashers { + /// Winternitz OTS hashers containing the message hasher and chain verification hashers. + /// See `WinternitzOtsHashers` documentation for details on populating these. + pub winternitz_ots: WinternitzOtsHashers, + + /// Keccak hasher for computing the OTS public key hash from individual Winternitz public keys. + /// Computes: `hash(param || TWEAK_PUBLIC_KEY || pk_hash[0] || pk_hash[1] || ... || + /// pk_hash[D-1])` Must be populated with: + /// - Message: The concatenated public key data (use `hashing::build_public_key_hash`) + /// - Digest: The resulting public key hash (which becomes a leaf in the Merkle tree) + pub public_key_hasher: Keccak, + + /// Vector of Keccak hashers for verifying the Merkle tree authentication path. + /// Contains one hasher per level of the tree that needs to be computed. + /// Each hasher computes: `hash(param || TWEAK_TREE || level || index || left_child || + /// right_child)` Must be populated with: + /// - Message: The tree node hash message (use `hashing::build_tree_hash`) + /// - Digest: The parent node hash at that level + /// + /// The hashers are ordered from leaf level upward to the root. + pub merkle_path_hashers: Vec, +} + +/// Verifies an XMSS (eXtended Merkle Signature Scheme) signature. +/// +/// This circuit combines: +/// 1. Winternitz OTS verification for the one-time signature +/// 2. Computation of public key hash from Winternitz public key +/// 3. Merkle tree path verification to prove the public key is in the tree +/// +/// # Arguments +/// +/// * `builder` - Circuit builder for constructing constraints +/// * `spec` - Winternitz specification parameters (including domain_param_len) +/// * `domain_param` - Cryptographic domain parameter as 64-bit LE-packed wires. The actual byte +/// length is specified by `spec.domain_param_len`, and the wires must have sufficient capacity +/// (i.e., `domain_param.len() * 8 >= spec.domain_param_len`) +/// * `message` - Message to verify (32 bytes as 4x64-bit LE wires) +/// * `signature` - The XMSS signature containing all witness data +/// * `root_hash` - Expected Merkle tree root hash (32 bytes as 4x64-bit LE wires) +/// +/// # Returns +/// +/// An `XmssHashers` struct containing all hashers that need witness population +pub fn circuit_xmss( + builder: &CircuitBuilder, + spec: &WinternitzSpec, + domain_param: &[Wire], + message: &[Wire], + signature: &XmssSignature, + root_hash: &[Wire; 4], +) -> XmssHashers { + // Step 1: Verify the Winternitz OTS signature + let winternitz_ots = circuit_winternitz_ots( + builder, + domain_param, + message, + &signature.nonce, + &signature.signature_hashes, + &signature.public_key_hashes, + spec, + ); + + // Step 2: Compute the public key hash from the Winternitz public key + let pk_hash_output: [Wire; 4] = std::array::from_fn(|_| builder.add_witness()); + let public_key_hasher = circuit_public_key_hash( + builder, + domain_param.to_vec(), + spec.domain_param_len, + &signature.public_key_hashes, + pk_hash_output, + ); + + // Step 3: Verify the Merkle tree path + let merkle_path_hashers = circuit_merkle_path( + builder, + domain_param, + spec.domain_param_len, + &pk_hash_output, + signature.epoch, + &signature.auth_path, + root_hash, + ); + + XmssHashers { + winternitz_ots, + public_key_hasher, + merkle_path_hashers, + } +} + +#[cfg(test)] +mod tests { + use binius_core::Word; + use rand::{RngCore, SeedableRng, rngs::StdRng}; + use rstest::rstest; + + use super::*; + use crate::{ + circuits::hash_based_sig::{ + hashing::{hash_chain_keccak, hash_public_key_keccak}, + winternitz_ots::grind_nonce, + witness_utils::{ + XmssHasherData, build_merkle_tree, extract_auth_path, populate_xmss_hashers, + }, + }, + constraint_verifier::verify_constraints, + util::pack_bytes_into_wires_le, + }; + + /// Helper struct containing all test data for XMSS verification + struct XmssTestData { + param_bytes: Vec, + message_bytes: [u8; 32], + nonce_bytes: Vec, + epoch: u64, + coords: Vec, + sig_hashes: Vec<[u8; 32]>, + pk_hashes: Vec<[u8; 32]>, + auth_path: Vec<[u8; 32]>, + root_hash: [u8; 32], + tree_depth: usize, + } + + impl XmssTestData { + /// Generate test data for XMSS verification + fn generate( + spec: &WinternitzSpec, + tree_size: usize, + signing_epoch: u64, + rng: &mut StdRng, + ) -> Self { + // Generate random parameters based on spec + let mut param_bytes = vec![0u8; spec.domain_param_len]; + rng.fill_bytes(&mut param_bytes); + + let mut message_bytes = [0u8; 32]; + rng.fill_bytes(&mut message_bytes); + + // Find valid nonce + let grind_result = grind_nonce(spec, rng, ¶m_bytes, &message_bytes) + .expect("Failed to find valid nonce"); + + // Generate Winternitz signature and public key + let mut sig_hashes = Vec::new(); + let mut pk_hashes = Vec::new(); + + for (chain_idx, &coord) in grind_result.coords.iter().enumerate() { + let mut sig_hash = [0u8; 32]; + rng.fill_bytes(&mut sig_hash); + sig_hashes.push(sig_hash); + + let pk_hash = hash_chain_keccak( + ¶m_bytes, + chain_idx, + &sig_hash, + coord as usize, + spec.chain_len() - 1 - coord as usize, + ); + pk_hashes.push(pk_hash); + } + + // Build Merkle tree + let mut leaves = Vec::new(); + for i in 0..tree_size { + if i as u64 == signing_epoch { + leaves.push(hash_public_key_keccak(¶m_bytes, &pk_hashes)); + } else { + // Fill other epochs with random values - these represent other public keys + // in the tree that we're not using for this signature verification + let mut leaf = [0u8; 32]; + rng.fill_bytes(&mut leaf); + leaves.push(leaf); + } + } + + let (tree_levels, root_hash) = build_merkle_tree(¶m_bytes, &leaves); + let auth_path = extract_auth_path(&tree_levels, signing_epoch as usize); + + XmssTestData { + param_bytes, + message_bytes, + nonce_bytes: grind_result.nonce, + epoch: signing_epoch, + coords: grind_result.coords, + sig_hashes, + pk_hashes, + auth_path, + root_hash, + tree_depth: tree_levels.len() - 1, + } + } + + /// Run verification test with this test data + fn run(&self, spec: &WinternitzSpec) -> Result<(), String> { + let builder = CircuitBuilder::new(); + + // Create input wires based on spec + let param_wire_count = spec.domain_param_len.div_ceil(8); + let param: Vec = (0..param_wire_count).map(|_| builder.add_inout()).collect(); + let message: Vec = (0..4).map(|_| builder.add_inout()).collect(); + let nonce: Vec = (0..3).map(|_| builder.add_inout()).collect(); + let epoch = builder.add_inout(); + let root_hash: [Wire; 4] = std::array::from_fn(|_| builder.add_inout()); + + let signature_hashes: Vec<[Wire; 4]> = (0..spec.dimension()) + .map(|_| std::array::from_fn(|_| builder.add_inout())) + .collect(); + + let public_key_hashes: Vec<[Wire; 4]> = (0..spec.dimension()) + .map(|_| std::array::from_fn(|_| builder.add_inout())) + .collect(); + + let auth_path: Vec<[Wire; 4]> = (0..self.tree_depth) + .map(|_| std::array::from_fn(|_| builder.add_inout())) + .collect(); + + // Create the verification circuit + let signature = XmssSignature { + nonce: nonce.clone(), + epoch, + signature_hashes: signature_hashes.clone(), + public_key_hashes: public_key_hashes.clone(), + auth_path: auth_path.clone(), + }; + + let hashers = circuit_xmss(&builder, spec, ¶m, &message, &signature, &root_hash); + + let circuit = builder.build(); + let mut w = circuit.new_witness_filler(); + + // Pack inputs into wires (pad param_bytes to match wire count) + let mut padded_param = vec![0u8; param.len() * 8]; + padded_param[..self.param_bytes.len()].copy_from_slice(&self.param_bytes); + pack_bytes_into_wires_le(&mut w, ¶m, &padded_param); + pack_bytes_into_wires_le(&mut w, &message, &self.message_bytes); + + let mut nonce_padded = vec![0u8; 24]; + nonce_padded[..self.nonce_bytes.len()].copy_from_slice(&self.nonce_bytes); + pack_bytes_into_wires_le(&mut w, &nonce, &nonce_padded); + + w[epoch] = Word::from_u64(self.epoch); + pack_bytes_into_wires_le(&mut w, &root_hash, &self.root_hash); + + for (i, sig_hash) in self.sig_hashes.iter().enumerate() { + pack_bytes_into_wires_le(&mut w, &signature_hashes[i], sig_hash); + } + + for (i, pk_hash) in self.pk_hashes.iter().enumerate() { + pack_bytes_into_wires_le(&mut w, &public_key_hashes[i], pk_hash); + } + + for (i, auth_node) in self.auth_path.iter().enumerate() { + pack_bytes_into_wires_le(&mut w, &auth_path[i], auth_node); + } + + let hasher_data = XmssHasherData { + param_bytes: self.param_bytes.clone(), + message_bytes: self.message_bytes, + nonce_bytes: self.nonce_bytes.clone(), + epoch: self.epoch, + coords: self.coords.clone(), + sig_hashes: self.sig_hashes.clone(), + pk_hashes: self.pk_hashes.clone(), + auth_path: self.auth_path.clone(), + }; + populate_xmss_hashers(&mut w, &hashers, spec, &hasher_data); + + circuit + .populate_wire_witness(&mut w) + .map_err(|e| format!("Wire population failed: {:?}", e))?; + + let cs = circuit.constraint_system(); + verify_constraints(cs, &w.into_value_vec()) + .map_err(|e| format!("Constraint verification failed: {:?}", e))?; + + Ok(()) + } + } + + /// Test case configuration for parameterized testing + enum TestCase { + Valid { + tree_size: usize, + signing_epoch: u64, + }, + Invalid { + tree_size: usize, + signing_epoch: u64, + corrupt_fn: fn(&mut XmssTestData), + }, + } + + impl TestCase { + fn run(&self, spec: WinternitzSpec) { + let mut rng = StdRng::seed_from_u64(42); + + match self { + TestCase::Valid { + tree_size, + signing_epoch, + } => { + // Generate test data + let test_data = + XmssTestData::generate(&spec, *tree_size, *signing_epoch, &mut rng); + + let result = test_data.run(&spec); + result.unwrap_or_else(|e| { + panic!("Test expected to pass but failed: {}", e); + }); + } + TestCase::Invalid { + tree_size, + signing_epoch, + corrupt_fn, + } => { + // Generate test data + let mut test_data = + XmssTestData::generate(&spec, *tree_size, *signing_epoch, &mut rng); + + // Apply corruption + corrupt_fn(&mut test_data); + + let result = test_data.run(&spec); + assert!(result.is_err(), "Test expected to fail but passed"); + } + } + } + } + + fn corrupt_signature(test_data: &mut XmssTestData) { + // Corrupt the first signature hash + if !test_data.sig_hashes.is_empty() { + test_data.sig_hashes[0][0] ^= 0xFF; + } + } + + fn corrupt_public_key(test_data: &mut XmssTestData) { + // Corrupt the first public key hash + if !test_data.pk_hashes.is_empty() { + test_data.pk_hashes[0][0] ^= 0xFF; + } + } + + fn corrupt_auth_path(test_data: &mut XmssTestData) { + // Corrupt a node in the authentication path + if !test_data.auth_path.is_empty() { + test_data.auth_path[0][0] ^= 0xFF; + } + } + + fn corrupt_root_hash(test_data: &mut XmssTestData) { + // Corrupt the root hash + test_data.root_hash[0] ^= 0xFF; + } + + fn corrupt_message(test_data: &mut XmssTestData) { + // Change the message after signing + test_data.message_bytes[0] ^= 0xFF; + } + + fn corrupt_epoch(test_data: &mut XmssTestData) { + // Use wrong epoch + test_data.epoch = (test_data.epoch + 1) % 4; + } + + // ==================== Test Specs ==================== + + fn test_spec_small() -> WinternitzSpec { + WinternitzSpec { + message_hash_len: 4, + coordinate_resolution_bits: 2, + target_sum: 24, + domain_param_len: 32, + } + } + + /// Valid test cases with different configurations + #[rstest] + #[case::small_tree_4(test_spec_small(), 4, 1)] + #[case::small_tree_8(test_spec_small(), 8, 3)] + #[case::medium_tree_16(test_spec_small(), 16, 7)] + #[case::spec1(WinternitzSpec::spec_1(), 4, 0)] + #[case::spec2(WinternitzSpec::spec_2(), 4, 2)] + fn test_xmss_valid( + #[case] spec: WinternitzSpec, + #[case] tree_size: usize, + #[case] signing_epoch: u64, + ) { + TestCase::Valid { + tree_size, + signing_epoch, + } + .run(spec); + } + + /// Invalid test cases with various corruption scenarios + #[rstest] + #[case::corrupt_signature(corrupt_signature)] + #[case::corrupt_public_key(corrupt_public_key)] + #[case::corrupt_auth_path(corrupt_auth_path)] + #[case::corrupt_root(corrupt_root_hash)] + #[case::corrupt_message(corrupt_message)] + #[case::corrupt_epoch(corrupt_epoch)] + fn test_xmss_invalid(#[case] corrupt_fn: fn(&mut XmssTestData)) { + TestCase::Invalid { + tree_size: 4, + signing_epoch: 1, + corrupt_fn, + } + .run(test_spec_small()); + } +} diff --git a/crates/frontend/src/circuits/hash_based_sig/xmss_aggregate.rs b/crates/frontend/src/circuits/hash_based_sig/xmss_aggregate.rs new file mode 100644 index 000000000..f8f02e512 --- /dev/null +++ b/crates/frontend/src/circuits/hash_based_sig/xmss_aggregate.rs @@ -0,0 +1,471 @@ +//! XMSS multi-signature aggregation for multiple validators. +//! +//! This module implements aggregation of XMSS signatures where each validator +//! has their own independent XMSS tree and signs at their designated epoch. +//! The aggregation creates a single proof that all signatures are valid. + +use super::{ + winternitz_ots::WinternitzSpec, + xmss::{XmssHashers, XmssSignature, circuit_xmss}, +}; +use crate::compiler::{CircuitBuilder, Wire}; + +/// The collection of XMSS hashers for multi-signature verification. +/// Contains one `XmssHashers` struct per validator. +pub struct XmssMultisigHashers { + /// Vector of XmssHashers, one for each validator + pub validator_hashers: Vec, +} + +/// Verifies multiple XMSS signatures on the same message from different validators at a common +/// epoch. +/// +/// This function implements multi-signature aggregation where: +/// - Each validator has their own independent XMSS tree (different roots) +/// - All validators sign the same message +/// - All validators sign at the same epoch (leaf index) +/// - The proof aggregates all individual signature verifications +/// +/// # Public Inputs (inout wires) +/// - `domain_param`: Cryptographic domain parameter shared by all validators as 64-bit LE-packed +/// wires. The actual byte length is specified by `spec.domain_param_len` +/// - `message`: The common message being signed by all validators +/// - `epoch`: The common epoch (leaf index) at which all validators sign +/// - `validator_roots`: Each validator's XMSS tree root +/// +/// # Private Inputs (witness wires) +/// - `validator_signatures`: Each validator's signature data (witness) +/// +/// # Returns +/// +/// An `XmssMultisigHashers` struct containing all hashers that need witness population +pub fn circuit_xmss_multisig( + builder: &CircuitBuilder, + spec: &WinternitzSpec, + domain_param: &[Wire], + message: &[Wire], + epoch: Wire, + validator_roots: &[[Wire; 4]], + validator_signatures: &[XmssSignature], +) -> XmssMultisigHashers { + assert_eq!( + validator_roots.len(), + validator_signatures.len(), + "Number of validator roots must match number of signatures" + ); + + let mut validator_hashers = Vec::new(); + for (root, sig) in validator_roots.iter().zip(validator_signatures.iter()) { + builder.assert_eq("epoch_equality", sig.epoch, epoch); + let hashers = circuit_xmss(builder, spec, domain_param, message, sig, root); + validator_hashers.push(hashers); + } + + XmssMultisigHashers { validator_hashers } +} + +/// Convenience structure for building multi-signature circuits. +/// +/// This helps organize the wire allocation for multiple validators. +pub struct MultiSigBuilder<'a> { + builder: &'a CircuitBuilder, + spec: &'a WinternitzSpec, +} + +impl<'a> MultiSigBuilder<'a> { + pub fn new(builder: &'a CircuitBuilder, spec: &'a WinternitzSpec) -> Self { + Self { builder, spec } + } + + /// Creates public input wires for parameters, message, and epoch. + pub fn create_public_inputs(&self) -> (Vec, Vec, Wire) { + let param_wire_count = self.spec.domain_param_len.div_ceil(8); + let param: Vec = (0..param_wire_count) + .map(|_| self.builder.add_inout()) + .collect(); + let message: Vec = (0..4).map(|_| self.builder.add_inout()).collect(); + let epoch = self.builder.add_inout(); + (param, message, epoch) + } + + /// Creates public input wires for validator roots. + pub fn create_validator_roots(&self, num_validators: usize) -> Vec<[Wire; 4]> { + (0..num_validators) + .map(|_| std::array::from_fn(|_| self.builder.add_inout())) + .collect() + } + + /// Creates private witness wires for a single validator's signature using the shared epoch. + pub fn create_validator_signature(&self, tree_height: usize, epoch: Wire) -> XmssSignature { + XmssSignature { + nonce: (0..3).map(|_| self.builder.add_witness()).collect(), + epoch, // Use the shared epoch wire + signature_hashes: (0..self.spec.dimension()) + .map(|_| std::array::from_fn(|_| self.builder.add_witness())) + .collect(), + public_key_hashes: (0..self.spec.dimension()) + .map(|_| std::array::from_fn(|_| self.builder.add_witness())) + .collect(), + auth_path: (0..tree_height) + .map(|_| std::array::from_fn(|_| self.builder.add_witness())) + .collect(), + } + } +} + +#[cfg(test)] +mod tests { + use std::error::Error; + + use binius_core::Word; + use rand::{RngCore, SeedableRng, rngs::StdRng}; + use rstest::rstest; + + use super::*; + use crate::{ + circuits::hash_based_sig::witness_utils::{ + ValidatorSignatureData, XmssHasherData, populate_xmss_hashers, + }, + constraint_verifier::verify_constraints, + util::pack_bytes_into_wires_le, + }; + + fn test_spec_small() -> WinternitzSpec { + WinternitzSpec { + message_hash_len: 4, + coordinate_resolution_bits: 2, + target_sum: 24, + domain_param_len: 32, + } + } + + enum MultisigTestCase { + Valid { + num_validators: usize, + tree_height: usize, + epoch: u32, + }, + Invalid { + num_validators: usize, + tree_height: usize, + epoch: u32, + corrupt_fn: fn(&mut MultisigTestData), + }, + } + + impl MultisigTestCase { + fn run(&self, spec: WinternitzSpec) { + let mut rng = StdRng::seed_from_u64(42); + + match self { + MultisigTestCase::Valid { + num_validators, + tree_height, + epoch, + } => { + let test_data = MultisigTestData::generate( + *num_validators, + *tree_height, + *epoch, + &spec, + &mut rng, + ); + test_data.run(&spec, *tree_height).unwrap(); + } + MultisigTestCase::Invalid { + num_validators, + tree_height, + epoch, + corrupt_fn, + } => { + let mut test_data = MultisigTestData::generate( + *num_validators, + *tree_height, + *epoch, + &spec, + &mut rng, + ); + corrupt_fn(&mut test_data); + let result = test_data.run(&spec, *tree_height); + assert!(result.is_err(), "Test expected to fail but passed"); + } + } + } + } + + // These functions corrupt specific aspects of multisig test data + struct MultisigTestData { + param_bytes: Vec, + message_bytes: [u8; 32], + epoch: u32, // Single shared epoch for all validators + validators: Vec, + } + + impl MultisigTestData { + /// Generate test data for multi-signature verification + fn generate( + num_validators: usize, + tree_height: usize, + epoch: u32, + spec: &WinternitzSpec, + rng: &mut StdRng, + ) -> Self { + let mut param_bytes = vec![0u8; spec.domain_param_len]; + rng.fill_bytes(&mut param_bytes); + + let mut message_bytes = [0u8; 32]; + rng.fill_bytes(&mut message_bytes); + + let mut validators = Vec::new(); + for _ in 0..num_validators { + validators.push(ValidatorSignatureData::generate( + rng, + ¶m_bytes, + &message_bytes, + epoch, // All validators sign at the same epoch + spec, + tree_height, + )); + } + + MultisigTestData { + param_bytes, + message_bytes, + epoch, + validators, + } + } + + /// Run the multi-signature verification test + fn run(&self, spec: &WinternitzSpec, tree_height: usize) -> Result<(), Box> { + let builder = CircuitBuilder::new(); + let multisig_builder = MultiSigBuilder::new(&builder, spec); + + let (param, message, epoch_wire) = multisig_builder.create_public_inputs(); + let num_validators = self.validators.len(); + let validator_roots = multisig_builder.create_validator_roots(num_validators); + + let mut validator_signatures = Vec::new(); + for _ in 0..num_validators { + validator_signatures + .push(multisig_builder.create_validator_signature(tree_height, epoch_wire)); + } + + let hashers = circuit_xmss_multisig( + &builder, + spec, + ¶m, + &message, + epoch_wire, + &validator_roots, + &validator_signatures, + ); + + let circuit = builder.build(); + let mut w = circuit.new_witness_filler(); + + // Pack param_bytes (pad to match wire count) + let mut padded_param = vec![0u8; param.len() * 8]; + padded_param[..self.param_bytes.len()].copy_from_slice(&self.param_bytes); + pack_bytes_into_wires_le(&mut w, ¶m, &padded_param); + pack_bytes_into_wires_le(&mut w, &message, &self.message_bytes); + w[epoch_wire] = Word::from_u64(self.epoch as u64); + + for (i, validator) in self.validators.iter().enumerate() { + pack_bytes_into_wires_le(&mut w, &validator_roots[i], &validator.root); + + let mut nonce_padded = [0u8; 24]; + nonce_padded[..23].copy_from_slice(&validator.nonce); + pack_bytes_into_wires_le(&mut w, &validator_signatures[i].nonce, &nonce_padded); + + for (j, sig_hash) in validator.signature_hashes.iter().enumerate() { + pack_bytes_into_wires_le( + &mut w, + &validator_signatures[i].signature_hashes[j], + sig_hash, + ); + } + + for (j, pk_hash) in validator.public_key_hashes.iter().enumerate() { + pack_bytes_into_wires_le( + &mut w, + &validator_signatures[i].public_key_hashes[j], + pk_hash, + ); + } + + for (j, auth_node) in validator.auth_path.iter().enumerate() { + pack_bytes_into_wires_le( + &mut w, + &validator_signatures[i].auth_path[j], + auth_node, + ); + } + } + + for (val_idx, validator) in self.validators.iter().enumerate() { + let validator_hasher = &hashers.validator_hashers[val_idx]; + + let hasher_data = XmssHasherData { + param_bytes: self.param_bytes.to_vec(), + message_bytes: self.message_bytes, + nonce_bytes: validator.nonce.to_vec(), + epoch: self.epoch as u64, // Use shared epoch + coords: validator.coords.clone(), + sig_hashes: validator.signature_hashes.clone(), + pk_hashes: validator.public_key_hashes.clone(), + auth_path: validator.auth_path.clone(), + }; + + populate_xmss_hashers(&mut w, validator_hasher, spec, &hasher_data); + } + + circuit.populate_wire_witness(&mut w)?; + + let cs = circuit.constraint_system(); + verify_constraints(cs, &w.into_value_vec())?; + + Ok(()) + } + } + + // ==================== Parameterized Tests ==================== + + /// Valid test cases with different configurations + #[rstest] + #[case::three_validators_epoch_1(3, 3, 1, test_spec_small())] + #[case::single_validator_epoch_2(1, 3, 2, test_spec_small())] + #[case::five_validators_epoch_0(5, 3, 0, test_spec_small())] + #[case::two_validators_spec1_epoch_0(2, 2, 0, WinternitzSpec::spec_1())] + #[case::four_validators_spec2_epoch_1(4, 3, 1, WinternitzSpec::spec_2())] + #[case::two_validators_small_tree_epoch_1(2, 2, 1, test_spec_small())] + #[case::three_validators_large_tree_epoch_2(3, 4, 2, test_spec_small())] + #[case::many_validators_same_epoch(6, 3, 2, test_spec_small())] + fn test_xmss_multisig_valid( + #[case] num_validators: usize, + #[case] tree_height: usize, + #[case] epoch: u32, + #[case] spec: WinternitzSpec, + ) { + MultisigTestCase::Valid { + num_validators, + tree_height, + epoch, + } + .run(spec); + } + + fn corrupt_one_validator_signature(test_data: &mut MultisigTestData) { + // Corrupt the second validator's first signature hash + if test_data.validators.len() > 1 { + test_data.validators[1].signature_hashes[0][0] ^= 0xFF; + } + } + + fn corrupt_shared_epoch(test_data: &mut MultisigTestData) { + // Change the shared epoch to an incorrect value + test_data.epoch = (test_data.epoch + 1) % 8; + } + + fn corrupt_one_validator_message(test_data: &mut MultisigTestData) { + // Make second validator sign a different message + if test_data.validators.len() > 1 { + let mut rng = StdRng::seed_from_u64(99999); + let mut wrong_message = [0u8; 32]; + rng.fill_bytes(&mut wrong_message); + + // Regenerate second validator's signature with wrong message + let spec = test_spec_small(); + test_data.validators[1] = ValidatorSignatureData::generate( + &mut rng, + &test_data.param_bytes, + &wrong_message, + test_data.epoch, + &spec, + 3, + ); + } + } + + fn corrupt_one_validator_root(test_data: &mut MultisigTestData) { + // Corrupt the first validator's root + if !test_data.validators.is_empty() { + test_data.validators[0].root[0] ^= 0xFF; + } + } + + fn corrupt_one_validator_auth_path(test_data: &mut MultisigTestData) { + // Corrupt the last validator's first auth path node + if let Some(validator) = test_data.validators.last_mut() + && !validator.auth_path.is_empty() + { + validator.auth_path[0][0] ^= 0xFF; + } + } + + fn corrupt_validator_epochs(test_data: &mut MultisigTestData) { + // Make validators sign at different epochs + if test_data.validators.len() > 1 { + let mut rng = StdRng::seed_from_u64(88888); + let spec = test_spec_small(); + + // Regenerate second validator with a different epoch + let different_epoch = (test_data.epoch + 1) % 8; + test_data.validators[1] = ValidatorSignatureData::generate( + &mut rng, + &test_data.param_bytes, + &test_data.message_bytes, + different_epoch, + &spec, + 3, + ); + } + } + + /// Test that mismatched number of roots and signatures causes panic + #[test] + #[should_panic(expected = "Number of validator roots must match number of signatures")] + fn test_multisig_mismatched_validators() { + let builder = CircuitBuilder::new(); + let spec = test_spec_small(); + let multisig_builder = MultiSigBuilder::new(&builder, &spec); + + let (param, message, epoch) = multisig_builder.create_public_inputs(); + + // Create 3 roots but only 2 signatures + let validator_roots = multisig_builder.create_validator_roots(3); + let validator_signatures = vec![ + multisig_builder.create_validator_signature(3, epoch), + multisig_builder.create_validator_signature(3, epoch), + ]; + + // This should panic + circuit_xmss_multisig( + &builder, + &spec, + ¶m, + &message, + epoch, + &validator_roots, + &validator_signatures, + ); + } + + /// Invalid test cases for multisig with various corruption scenarios + #[rstest] + #[case::corrupt_one_signature(corrupt_one_validator_signature)] + #[case::corrupt_epoch(corrupt_shared_epoch)] + #[case::corrupt_different_message(corrupt_one_validator_message)] + #[case::corrupt_root(corrupt_one_validator_root)] + #[case::corrupt_auth_path(corrupt_one_validator_auth_path)] + #[case::corrupt_validator_epochs(corrupt_validator_epochs)] + fn test_xmss_multisig_invalid(#[case] corrupt_fn: fn(&mut MultisigTestData)) { + MultisigTestCase::Invalid { + num_validators: 3, + tree_height: 3, + epoch: 2, // All validators sign at epoch 2 + corrupt_fn, + } + .run(test_spec_small()); + } +}