|
| 1 | +use binius_core::Word; |
| 2 | + |
| 3 | +use super::tweak::ChainTweak; |
| 4 | +use crate::{ |
| 5 | + circuits::multiplexer::multi_wire_multiplex, |
| 6 | + compiler::{CircuitBuilder, Wire}, |
| 7 | +}; |
| 8 | + |
| 9 | +/// Verifies a hash chain for hash-based signature schemes using Keccak-256. |
| 10 | +/// |
| 11 | +/// This function iteratively hashes a signature chunk a specified number of times |
| 12 | +/// and verifies that the final result matches an expected end hash. |
| 13 | +/// |
| 14 | +/// # Hash Chain Structure |
| 15 | +/// |
| 16 | +/// A hash chain is a sequence of values where each value is computed by hashing the previous one: |
| 17 | +/// ```text |
| 18 | +/// start → H(start) → H(H(start)) → ... → end |
| 19 | +/// ``` |
| 20 | +/// |
| 21 | +/// # Circuit Operation |
| 22 | +/// |
| 23 | +/// The circuit performs `coordinate` iterations of hashing, where each iteration: |
| 24 | +/// 1. Takes the current hash value |
| 25 | +/// 2. Applies Keccak-256 with appropriate tweaking parameters |
| 26 | +/// 3. Uses the result as input for the next iteration |
| 27 | +/// |
| 28 | +/// After all iterations, it verifies the final hash equals `end_hash`. |
| 29 | +/// |
| 30 | +/// # Arguments |
| 31 | +/// |
| 32 | +/// * `builder` - Circuit builder for constructing constraints |
| 33 | +/// * `param` - Cryptographic parameter as 64-bit packed wires (LE format) |
| 34 | +/// * `param_len` - Actual byte length of the parameter (must be less than or equal to param.len() * |
| 35 | +/// 8) |
| 36 | +/// * `chain_index` - Index of this chain in the signature structure |
| 37 | +/// * `signature_chunk` - Starting hash value (32 bytes as 4x64-bit LE wires) |
| 38 | +/// * `coordinate` - Number of hash iterations to perform (from codeword) |
| 39 | +/// * `max_chain_len` - Maximum chain length |
| 40 | +/// * `end_hash` - Expected final hash value (32 bytes as 4x64-bit LE wires) |
| 41 | +/// |
| 42 | +/// # Returns |
| 43 | +/// |
| 44 | +/// A vector of `ChainTweak` hashers that need to be populated with witness values. |
| 45 | +/// The number of hashers equals the maximum chain length supported. |
| 46 | +#[allow(clippy::too_many_arguments)] |
| 47 | +pub fn verify_chain( |
| 48 | + builder: &CircuitBuilder, |
| 49 | + param: &[Wire], |
| 50 | + param_len: usize, |
| 51 | + chain_index: Wire, |
| 52 | + signature_chunk: [Wire; 4], |
| 53 | + coordinate: Wire, |
| 54 | + max_chain_len: u64, |
| 55 | + end_hash: [Wire; 4], |
| 56 | +) -> Vec<ChainTweak> { |
| 57 | + assert!( |
| 58 | + param_len <= param.len() * 8, |
| 59 | + "param_len {} exceeds maximum capacity {} of param wires", |
| 60 | + param_len, |
| 61 | + param.len() * 8 |
| 62 | + ); |
| 63 | + let mut hashers = Vec::with_capacity(max_chain_len as usize); |
| 64 | + let mut current_hash = signature_chunk; |
| 65 | + |
| 66 | + let one = builder.add_constant(Word::ONE); |
| 67 | + let zero = builder.add_constant(Word::ZERO); |
| 68 | + let max_chain_len_wire = builder.add_constant_64(max_chain_len); |
| 69 | + |
| 70 | + // Build the hash chain |
| 71 | + for step in 0..max_chain_len { |
| 72 | + let step_wire = builder.add_constant_64(step); |
| 73 | + let (position, _) = builder.iadd_cin_cout(step_wire, coordinate, zero); |
| 74 | + let (position_plus_one, _) = builder.iadd_cin_cout(position, one, zero); |
| 75 | + |
| 76 | + let next_hash = std::array::from_fn(|_| builder.add_witness()); |
| 77 | + let hasher = ChainTweak::new( |
| 78 | + builder, |
| 79 | + param.to_vec(), |
| 80 | + param_len, |
| 81 | + current_hash, |
| 82 | + chain_index, |
| 83 | + position_plus_one, |
| 84 | + next_hash, |
| 85 | + ); |
| 86 | + |
| 87 | + hashers.push(hasher); |
| 88 | + |
| 89 | + // Conditionally select the hash based on whether position + 1 < max_chain_len |
| 90 | + // If position + 1 < max_chain_len, use next_hash, otherwise keep current_hash |
| 91 | + let position_lt_max_chain_len = builder.icmp_ult(position_plus_one, max_chain_len_wire); |
| 92 | + current_hash = |
| 93 | + multi_wire_multiplex(builder, &[¤t_hash, &next_hash], position_lt_max_chain_len) |
| 94 | + .try_into() |
| 95 | + .expect("multi_wire_multiplex should return 4 wires"); |
| 96 | + } |
| 97 | + |
| 98 | + // Assert that the final hash matches the expected end_hash |
| 99 | + builder.assert_eq_v("hash_chain_end_check", current_hash, end_hash); |
| 100 | + hashers |
| 101 | +} |
| 102 | + |
| 103 | +#[cfg(test)] |
| 104 | +mod tests { |
| 105 | + use binius_core::Word; |
| 106 | + use proptest::{prelude::*, strategy::Just}; |
| 107 | + use sha3::{Digest, Keccak256}; |
| 108 | + |
| 109 | + use super::*; |
| 110 | + use crate::{constraint_verifier::verify_constraints, util::pack_bytes_into_wires_le}; |
| 111 | + |
| 112 | + proptest! { |
| 113 | + #[test] |
| 114 | + fn test_verify_chain( |
| 115 | + (coordinate_val, max_chain_len) in (0u64..10).prop_flat_map(|coord| { |
| 116 | + // max_chain_len must be > coordinate_val for any hashing to occur |
| 117 | + // Generate max_chain_len in range [coord + 1, coord + 8] |
| 118 | + (Just(coord), (coord + 1)..=(coord + 8)) |
| 119 | + }), |
| 120 | + chain_index_val in 0u64..100, |
| 121 | + param_bytes in prop::collection::vec(any::<u8>(), 1..120), // Variable length param (1-119 bytes) |
| 122 | + signature_chunk_bytes in prop::array::uniform32(any::<u8>()), |
| 123 | + ) { |
| 124 | + let builder = CircuitBuilder::new(); |
| 125 | + |
| 126 | + let param_wire_count = param_bytes.len().div_ceil(8); |
| 127 | + let param: Vec<Wire> = (0..param_wire_count).map(|_| builder.add_inout()).collect(); |
| 128 | + let chain_index = builder.add_inout(); |
| 129 | + let signature_chunk: [Wire; 4] = std::array::from_fn(|_| builder.add_inout()); |
| 130 | + let coordinate = builder.add_inout(); |
| 131 | + let end_hash: [Wire; 4] = std::array::from_fn(|_| builder.add_inout()); |
| 132 | + |
| 133 | + let hashers = verify_chain( |
| 134 | + &builder, |
| 135 | + ¶m, |
| 136 | + param_bytes.len(), |
| 137 | + chain_index, |
| 138 | + signature_chunk, |
| 139 | + coordinate, |
| 140 | + max_chain_len, |
| 141 | + end_hash, |
| 142 | + ); |
| 143 | + |
| 144 | + let circuit = builder.build(); |
| 145 | + let mut w = circuit.new_witness_filler(); |
| 146 | + |
| 147 | + w[chain_index] = Word::from_u64(chain_index_val); |
| 148 | + w[coordinate] = Word::from_u64(coordinate_val); |
| 149 | + |
| 150 | + let mut current_hash: [u8; 32] = signature_chunk_bytes; |
| 151 | + for (step, hasher) in hashers.iter().enumerate() { |
| 152 | + let hash_position = step as u64 + coordinate_val + 1; |
| 153 | + |
| 154 | + hasher.populate_param(&mut w, ¶m_bytes); |
| 155 | + hasher.populate_hash(&mut w, ¤t_hash); |
| 156 | + hasher.populate_chain_index(&mut w, chain_index_val); |
| 157 | + hasher.populate_position(&mut w, hash_position); |
| 158 | + |
| 159 | + let message = ChainTweak::build_message( |
| 160 | + ¶m_bytes, |
| 161 | + ¤t_hash, |
| 162 | + chain_index_val, |
| 163 | + hash_position, |
| 164 | + ); |
| 165 | + hasher.populate_message(&mut w, &message); |
| 166 | + |
| 167 | + // The circuit always computes the hash, even if it won't be used in the final result |
| 168 | + // This is because the constraint system verifies all hash computations |
| 169 | + let digest: [u8; 32] = Keccak256::digest(&message).into(); |
| 170 | + hasher.populate_digest(&mut w, digest); |
| 171 | + |
| 172 | + // Only update current_hash if this hash is actually selected by the multiplexer |
| 173 | + // (when hash_position < max_chain_len) |
| 174 | + if hash_position < max_chain_len { |
| 175 | + current_hash = digest; |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + pack_bytes_into_wires_le(&mut w, &end_hash, ¤t_hash); |
| 180 | + pack_bytes_into_wires_le(&mut w, &signature_chunk, &signature_chunk_bytes); |
| 181 | + pack_bytes_into_wires_le(&mut w, ¶m, ¶m_bytes); |
| 182 | + circuit.populate_wire_witness(&mut w).unwrap(); |
| 183 | + |
| 184 | + let cs = circuit.constraint_system(); |
| 185 | + verify_constraints(cs, &w.into_value_vec()).unwrap(); |
| 186 | + } |
| 187 | + } |
| 188 | +} |
0 commit comments