Skip to content

Commit d3fef83

Browse files
authored
calculate MainPod id in a dynamic-friendly way (#241)
* calculate MainPod id in a dynamic-friendly way The MainPod id is now calculated with front padding and a fixed size independent of max_public_statements so that introduction gadgets can be verified by a MainPod while paying only for the number of statements they use. This is because with front padding of none-statements we can precompute the poseidon state corresponding to absorbing all the padding statements and only pay constraints for the non-padding statements. The id is calculated as follows: `id = hash(serialize(reverse(statements || none-statements)))` * fix test
1 parent 82481e8 commit d3fef83

File tree

6 files changed

+245
-26
lines changed

6 files changed

+245
-26
lines changed

src/backends/plonky2/circuits/mainpod.rs

Lines changed: 203 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@ use std::{array, iter, sync::Arc};
33
use itertools::{zip_eq, Itertools};
44
use plonky2::{
55
field::types::Field,
6-
hash::{hash_types::HashOutTarget, poseidon::PoseidonHash},
7-
iop::{target::BoolTarget, witness::PartialWitness},
8-
plonk::circuit_builder::CircuitBuilder,
6+
hash::{
7+
hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS},
8+
hashing::PlonkyPermutation,
9+
poseidon::{PoseidonHash, PoseidonPermutation},
10+
},
11+
iop::{
12+
target::{BoolTarget, Target},
13+
witness::PartialWitness,
14+
},
15+
plonk::{circuit_builder::CircuitBuilder, config::AlgebraicHasher},
916
};
1017

1118
use crate::{
@@ -22,7 +29,7 @@ use crate::{
2229
signedpod::{SignedPodVerifyGadget, SignedPodVerifyTarget},
2330
},
2431
error::Result,
25-
mainpod,
32+
mainpod::{self, pad_statement},
2633
primitives::merkletree::{
2734
MerkleClaimAndProof, MerkleClaimAndProofTarget, MerkleProofGadget,
2835
},
@@ -894,6 +901,88 @@ impl CustomOperationVerifyGadget {
894901
}
895902
}
896903

