Skip to content

Commit 1e322a3

Browse files
committed
add initial counter setup
We can extend it to also count the POD operations or other kind of logic that we might want to count.
1 parent a77b522 commit 1e322a3

File tree

4 files changed

+99
-17
lines changed

4 files changed

+99
-17
lines changed

src/backends/counter.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//! This module allows to count operations involved in tests, isolating by test.
2+
//!
3+
//! Example of usage:
4+
//! ```rust
5+
//! #[test]
6+
//! fn test_example() {
7+
//! // [...]
8+
//! println!("{}", counter::counter_get());
9+
//! }
10+
//! ```
11+
//!
12+
use std::cell::RefCell;
13+
use std::fmt;
14+
use std::thread_local;
15+
16+
thread_local! {
17+
static COUNTER: RefCell<Counter> = RefCell::new(Counter::new());
18+
}
19+
20+
#[derive(Clone, Debug)]
21+
pub(crate) struct Counter {
22+
hash: usize,
23+
tree_insert: usize,
24+
tree_proof_gen: usize,
25+
}
26+
27+
impl Counter {
28+
const fn new() -> Self {
29+
Counter {
30+
hash: 0,
31+
tree_insert: 0,
32+
tree_proof_gen: 0,
33+
}
34+
}
35+
}
36+
37+
impl fmt::Display for Counter {
38+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39+
let counter = counter_get();
40+
write!(f, "Counter:\n")?;
41+
write!(f, " hashes: {},\n", counter.hash)?;
42+
write!(f, " tree_inserts: {},\n", counter.tree_insert)?;
43+
write!(f, " tree_proof_gens: {}\n", counter.tree_proof_gen)?;
44+
Ok(())
45+
}
46+
}
47+
48+
pub(crate) fn count_hash() {
49+
#[cfg(test)]
50+
COUNTER.with(|c| c.borrow_mut().hash += 1);
51+
}
52+
53+
pub(crate) fn count_tree_insert() {
54+
#[cfg(test)]
55+
COUNTER.with(|c| c.borrow_mut().tree_insert += 1);
56+
}
57+
58+
pub(crate) fn count_tree_proof_gen() {
59+
#[cfg(test)]
60+
COUNTER.with(|c| c.borrow_mut().tree_proof_gen += 1);
61+
}
62+
63+
pub(crate) fn counter_get() -> Counter {
64+
COUNTER.with(|c| c.borrow().clone())
65+
}
66+
67+
pub(crate) fn counter_reset() {
68+
COUNTER.with(|c| {
69+
c.borrow_mut().hash = 0;
70+
c.borrow_mut().tree_insert = 0;
71+
c.borrow_mut().tree_proof_gen = 0;
72+
});
73+
}

src/backends/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
pub(crate) mod counter;
2+
13
#[cfg(feature = "backend_plonky2")]
24
pub mod plonky2;

src/backends/plonky2/basetypes.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use std::fmt;
1515

1616
use crate::middleware::{Params, ToFields};
1717

18+
use crate::backends::counter;
19+
1820
/// F is the native field we use everywhere. Currently it's Goldilocks from plonky2
1921
pub type F = GoldilocksField;
2022
/// C is the Plonky2 config used in POD2 to work with Plonky2 recursion.
@@ -119,10 +121,14 @@ impl fmt::Display for Value {
119121
pub struct Hash(pub [F; HASH_SIZE]);
120122

121123
pub fn hash_value(input: &Value) -> Hash {
122-
Hash(PoseidonHash::hash_no_pad(&input.0).elements)
124+
hash_fields(&input.0)
123125
}
126+
124127
pub fn hash_fields(input: &[F]) -> Hash {
125-
Hash(PoseidonHash::hash_no_pad(input).elements)
128+
// Note: the counter counts when this method is called, but different input
129+
// sizes will have different costs in-circuit.
130+
counter::count_hash();
131+
Hash(PoseidonHash::hash_no_pad(&input).elements)
126132
}
127133

128134
impl From<Value> for Hash {
@@ -203,7 +209,7 @@ pub fn hash_str(s: &str) -> Hash {
203209
F::from_canonical_u64(v)
204210
})
205211
.collect();
206-
Hash(PoseidonHash::hash_no_pad(&input).elements)
212+
hash_fields(&input)
207213
}
208214

209215
#[cfg(test)]

