@@ -3,9 +3,16 @@ use std::{array, iter, sync::Arc};
33use itertools:: { zip_eq, Itertools } ;
44use plonky2:: {
55 field:: types:: Field ,
6- hash:: { hash_types:: HashOutTarget , poseidon:: PoseidonHash } ,
7- iop:: { target:: BoolTarget , witness:: PartialWitness } ,
8- plonk:: circuit_builder:: CircuitBuilder ,
6+ hash:: {
7+ hash_types:: { HashOutTarget , RichField , NUM_HASH_OUT_ELTS } ,
8+ hashing:: PlonkyPermutation ,
9+ poseidon:: { PoseidonHash , PoseidonPermutation } ,
10+ } ,
11+ iop:: {
12+ target:: { BoolTarget , Target } ,
13+ witness:: PartialWitness ,
14+ } ,
15+ plonk:: { circuit_builder:: CircuitBuilder , config:: AlgebraicHasher } ,
916} ;
1017
1118use crate :: {
@@ -22,7 +29,7 @@ use crate::{
2229 signedpod:: { SignedPodVerifyGadget , SignedPodVerifyTarget } ,
2330 } ,
2431 error:: Result ,
25- mainpod,
32+ mainpod:: { self , pad_statement } ,
2633 primitives:: merkletree:: {
2734 MerkleClaimAndProof , MerkleClaimAndProofTarget , MerkleProofGadget ,
2835 } ,
@@ -894,6 +901,88 @@ impl CustomOperationVerifyGadget {
894901 }
895902}
896903
904+ struct CalculateIdGadget {
905+ params : Params ,
906+ }
907+
908+ impl CalculateIdGadget {
909+ /// Precompute the hash state by absorbing all full chunks from `inputs` and return the reminder
910+ /// elements that didn't fit into a chunk.
911+ fn precompute_hash_state < F : RichField , P : PlonkyPermutation < F > > ( inputs : & [ F ] ) -> ( P , & [ F ] ) {
912+ let ( inputs, inputs_rem) = inputs. split_at ( ( inputs. len ( ) / P :: RATE ) * P :: RATE ) ;
913+ let mut perm = P :: new ( core:: iter:: repeat ( F :: ZERO ) ) ;
914+
915+ // Absorb all inputs up to the biggest multiple of RATE.
916+ for input_chunk in inputs. chunks ( P :: RATE ) {
917+ perm. set_from_slice ( input_chunk, 0 ) ;
918+ perm. permute ( ) ;
919+ }
920+
921+ ( perm, inputs_rem)
922+ }
923+
924+ /// Hash `inputs` starting from a circuit-constant `perm` state.
925+ fn hash_from_state < H : AlgebraicHasher < F > , P : PlonkyPermutation < F > > (
926+ builder : & mut CircuitBuilder < F , D > ,
927+ perm : P ,
928+ inputs : & [ Target ] ,
929+ ) -> HashOutTarget {
930+ let mut state =
931+ H :: AlgebraicPermutation :: new ( perm. as_ref ( ) . iter ( ) . map ( |v| builder. constant ( * v) ) ) ;
932+
933+ // Absorb all input chunks.
934+ for input_chunk in inputs. chunks ( H :: AlgebraicPermutation :: RATE ) {
935+ // Overwrite the first r elements with the inputs. This differs from a standard sponge,
936+ // where we would xor or add in the inputs. This is a well-known variant, though,
937+ // sometimes called "overwrite mode".
938+ state. set_from_slice ( input_chunk, 0 ) ;
939+ state = builder. permute :: < H > ( state) ;
940+ }
941+
942+ let num_outputs = NUM_HASH_OUT_ELTS ;
943+ // Squeeze until we have the desired number of outputs.
944+ let mut outputs = Vec :: with_capacity ( num_outputs) ;
945+ loop {
946+ for & s in state. squeeze ( ) {
947+ outputs. push ( s) ;
948+ if outputs. len ( ) == num_outputs {
949+ return HashOutTarget :: from_vec ( outputs) ;
950+ }
951+ }
952+ state = builder. permute :: < H > ( state) ;
953+ }
954+ }
955+
956+ fn eval (
957+ & self ,
958+ builder : & mut CircuitBuilder < F , D > ,
959+ statements : & [ StatementTarget ] ,
960+ ) -> HashOutTarget {
961+ let measure = measure_gates_begin ! ( builder, "CalculateId" ) ;
962+ let statements_rev_flattened = statements. iter ( ) . rev ( ) . flat_map ( |s| s. flatten ( ) ) ;
963+ let mut none_st = mainpod:: Statement :: from ( Statement :: None ) ;
964+ pad_statement ( & self . params , & mut none_st) ;
965+ let front_pad_elts = iter:: repeat ( & none_st)
966+ . take ( self . params . num_public_statements_id - self . params . max_public_statements )
967+ . flat_map ( |s| s. to_fields ( & self . params ) )
968+ . collect_vec ( ) ;
969+ let ( perm, front_pad_elts_rem) =
970+ Self :: precompute_hash_state :: < F , PoseidonPermutation < F > > ( & front_pad_elts) ;
971+
972+ // Precompute the Poseidon state for the initial padding chunks
973+ let inputs = front_pad_elts_rem
974+ . iter ( )
975+ . map ( |v| builder. constant ( * v) )
976+ . chain ( statements_rev_flattened)
977+ . collect_vec ( ) ;
978+ let id =
979+ Self :: hash_from_state :: < PoseidonHash , PoseidonPermutation < F > > ( builder, perm, & inputs) ;
980+
981+ measure_gates_end ! ( builder, measure) ;
982+ id
983+ }
984+ }
985+
897986struct MainPodVerifyGadget {
898987 params : Params ,
899988}
@@ -1089,10 +1178,10 @@ impl MainPodVerifyGadget {
10891178 self . build_custom_predicate_verification_table ( builder, & custom_predicate_table) ?;
10901179
10911180 // 2. Calculate the Pod Id from the public statements
1092- let measure_calc_id = measure_gates_begin ! ( builder , "MainPodId" ) ;
1093- let pub_statements_flattened = pub_statements . iter ( ) . flat_map ( |s| s . flatten ( ) ) . collect ( ) ;
1094- let id = builder . hash_n_to_hash_no_pad :: < PoseidonHash > ( pub_statements_flattened ) ;
1095- measure_gates_end ! ( builder, measure_calc_id ) ;
1181+ let id = CalculateIdGadget {
1182+ params : self . params . clone ( ) ,
1183+ }
1184+ . eval ( builder, pub_statements ) ;
10961185
10971186 // 4. Verify type
10981187 let type_statement = & pub_statements[ 0 ] ;
@@ -1266,10 +1355,12 @@ impl MainPodVerifyCircuit {
12661355
12671356#[ cfg( test) ]
12681357mod tests {
1269- use std:: ops:: Not ;
1358+ use std:: { iter , ops:: Not } ;
12701359
12711360 use plonky2:: {
12721361 field:: { goldilocks_field:: GoldilocksField , types:: Field } ,
1362+ hash:: hash_types:: HashOut ,
1363+ iop:: witness:: WitnessWrite ,
12731364 plonk:: { circuit_builder:: CircuitBuilder , circuit_data:: CircuitConfig } ,
12741365 } ;
12751366
@@ -1278,7 +1369,7 @@ mod tests {
12781369 backends:: plonky2:: {
12791370 basetypes:: C ,
12801371 circuits:: common:: tests:: I64_TEST_PAIRS ,
1281- mainpod:: { OperationArg , OperationAux } ,
1372+ mainpod:: { calculate_id , OperationArg , OperationAux } ,
12821373 primitives:: merkletree:: { MerkleClaimAndProof , MerkleTree } ,
12831374 } ,
12841375 frontend:: { self , key, literal, CustomPredicateBatchBuilder , StatementTmplBuilder } ,
@@ -2681,4 +2772,106 @@ mod tests {
26812772
26822773 Ok ( ( ) )
26832774 }
2775+
2776+ fn helper_calculate_id ( params : & Params , statements : & [ Statement ] ) -> Result < ( ) > {
2777+ let config = CircuitConfig :: standard_recursion_config ( ) ;
2778+ let mut builder = CircuitBuilder :: < F , D > :: new ( config) ;
2779+ let gadget = CalculateIdGadget {
2780+ params : params. clone ( ) ,
2781+ } ;
2782+
2783+ let statements_target = ( 0 ..params. max_public_statements )
2784+ . map ( |_| builder. add_virtual_statement ( params) )
2785+ . collect_vec ( ) ;
2786+ let id_target = gadget. eval ( & mut builder, & statements_target) ;
2787+
2788+ let mut pw = PartialWitness :: < F > :: new ( ) ;
2789+
2790+ // Input
2791+ let statements = statements
2792+ . into_iter ( )
2793+ . map ( |st| {
2794+ let mut st = mainpod:: Statement :: from ( st. clone ( ) ) ;
2795+ pad_statement ( params, & mut st) ;
2796+ st
2797+ } )
2798+ . collect_vec ( ) ;
2799+ for ( st_target, st) in statements_target. iter ( ) . zip ( statements. iter ( ) ) {
2800+ st_target. set_targets ( & mut pw, params, st) ?;
2801+ }
2802+ // Expected Output
2803+ let expected_id = calculate_id ( & statements, params) ;
2804+ pw. set_hash_target (
2805+ id_target,
2806+ HashOut {
2807+ elements : expected_id. 0 ,
2808+ } ,
2809+ ) ?;
2810+
2811+ // generate & verify proof
2812+ let data = builder. build :: < C > ( ) ;
2813+ let proof = data. prove ( pw) ?;
2814+ Ok ( data. verify ( proof. clone ( ) ) ?)
2815+ }
2816+
2817+ #[ test]
2818+ fn test_calculate_id ( ) -> frontend:: Result < ( ) > {
2819+ // Case with no public public statements
2820+ let params = Params {
2821+ max_public_statements : 0 ,
2822+ num_public_statements_id : 8 ,
2823+ ..Default :: default ( )
2824+ } ;
2825+
2826+ helper_calculate_id ( & params, & [ ] ) . unwrap ( ) ;
2827+
2828+ // Case with number of statements for the id equal to number of public statements
2829+ let params = Params {
2830+ max_public_statements : 2 ,
2831+ num_public_statements_id : 2 ,
2832+ ..Default :: default ( )
2833+ } ;
2834+
2835+ let statements = [
2836+ Statement :: ValueOf ( AnchoredKey :: from ( ( SELF , "foo" ) ) , Value :: from ( 42 ) ) ,
2837+ Statement :: Equal (
2838+ AnchoredKey :: from ( ( SELF , "bar" ) ) ,
2839+ AnchoredKey :: from ( ( SELF , "baz" ) ) ,
2840+ ) ,
2841+ ]
2842+ . into_iter ( )
2843+ . chain ( iter:: repeat ( Statement :: None ) )
2844+ . take ( params. max_public_statements )
2845+ . collect_vec ( ) ;
2846+
2847+ helper_calculate_id ( & params, & statements) . unwrap ( ) ;
2848+
2849+ // Case with more statements for the id than the number of public statements
2850+ let params = Params {
2851+ max_public_statements : 4 ,
2852+ num_public_statements_id : 6 ,
2853+ ..Default :: default ( )
2854+ } ;
2855+
2856+ let pod_id = PodId ( hash_str ( "pod_id" ) ) ;
2857+ let statements = [
2858+ Statement :: ValueOf ( AnchoredKey :: from ( ( SELF , "foo" ) ) , Value :: from ( 42 ) ) ,
2859+ Statement :: Equal (
2860+ AnchoredKey :: from ( ( SELF , "bar" ) ) ,
2861+ AnchoredKey :: from ( ( SELF , "baz" ) ) ,
2862+ ) ,
2863+ Statement :: Lt (
2864+ AnchoredKey :: from ( ( pod_id, "one" ) ) ,
2865+ AnchoredKey :: from ( ( pod_id, "two" ) ) ,
2866+ ) ,
2867+ ]
2868+ . into_iter ( )
2869+ . chain ( iter:: repeat ( Statement :: None ) )
2870+ . take ( params. max_public_statements )
2871+ . collect_vec ( ) ;
2872+
2873+ helper_calculate_id ( & params, & statements) . unwrap ( ) ;
2874+
2875+ Ok ( ( ) )
2876+ }
26842877}
0 commit comments