904+
struct CalculateIdGadget {
905+
params: Params,
906+
}
907+
908+
impl CalculateIdGadget {
909+
/// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder
910+
/// elements that didn't fit into a chunk.
911+
fn precompute_hash_state<F: RichField, P: PlonkyPermutation<F>>(inputs: &[F]) -> (P, &[F]) {
912+
let (inputs, inputs_rem) = inputs.split_at((inputs.len() / P::RATE) * P::RATE);
913+
let mut perm = P::new(core::iter::repeat(F::ZERO));
914+
915+
// Absorb all inputs up to the biggest multiple of RATE.
916+
for input_chunk in inputs.chunks(P::RATE) {
917+
perm.set_from_slice(input_chunk, 0);
918+
perm.permute();
919+
}
920+
921+
(perm, inputs_rem)
922+
}
923+
924+
/// Hash `inputs` starting from a circuit-constant `perm` state.
925+
fn hash_from_state<H: AlgebraicHasher<F>, P: PlonkyPermutation<F>>(
926+
builder: &mut CircuitBuilder<F, D>,
927+
perm: P,
928+
inputs: &[Target],
929+
) -> HashOutTarget {
930+
let mut state =
931+
H::AlgebraicPermutation::new(perm.as_ref().iter().map(|v| builder.constant(*v)));
932+
933+
// Absorb all input chunks.
934+
for input_chunk in inputs.chunks(H::AlgebraicPermutation::RATE) {
935+
// Overwrite the first r elements with the inputs. This differs from a standard sponge,
936+
// where we would xor or add in the inputs. This is a well-known variant, though,
937+
// sometimes called "overwrite mode".
938+
state.set_from_slice(input_chunk, 0);
939+
state = builder.permute::<H>(state);
940+
}
941+
942+
let num_outputs = NUM_HASH_OUT_ELTS;
943+
// Squeeze until we have the desired number of outputs.
944+
let mut outputs = Vec::with_capacity(num_outputs);
945+
loop {
946+
for &s in state.squeeze() {
947+
outputs.push(s);
948+
if outputs.len() == num_outputs {
949+
return HashOutTarget::from_vec(outputs);
950+
}
951+
}
952+
state = builder.permute::<H>(state);
953+
}
954+
}
955+
956+
fn eval(
957+
&self,
958+
builder: &mut CircuitBuilder<F, D>,
959+
statements: &[StatementTarget],
960+
) -> HashOutTarget {
961+
let measure = measure_gates_begin!(builder, "CalculateId");
962+
let statements_rev_flattened = statements.iter().rev().flat_map(|s| s.flatten());
963+
let mut none_st = mainpod::Statement::from(Statement::None);
964+
pad_statement(&self.params, &mut none_st);
965+
let front_pad_elts = iter::repeat(&none_st)
966+
.take(self.params.num_public_statements_id - self.params.max_public_statements)
967+
.flat_map(|s| s.to_fields(&self.params))
968+
.collect_vec();
969+
let (perm, front_pad_elts_rem) =
970+
Self::precompute_hash_state::<F, PoseidonPermutation<F>>(&front_pad_elts);
971+
972+
// Precompute the Poseidon state for the initial padding chunks
973+
let inputs = front_pad_elts_rem
974+
.iter()
975+
.map(|v| builder.constant(*v))
976+
.chain(statements_rev_flattened)
977+
.collect_vec();
978+
let id =
979+
Self::hash_from_state::<PoseidonHash, PoseidonPermutation<F>>(builder, perm, &inputs);
980+
981+
measure_gates_end!(builder, measure);
982+
id
983+
}
984+
}
985+
897986
struct MainPodVerifyGadget {
898987
params: Params,
899988
}
@@ -1089,10 +1178,10 @@ impl MainPodVerifyGadget {
10891178
self.build_custom_predicate_verification_table(builder, &custom_predicate_table)?;
10901179

10911180
// 2. Calculate the Pod Id from the public statements
1092-
let measure_calc_id = measure_gates_begin!(builder, "MainPodId");
1093-
let pub_statements_flattened = pub_statements.iter().flat_map(|s| s.flatten()).collect();
1094-
let id = builder.hash_n_to_hash_no_pad::<PoseidonHash>(pub_statements_flattened);
1095-
measure_gates_end!(builder, measure_calc_id);
1181+
let id = CalculateIdGadget {
1182+
params: self.params.clone(),
1183+
}
1184+
.eval(builder, pub_statements);
10961185

10971186
// 4. Verify type
10981187
let type_statement = &pub_statements[0];
@@ -1266,10 +1355,12 @@ impl MainPodVerifyCircuit {
12661355

12671356
#[cfg(test)]
12681357
mod tests {
1269-
use std::ops::Not;
1358+
use std::{iter, ops::Not};
12701359

12711360
use plonky2::{
12721361
field::{goldilocks_field::GoldilocksField, types::Field},
1362+
hash::hash_types::HashOut,
1363+
iop::witness::WitnessWrite,
12731364
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig},
12741365
};
12751366

@@ -1278,7 +1369,7 @@ mod tests {
12781369
backends::plonky2::{
12791370
basetypes::C,
12801371
circuits::common::tests::I64_TEST_PAIRS,
1281-
mainpod::{OperationArg, OperationAux},
1372+
mainpod::{calculate_id, OperationArg, OperationAux},
12821373
primitives::merkletree::{MerkleClaimAndProof, MerkleTree},
12831374
},
12841375
frontend::{self, key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
@@ -2681,4 +2772,106 @@ mod tests {
26812772

26822773
Ok(())
26832774
}
2775+
2776+
fn helper_calculate_id(params: &Params, statements: &[Statement]) -> Result<()> {
2777+
let config = CircuitConfig::standard_recursion_config();
2778+
let mut builder = CircuitBuilder::<F, D>::new(config);
2779+
let gadget = CalculateIdGadget {
2780+
params: params.clone(),
2781+
};
2782+
2783+
let statements_target = (0..params.max_public_statements)
2784+
.map(|_| builder.add_virtual_statement(params))
2785+
.collect_vec();
2786+
let id_target = gadget.eval(&mut builder, &statements_target);
2787+
2788+
let mut pw = PartialWitness::<F>::new();
2789+
2790+
// Input
2791+
let statements = statements
2792+
.into_iter()
2793+
.map(|st| {
2794+
let mut st = mainpod::Statement::from(st.clone());
2795+
pad_statement(params, &mut st);
2796+
st
2797+
})
2798+
.collect_vec();
2799+
for (st_target, st) in statements_target.iter().zip(statements.iter()) {
2800+
st_target.set_targets(&mut pw, params, st)?;
2801+
}
2802+
// Expected Output
2803+
let expected_id = calculate_id(&statements, params);
2804+
pw.set_hash_target(
2805+
id_target,
2806+
HashOut {
2807+
elements: expected_id.0,
2808+
},
2809+
)?;
2810+
2811+
// generate & verify proof
2812+
let data = builder.build::<C>();
2813+
let proof = data.prove(pw)?;
2814+
Ok(data.verify(proof.clone())?)
2815+
}
2816+
2817+
#[test]
2818+
fn test_calculate_id() -> frontend::Result<()> {
2819+
// Case with no public public statements
2820+
let params = Params {
2821+
max_public_statements: 0,
2822+
num_public_statements_id: 8,
2823+
..Default::default()
2824+
};
2825+
2826+
helper_calculate_id(&params, &[]).unwrap();
2827+
2828+
// Case with number of statements for the id equal to number of public statements
2829+
let params = Params {
2830+
max_public_statements: 2,
2831+
num_public_statements_id: 2,
2832+
..Default::default()
2833+
};
2834+
2835+
let statements = [
2836+
Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)),
2837+
Statement::Equal(
2838+
AnchoredKey::from((SELF, "bar")),
2839+
AnchoredKey::from((SELF, "baz")),
2840+
),
2841+
]
2842+
.into_iter()
2843+
.chain(iter::repeat(Statement::None))
2844+
.take(params.max_public_statements)
2845+
.collect_vec();
2846+
2847+
helper_calculate_id(&params, &statements).unwrap();
2848+
2849+
// Case with more statements for the id than the number of public statements
2850+
let params = Params {
2851+
max_public_statements: 4,
2852+
num_public_statements_id: 6,
2853+
..Default::default()
2854+
};
2855+
2856+
let pod_id = PodId(hash_str("pod_id"));
2857+
let statements = [
2858+
Statement::ValueOf(AnchoredKey::from((SELF, "foo")), Value::from(42)),
2859+
Statement::Equal(
2860+
AnchoredKey::from((SELF, "bar")),
2861+
AnchoredKey::from((SELF, "baz")),
2862+
),
2863+
Statement::Lt(
2864+
AnchoredKey::from((pod_id, "one")),
2865+
AnchoredKey::from((pod_id, "two")),
2866+
),
2867+
]
2868+
.into_iter()
2869+
.chain(iter::repeat(Statement::None))
2870+
.take(params.max_public_statements)
2871+
.collect_vec();
2872+
2873+
helper_calculate_id(&params, &statements).unwrap();
2874+
2875+
Ok(())
2876+
}
26842877
}

src/backends/plonky2/mainpod/mod.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
pub mod operation;
22
pub mod statement;
3-
use std::{any::Any, sync::Arc};
3+
use std::{any::Any, iter, sync::Arc};
44

55
use base64::{prelude::BASE64_STANDARD, Engine};
66
use itertools::Itertools;
@@ -35,11 +35,30 @@ use crate::{
3535
},
3636
};
3737

