diff --git a/src/backends/counter.rs b/src/backends/counter.rs new file mode 100644 index 00000000..939e8af2 --- /dev/null +++ b/src/backends/counter.rs @@ -0,0 +1,73 @@ +//! This module allows to count operations involved in tests, isolating by test. +//! +//! Example of usage: +//! ```rust +//! #[test] +//! fn test_example() { +//! // [...] +//! println!("{}", counter::counter_get()); +//! } +//! ``` +//! +use std::cell::RefCell; +use std::fmt; +use std::thread_local; + +thread_local! { + static COUNTER: RefCell = RefCell::new(Counter::new()); +} + +#[derive(Clone, Debug)] +pub(crate) struct Counter { + hash: usize, + tree_insert: usize, + tree_proof_gen: usize, +} + +impl Counter { + const fn new() -> Self { + Counter { + hash: 0, + tree_insert: 0, + tree_proof_gen: 0, + } + } +} + +impl fmt::Display for Counter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let counter = counter_get(); + write!(f, "Counter:\n")?; + write!(f, " hashes: {},\n", counter.hash)?; + write!(f, " tree_inserts: {},\n", counter.tree_insert)?; + write!(f, " tree_proof_gens: {}\n", counter.tree_proof_gen)?; + Ok(()) + } +} + +pub(crate) fn count_hash() { + #[cfg(test)] + COUNTER.with(|c| c.borrow_mut().hash += 1); +} + +pub(crate) fn count_tree_insert() { + #[cfg(test)] + COUNTER.with(|c| c.borrow_mut().tree_insert += 1); +} + +pub(crate) fn count_tree_proof_gen() { + #[cfg(test)] + COUNTER.with(|c| c.borrow_mut().tree_proof_gen += 1); +} + +pub(crate) fn counter_get() -> Counter { + COUNTER.with(|c| c.borrow().clone()) +} + +pub(crate) fn counter_reset() { + COUNTER.with(|c| { + c.borrow_mut().hash = 0; + c.borrow_mut().tree_insert = 0; + c.borrow_mut().tree_proof_gen = 0; + }); +} diff --git a/src/backends/mod.rs b/src/backends/mod.rs index 0d137268..da174ee2 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -1,2 +1,4 @@ +pub(crate) mod counter; + #[cfg(feature = "backend_plonky2")] pub mod plonky2; diff --git a/src/backends/plonky2/basetypes.rs b/src/backends/plonky2/basetypes.rs index 1ca92213..12568756 100644 --- a/src/backends/plonky2/basetypes.rs +++ b/src/backends/plonky2/basetypes.rs @@ -15,6 +15,8 @@ use std::fmt; use crate::middleware::{Params, ToFields}; +use crate::backends::counter; + /// F is the native field we use everywhere. Currently it's Goldilocks from plonky2 pub type F = GoldilocksField; /// C is the Plonky2 config used in POD2 to work with Plonky2 recursion. @@ -119,10 +121,14 @@ impl fmt::Display for Value { pub struct Hash(pub [F; HASH_SIZE]); pub fn hash_value(input: &Value) -> Hash { - Hash(PoseidonHash::hash_no_pad(&input.0).elements) + hash_fields(&input.0) } + pub fn hash_fields(input: &[F]) -> Hash { - Hash(PoseidonHash::hash_no_pad(input).elements) + // Note: the counter counts when this method is called, but different input + // sizes will have different costs in-circuit. + counter::count_hash(); + Hash(PoseidonHash::hash_no_pad(&input).elements) } impl From for Hash { @@ -203,7 +209,7 @@ pub fn hash_str(s: &str) -> Hash { F::from_canonical_u64(v) }) .collect(); - Hash(PoseidonHash::hash_no_pad(&input).elements) + hash_fields(&input) } #[cfg(test)] diff --git a/src/backends/plonky2/primitives/merkletree.rs b/src/backends/plonky2/primitives/merkletree.rs index bcae7ceb..7ec4636c 100644 --- a/src/backends/plonky2/primitives/merkletree.rs +++ b/src/backends/plonky2/primitives/merkletree.rs @@ -2,13 +2,12 @@ //! https://0xparc.github.io/pod2/merkletree.html . use anyhow::{anyhow, Result}; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::hash::poseidon::PoseidonHash; -use plonky2::plonk::config::Hasher; use std::collections::HashMap; use std::fmt; use std::iter::IntoIterator; -use crate::backends::plonky2::basetypes::{Hash, Value, F, NULL}; +use crate::backends::counter; +use crate::backends::plonky2::basetypes::{hash_fields, Hash, Value, F, NULL}; /// Implements the MerkleTree specified at /// https://0xparc.github.io/pod2/merkletree.html @@ -71,6 +70,8 @@ impl MerkleTree { /// the tree. It returns the `value` of the leaf at the given `key`, and the /// `MerkleProof`. pub fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { + counter::count_tree_proof_gen(); + let path = keypath(self.max_depth, *key)?; let mut siblings: Vec = Vec::new(); @@ -96,6 +97,8 @@ impl MerkleTree { /// the key-value pair in the leaf reached as a result of /// resolving `key` as well as a `MerkleProof`. pub fn prove_nonexistence(&self, key: &Value) -> Result { + counter::count_tree_proof_gen(); + let path = keypath(self.max_depth, *key)?; let mut siblings: Vec = Vec::new(); @@ -175,14 +178,7 @@ impl MerkleTree { /// mitigate fake proofs. pub fn kv_hash(key: &Value, value: Option) -> Hash { value - .map(|v| { - Hash( - PoseidonHash::hash_no_pad( - &[key.0.to_vec(), v.0.to_vec(), vec![GoldilocksField(1)]].concat(), - ) - .elements, - ) - }) + .map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![GoldilocksField(1)]].concat())) .unwrap_or(Hash([GoldilocksField(0); 4])) } @@ -253,7 +249,7 @@ impl MerkleProof { } else { [h.0, sibling.0].concat() }; - h = Hash(PoseidonHash::hash_no_pad(&input).elements); + h = hash_fields(&input); } Ok(h) } @@ -365,6 +361,8 @@ impl Node { // adds the leaf at the tree from the current node (self), without computing any hash fn add_leaf(&mut self, lvl: usize, max_depth: usize, leaf: Leaf) -> Result<()> { + counter::count_tree_insert(); + if lvl >= max_depth { return Err(anyhow!("max depth reached")); } @@ -480,7 +478,7 @@ impl Intermediate { let l_hash = self.left.compute_hash(); let r_hash = self.right.compute_hash(); let input: Vec = [l_hash.0, r_hash.0].concat(); - let h = Hash(PoseidonHash::hash_no_pad(&input).elements); + let h = hash_fields(&input); self.hash = Some(h); h } @@ -599,8 +597,11 @@ pub mod tests { let (v, proof) = tree.prove(&Value::from(13))?; assert_eq!(v, Value::from(1013)); println!("{}", proof); + println!("after proof generation, {}", counter::counter_get()); + counter::counter_reset(); MerkleTree::verify(32, tree.root(), &proof, &key, &value)?; + println!("after verify, {}", counter::counter_get()); // Exclusion checks let key = Value::from(12);