src/backends/plonky2/primitives/merkletree.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
//! https://0xparc.github.io/pod2/merkletree.html .
33
use anyhow::{anyhow, Result};
44
use plonky2::field::goldilocks_field::GoldilocksField;
5-
use plonky2::hash::poseidon::PoseidonHash;
6-
use plonky2::plonk::config::Hasher;
75
use std::collections::HashMap;
86
use std::fmt;
97
use std::iter::IntoIterator;
108

11-
use crate::backends::plonky2::basetypes::{Hash, Value, F, NULL};
9+
use crate::backends::counter;
10+
use crate::backends::plonky2::basetypes::{hash_fields, Hash, Value, F, NULL};
1211

1312
/// Implements the MerkleTree specified at
1413
/// https://0xparc.github.io/pod2/merkletree.html
@@ -71,6 +70,8 @@ impl MerkleTree {
7170
/// the tree. It returns the `value` of the leaf at the given `key`, and the
7271
/// `MerkleProof`.
7372
pub fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> {
73+
counter::count_tree_proof_gen();
74+
7475
let path = keypath(self.max_depth, *key)?;
7576

7677
let mut siblings: Vec<Hash> = Vec::new();
@@ -96,6 +97,8 @@ impl MerkleTree {
9697
/// the key-value pair in the leaf reached as a result of
9798
/// resolving `key` as well as a `MerkleProof`.
9899
pub fn prove_nonexistence(&self, key: &Value) -> Result<MerkleProof> {
100+
counter::count_tree_proof_gen();
101+
99102
let path = keypath(self.max_depth, *key)?;
100103

101104
let mut siblings: Vec<Hash> = Vec::new();
@@ -175,14 +178,7 @@ impl MerkleTree {
175178
/// mitigate fake proofs.
176179
pub fn kv_hash(key: &Value, value: Option<Value>) -> Hash {
177180
value
178-
.map(|v| {
179-
Hash(
180-
PoseidonHash::hash_no_pad(
181-
&[key.0.to_vec(), v.0.to_vec(), vec![GoldilocksField(1)]].concat(),
182-
)
183-
.elements,
184-
)
185-
})
181+
.map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![GoldilocksField(1)]].concat()))
186182
.unwrap_or(Hash([GoldilocksField(0); 4]))
187183
}
188184

@@ -253,7 +249,7 @@ impl MerkleProof {
253249
} else {
254250
[h.0, sibling.0].concat()
255251
};
256-
h = Hash(PoseidonHash::hash_no_pad(&input).elements);
252+
h = hash_fields(&input);
257253
}
258254
Ok(h)
259255
}
@@ -319,7 +315,7 @@ impl Node {
319315
}
320316
}
321317

322-
/// Goes down from the current node until it encounters a terminal node,
318+
/// Goes down from the current node until it encounter a terminal node,
323319
/// viz. a leaf or empty node, or until it reaches the maximum depth. The
324320
/// `siblings` parameter is used to store the siblings while going down to
325321
/// the leaf, if the given parameter is set to `None`, then no siblings are
@@ -365,6 +361,8 @@ impl Node {
365361

366362
// adds the leaf at the tree from the current node (self), without computing any hash
367363
fn add_leaf(&mut self, lvl: usize, max_depth: usize, leaf: Leaf) -> Result<()> {
364+
counter::count_tree_insert();
365+
368366
if lvl >= max_depth {
369367
return Err(anyhow!("max depth reached"));
370368
}
@@ -480,7 +478,7 @@ impl Intermediate {
480478
let l_hash = self.left.compute_hash();
481479
let r_hash = self.right.compute_hash();
482480
let input: Vec<F> = [l_hash.0, r_hash.0].concat();
483-
let h = Hash(PoseidonHash::hash_no_pad(&input).elements);
481+
let h = hash_fields(&input);
484482
self.hash = Some(h);
485483
h
486484
}
@@ -599,8 +597,11 @@ pub mod tests {
599597
let (v, proof) = tree.prove(&Value::from(13))?;
600598
assert_eq!(v, Value::from(1013));
601599
println!("{}", proof);
600+
println!("after proof generation, {}", counter::counter_get());
602601

602+
counter::counter_reset();
603603
MerkleTree::verify(32, tree.root(), &proof, &key, &value)?;
604+
println!("after verify, {}", counter::counter_get());
604605

605606
// Exclusion checks
606607
let key = Value::from(12);

0 commit comments

Comments
 (0)