38-
/// Hash a list of public statements to derive the PodId
39-
pub(crate) fn hash_statements(statements: &[Statement], _params: &Params) -> middleware::Hash {
40-
let field_elems = statements
38+
/// Hash a list of public statements to derive the PodId. To make circuits with different number
39+
/// of `max_public_statements compatible we pad the statements up to `num_public_statements_id`.
40+
/// As an optimization we front pad with none-statements so that circuits with a small
41+
/// `max_public_statements` only pay for `max_public_statements` by starting the poseidon state
42+
/// with a precomputed constant corresponding to the front-padding part:
43+
/// `id = hash(serialize(reverse(statements || none-statements)))`
44+
pub(crate) fn calculate_id(statements: &[Statement], params: &Params) -> middleware::Hash {
45+
assert_eq!(params.max_public_statements, statements.len());
46+
assert!(params.max_public_statements <= params.num_public_statements_id);
47+
statements
48+
.iter()
49+
.for_each(|st| assert_eq!(params.max_statement_args, st.1.len()));
50+
51+
let mut none_st: Statement = middleware::Statement::None.into();
52+
pad_statement(params, &mut none_st);
53+
let statements_back_padded = statements
54+
.iter()
55+
.chain(iter::repeat(&none_st))
56+
.take(params.num_public_statements_id)
57+
.collect_vec();
58+
let field_elems = statements_back_padded
4159
.iter()
42-
.flat_map(|statement| statement.clone().to_fields(_params))
60+
.rev()
61+
.flat_map(|statement| statement.to_fields(params))
4362
.collect::<Vec<_>>();
4463
Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
4564
}
@@ -421,7 +440,7 @@ impl Prover {
421440
let public_statements =
422441
statements[statements.len() - params.max_public_statements..].to_vec();
423442
// get the id out of the public statements
424-
let id: PodId = PodId(hash_statements(&public_statements, params));
443+
let id: PodId = PodId(calculate_id(&public_statements, params));
425444

426445
let input = MainPodVerifyInput {
427446
signed_pods: signed_pods_input,
@@ -505,7 +524,7 @@ fn get_common_data(params: &Params) -> Result<CommonCircuitData<F, D>, Error> {
505524
impl MainPod {
506525
fn _verify(&self) -> Result<()> {
507526
// 2. get the id out of the public statements
508-
let id: PodId = PodId(hash_statements(&self.public_statements, &self.params));
527+
let id: PodId = PodId(calculate_id(&self.public_statements, &self.params));
509528
if id != self.id {
510529
return Err(Error::id_not_equal(self.id, id));
511530
}
@@ -700,6 +719,7 @@ pub mod tests {
700719
max_statements: 5,
701720
max_signed_pod_values: 2,
702721
max_public_statements: 2,
722+
num_public_statements_id: 4,
703723
max_statement_args: 2,
704724
max_operation_args: 3,
705725
max_custom_predicate_batches: 2,

src/backends/plonky2/mock/mainpod.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{
1111
backends::plonky2::{
1212
error::{Error, Result},
1313
mainpod::{
14-
extract_merkle_proofs, hash_statements, layout_statements, normalize_statement,
14+
calculate_id, extract_merkle_proofs, layout_statements, normalize_statement,
1515
process_private_statements_operations, process_public_statements_operations, Operation,
1616
Statement,
1717
},
@@ -163,7 +163,7 @@ impl MockMainPod {
163163
statements[statements.len() - params.max_public_statements..].to_vec();
164164

165165
// get the id out of the public statements
166-
let id: PodId = PodId(hash_statements(&public_statements, params));
166+
let id: PodId = PodId(calculate_id(&public_statements, params));
167167

168168
Ok(Self {
169169
params: params.clone(),
@@ -197,7 +197,7 @@ impl MockMainPod {
197197
// get the input_statements from the self.statements
198198
let input_statements = &self.statements[input_statement_offset..];
199199
// 2. get the id out of the public statements, and ensure it is equal to self.id
200-
let ids_match = self.id == PodId(hash_statements(&self.public_statements, &self.params));
200+
let ids_match = self.id == PodId(calculate_id(&self.public_statements, &self.params));
201201
// find a ValueOf statement from the public statements with key=KEY_TYPE and check that the
202202
// value is PodType::MockMainPod
203203
let has_type_statement = self.public_statements.iter().any(|s| {
@@ -351,8 +351,7 @@ pub mod tests {
351351

352352
#[test]
353353
fn test_mock_main_great_boy() -> frontend::Result<()> {
354-
let params = middleware::Params::default();
355-
let great_boy_builder = great_boy_pod_full_flow()?;
354+
let (params, great_boy_builder) = great_boy_pod_full_flow()?;
356355

357356
let mut prover = MockProver {};
358357
let great_boy_pod = great_boy_builder.prove(&mut prover, &params)?;

0 commit comments

Comments
 (0)