@@ -32,9 +32,10 @@ use crate::{
3232 } ,
3333 middleware:: {
3434 CustomPredicate , CustomPredicateBatch , CustomPredicateRef , NativeOperation ,
35- NativePredicate , OperationType , Params , Predicate , PredicatePrefix , RawValue , StatementArg ,
36- StatementTmpl , StatementTmplArg , StatementTmplArgPrefix , ToFields , Value , EMPTY_VALUE , F ,
37- HASH_SIZE , STATEMENT_ARG_F_LEN , VALUE_SIZE ,
35+ NativePredicate , OperationType , Params , Predicate , PredicateOrWildcard ,
36+ PredicateOrWildcardPrefix , PredicatePrefix , RawValue , StatementArg , StatementTmpl ,
37+ StatementTmplArg , StatementTmplArgPrefix , ToFields , Value , EMPTY_VALUE , F , HASH_SIZE ,
38+ STATEMENT_ARG_F_LEN , VALUE_SIZE ,
3839 } ,
3940} ;
4041
@@ -46,6 +47,22 @@ pub struct ValueTarget {
4647 pub elements : [ Target ; VALUE_SIZE ] ,
4748}
4849
50+ impl From < ValueTarget > for HashOutTarget {
51+ fn from ( v : ValueTarget ) -> HashOutTarget {
52+ HashOutTarget {
53+ elements : v. elements ,
54+ }
55+ }
56+ }
57+
58+ impl From < HashOutTarget > for ValueTarget {
59+ fn from ( h : HashOutTarget ) -> ValueTarget {
60+ ValueTarget {
61+ elements : h. elements ,
62+ }
63+ }
64+ }
65+
4966impl ValueTarget {
5067 pub fn zero ( builder : & mut CircuitBuilder ) -> Self {
5168 Self {
@@ -524,18 +541,112 @@ impl StatementTmplArgTarget {
524541 }
525542}
526543
544+ #[ derive( Clone , Serialize , Deserialize ) ]
545+ pub struct PredicateHashOrWildcardTarget {
546+ /// layout: `prefix | [data]`, where data is predicate_hash or wildcard_index
547+ pub elements : [ Target ; Params :: pred_hash_or_wc_size ( ) ] ,
548+ }
549+
550+ impl PredicateHashOrWildcardTarget {
551+ pub fn new ( prefix : Target , data : ValueTarget ) -> Self {
552+ let v = data. elements ;
553+ Self {
554+ elements : [ prefix, v[ 0 ] , v[ 1 ] , v[ 2 ] , v[ 3 ] ] ,
555+ }
556+ }
557+ pub fn new_pred_hash ( builder : & mut CircuitBuilder , pred_hash : HashOutTarget ) -> Self {
558+ Self :: new (
559+ builder. constant ( F :: from ( PredicateOrWildcardPrefix :: Predicate ) ) ,
560+ ValueTarget :: from ( pred_hash) ,
561+ )
562+ }
563+ pub fn is_pred ( & self , builder : & mut CircuitBuilder ) -> BoolTarget {
564+ let prefix_pred = builder. constant ( F :: from ( PredicateOrWildcardPrefix :: Predicate ) ) ;
565+ builder. is_equal ( self . elements [ 0 ] , prefix_pred)
566+ }
567+ pub fn data ( & self ) -> ValueTarget {
568+ ValueTarget {
569+ elements : self . elements [ 1 ..] . try_into ( ) . expect ( "4 elements" ) ,
570+ }
571+ }
572+ pub fn pred_hash ( & self ) -> HashOutTarget {
573+ HashOutTarget :: from ( self . data ( ) )
574+ }
575+ pub fn wc_index ( & self ) -> Target {
576+ self . elements [ 1 ]
577+ }
578+ pub fn set_targets_raw (
579+ & self ,
580+ pw : & mut PartialWitness < F > ,
581+ prefix : PredicateOrWildcardPrefix ,
582+ data : RawValue ,
583+ ) -> Result < ( ) > {
584+ pw. set_target ( self . elements [ 0 ] , F :: from ( prefix) ) ?;
585+ pw. set_target_arr ( & self . elements [ 1 ..] , & data. 0 ) ?;
586+ Ok ( ( ) )
587+ }
588+ pub fn set_targets (
589+ & self ,
590+ pw : & mut PartialWitness < F > ,
591+ params : & Params ,
592+ pred : & PredicateOrWildcard ,
593+ ) -> Result < ( ) > {
594+ match pred {
595+ PredicateOrWildcard :: Predicate ( pred) => {
596+ self . set_targets_raw (
597+ pw,
598+ PredicateOrWildcardPrefix :: Predicate ,
599+ RawValue :: from ( pred. hash ( params) ) ,
600+ ) ?;
601+ }
602+ PredicateOrWildcard :: Wildcard ( wc) => {
603+ self . set_targets_raw (
604+ pw,
605+ PredicateOrWildcardPrefix :: Wildcard ,
606+ RawValue ( [ F :: from_canonical_usize ( wc. index ) , F :: ZERO , F :: ZERO , F :: ZERO ] ) ,
607+ ) ?;
608+ }
609+ }
610+ Ok ( ( ) )
611+ }
612+ }
613+
614+ impl Flattenable for PredicateHashOrWildcardTarget {
615+ fn flatten ( & self ) -> Vec < Target > {
616+ self . elements . to_vec ( )
617+ }
618+ fn from_flattened ( _params : & Params , vs : & [ Target ] ) -> Self {
619+ Self {
620+ elements : vs. try_into ( ) . expect ( "5 elements" ) ,
621+ }
622+ }
623+ fn size ( _params : & Params ) -> usize {
624+ Params :: pred_hash_or_wc_size ( )
625+ }
626+ }
627+
527628#[ derive( Clone , Serialize , Deserialize ) ]
528629pub struct StatementTmplTarget {
630+ /// The preimage of the predicate_hash. This predicate is needed only to build the custom
631+ /// predicate table because it needs to normalize statement templates with predicates that
632+ /// refer to self into content-addressed predicates (using the batch id and index). The
633+ /// predicate type is inspected to do this normalization. After the table is built we only use
634+ /// the predicate hash for equality checks.
529635 pred : Option < PredicateTarget > ,
530- pred_hash : HashOutTarget ,
636+ /// This is constrained to be `hash(pred)` through the type constructor when we have `pred`
637+ /// and the template uses a predicate and not a wildcard.
638+ pred_hash_or_wc : PredicateHashOrWildcardTarget ,
531639 pub args : Vec < StatementTmplArgTarget > ,
532640}
533641
534642impl StatementTmplTarget {
535- pub fn new ( pred_hash : HashOutTarget , args : Vec < StatementTmplArgTarget > ) -> Self {
643+ pub fn new (
644+ pred_hash_or_wc : PredicateHashOrWildcardTarget ,
645+ args : Vec < StatementTmplArgTarget > ,
646+ ) -> Self {
536647 Self {
537648 pred : None ,
538- pred_hash ,
649+ pred_hash_or_wc ,
539650 args,
540651 }
541652 }
@@ -546,9 +657,22 @@ impl StatementTmplTarget {
546657 st_tmpl : & StatementTmpl ,
547658 ) -> Result < ( ) > {
548659 if let Some ( pred) = & self . pred {
549- pred. set_targets ( pw, params, & st_tmpl. pred ) ?;
660+ match & st_tmpl. pred_or_wc {
661+ PredicateOrWildcard :: Predicate ( p) => {
662+ // We store a predicate (not a wildcard) and we have it available. In this
663+ // case the hash will be calculated by constraints later on and we should not
664+ // rely on the original data.
665+ pred. set_targets ( pw, params, p) ?
666+ }
667+ PredicateOrWildcard :: Wildcard ( _wc) => {
668+ // Fill in with a recognizable constant for better debugging; this value is
669+ // not supposed to be used.
670+ pw. set_target_arr ( & pred. elements , & [ F ( 0xdead ) ; Params :: predicate_size ( ) ] ) ?
671+ }
672+ }
550673 }
551- pw. set_hash_target ( self . pred_hash , HashOut :: from ( st_tmpl. pred . hash ( params) ) ) ?;
674+ self . pred_hash_or_wc
675+ . set_targets ( pw, params, & st_tmpl. pred_or_wc ) ?;
552676 let arg_pad = StatementTmplArg :: None ;
553677 for ( i, arg) in st_tmpl
554678 . args
@@ -564,8 +688,8 @@ impl StatementTmplTarget {
564688 pub fn pred ( & self ) -> Option < & PredicateTarget > {
565689 self . pred . as_ref ( )
566690 }
567- pub fn pred_hash ( & self ) -> & HashOutTarget {
568- & self . pred_hash
691+ pub fn pred_hash_or_wc ( & self ) -> & PredicateHashOrWildcardTarget {
692+ & self . pred_hash_or_wc
569693 }
570694}
571695
@@ -603,6 +727,8 @@ impl CustomPredicateTarget {
603727 }
604728}
605729
730+ /// This type is used to build the custom predicate table, which exposes the custom predicates with
731+ /// normalized statement templates indexed by batch_id and custom_predicate_index.
606732#[ derive( Clone , Serialize , Deserialize ) ]
607733pub struct CustomPredicateBatchTarget {
608734 pub predicates : Vec < CustomPredicateTarget > ,
@@ -660,15 +786,17 @@ impl CustomPredicateEntryTarget {
660786 . clone ( )
661787 . into_iter ( )
662788 . map ( |st_tmpl| {
663- let pred = match st_tmpl. pred {
664- Predicate :: BatchSelf ( i) => Predicate :: Custom ( CustomPredicateRef {
665- batch : batch. clone ( ) ,
666- index : i,
667- } ) ,
668- p => p,
789+ let pred_or_wc = match st_tmpl. pred_or_wc {
790+ PredicateOrWildcard :: Predicate ( Predicate :: BatchSelf ( i) ) => {
791+ PredicateOrWildcard :: Predicate ( Predicate :: Custom ( CustomPredicateRef {
792+ batch : batch. clone ( ) ,
793+ index : i,
794+ } ) )
795+ }
796+ x => x. clone ( ) ,
669797 } ;
670798 StatementTmpl {
671- pred ,
799+ pred_or_wc ,
672800 args : st_tmpl. args ,
673801 }
674802 } )
@@ -724,15 +852,15 @@ pub struct CustomPredicateVerifyEntryTarget {
724852}
725853
726854impl CustomPredicateVerifyEntryTarget {
727- pub fn new_virtual ( params : & Params , builder : & mut CircuitBuilder , with_pred : bool ) -> Self {
855+ pub fn new_virtual ( params : & Params , builder : & mut CircuitBuilder ) -> Self {
728856 let custom_predicate_table_len =
729857 params. max_custom_predicate_batches * params. max_custom_batch_size ;
730858 CustomPredicateVerifyEntryTarget {
731859 custom_predicate_table_index : IndexTarget :: new_virtual (
732860 custom_predicate_table_len,
733861 builder,
734862 ) ,
735- custom_predicate : builder. add_virtual_custom_predicate_entry ( params, with_pred ) ,
863+ custom_predicate : builder. add_virtual_custom_predicate_entry ( params) ,
736864 args : ( 0 ..params. max_custom_predicate_wildcards )
737865 . map ( |_| builder. add_virtual_value ( ) )
738866 . collect ( ) ,
@@ -1062,7 +1190,7 @@ impl Flattenable for CustomPredicateTarget {
10621190
10631191impl Flattenable for StatementTmplTarget {
10641192 fn flatten ( & self ) -> Vec < Target > {
1065- self . pred_hash
1193+ self . pred_hash_or_wc
10661194 . flatten ( )
10671195 . into_iter ( )
10681196 . chain ( self . args . iter ( ) . flat_map ( |sta| sta. flatten ( ) ) )
@@ -1071,24 +1199,27 @@ impl Flattenable for StatementTmplTarget {
10711199
10721200 fn from_flattened ( params : & Params , v : & [ Target ] ) -> Self {
10731201 assert_eq ! ( v. len( ) , Self :: size( params) ) ;
1074- let pred_hash_end = HASH_SIZE ;
1075- let pred_hash = HashOutTarget :: from_flattened ( params, & v[ ..pred_hash_end] ) ;
1202+ let pred_hash_or_wc_end = Params :: pred_hash_or_wc_size ( ) ;
1203+ let pred_hash_or_wc =
1204+ PredicateHashOrWildcardTarget :: from_flattened ( params, & v[ ..pred_hash_or_wc_end] ) ;
10761205 let sta_size = Params :: statement_tmpl_arg_size ( ) ;
10771206 let args = ( 0 ..params. max_statement_args )
10781207 . map ( |i| {
1079- let sta_v = & v[ pred_hash_end + sta_size * i..pred_hash_end + sta_size * ( i + 1 ) ] ;
1208+ let sta_v = & v
1209+ [ pred_hash_or_wc_end + sta_size * i..pred_hash_or_wc_end + sta_size * ( i + 1 ) ] ;
10801210 StatementTmplArgTarget :: from_flattened ( params, sta_v)
10811211 } )
10821212 . collect ( ) ;
10831213 Self {
10841214 pred : None ,
1085- pred_hash ,
1215+ pred_hash_or_wc ,
10861216 args,
10871217 }
10881218 }
10891219
10901220 fn size ( params : & Params ) -> usize {
1091- HASH_SIZE + params. max_statement_args * StatementTmplArgTarget :: size ( params)
1221+ Params :: pred_hash_or_wc_size ( )
1222+ + params. max_statement_args * StatementTmplArgTarget :: size ( params)
10921223 }
10931224}
10941225
@@ -1168,11 +1299,8 @@ pub trait CircuitBuilderPod<F: RichField + Extendable<D>, const D: usize> {
11681299 params : & Params ,
11691300 with_pred : bool ,
11701301 ) -> CustomPredicateBatchTarget ;
1171- fn add_virtual_custom_predicate_entry (
1172- & mut self ,
1173- params : & Params ,
1174- with_pred : bool ,
1175- ) -> CustomPredicateEntryTarget ;
1302+ fn add_virtual_custom_predicate_entry ( & mut self , params : & Params )
1303+ -> CustomPredicateEntryTarget ;
11761304 fn select_value ( & mut self , b : BoolTarget , x : ValueTarget , y : ValueTarget ) -> ValueTarget ;
11771305 fn select_statement_arg (
11781306 & mut self ,
@@ -1320,24 +1448,32 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
13201448 }
13211449 }
13221450
1323- /// If `with_pred = true` a predicate is included and its hash constrained .
1451+ /// If `with_pred = true` a predicate is included.
13241452 /// If `with_pred = false` only the predicate hash is included.
1453+ /// The pred_hash is constrained to be hash(pred) conditionally on the template using a
1454+ /// predicate and not a wildcard.
13251455 fn add_virtual_statement_tmpl (
13261456 & mut self ,
13271457 params : & Params ,
13281458 with_pred : bool ,
13291459 ) -> StatementTmplTarget {
1330- let ( pred, pred_hash) = if with_pred {
1460+ let pred_hash_or_wc =
1461+ PredicateHashOrWildcardTarget :: new ( self . add_virtual_target ( ) , self . add_virtual_value ( ) ) ;
1462+ let pred = if with_pred {
13311463 let pred = self . add_virtual_predicate ( ) ;
13321464 let pred_hash = pred. hash ( self ) ;
1333- ( Some ( pred) , pred_hash)
1465+ let is_pred = pred_hash_or_wc. is_pred ( self ) ;
1466+ let data = pred_hash_or_wc. data ( ) ;
1467+ for i in 0 ..VALUE_SIZE {
1468+ self . conditional_assert_eq ( is_pred. target , data. elements [ i] , pred_hash. elements [ i] ) ;
1469+ }
1470+ Some ( pred)
13341471 } else {
1335- let pred_hash = self . add_virtual_hash ( ) ;
1336- ( None , pred_hash)
1472+ None
13371473 } ;
13381474 StatementTmplTarget {
13391475 pred,
1340- pred_hash ,
1476+ pred_hash_or_wc ,
13411477 args : ( 0 ..params. max_statement_args )
13421478 . map ( |_| self . add_virtual_statement_tmpl_arg ( ) )
13431479 . collect ( ) ,
@@ -1377,12 +1513,11 @@ impl CircuitBuilderPod<F, D> for CircuitBuilder {
13771513 fn add_virtual_custom_predicate_entry (
13781514 & mut self ,
13791515 params : & Params ,
1380- with_pred : bool ,
13811516 ) -> CustomPredicateEntryTarget {
13821517 CustomPredicateEntryTarget {
13831518 id : self . add_virtual_hash ( ) ,
13841519 index : self . add_virtual_target ( ) ,
1385- predicate : self . add_virtual_custom_predicate ( params, with_pred ) ,
1520+ predicate : self . add_virtual_custom_predicate ( params, false ) ,
13861521 }
13871522 }
13881523
0 commit comments