@@ -318,15 +318,7 @@ impl fmt::Display for CustomPredicate {
318318 fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
319319 writeln ! ( f, "{}<" , if self . conjunction { "and" } else { "or" } ) ?;
320320 for st in & self . statements {
321- // NOTE: With recursive custom predicates we can't just display the predicate again
322- // because then this call will run into an infinite loop. Instead we should find a way
323- // to name custom predicates and use the names here. For this we will probably need an
324- // auxiliary data structure to hold the names, which IMO would be too complex to live
325- // in the middleware. For the middleware we may just print the custom predicate hash.
326- match & st. 0 {
327- Predicate :: Native ( p) => write ! ( f, " {:?}(" , p) ?,
328- Predicate :: Custom ( _p) => write ! ( f, " TODO(" ) ?,
329- }
321+ write ! ( f, " {}" , st. 0 ) ?;
330322 for ( i, arg) in st. 1 . iter ( ) . enumerate ( ) {
331323 if i != 0 {
332324 write ! ( f, ", " ) ?;
@@ -347,10 +339,23 @@ impl fmt::Display for CustomPredicate {
347339 }
348340}
349341
342+ #[ derive( Debug ) ]
343+ pub struct CustomPredicateBatch {
344+ predicates : Vec < CustomPredicate > ,
345+ }
346+
347+ impl CustomPredicateBatch {
348+ pub fn hash ( & self ) -> Hash {
349+ // TODO
350+ hash_str ( & format ! ( "{:?}" , self ) )
351+ }
352+ }
353+
350354#[ derive( Clone , Debug ) ]
351355pub enum Predicate {
352356 Native ( NativePredicate ) ,
353- Custom ( Arc < CustomPredicate > ) ,
357+ BatchSelf ( usize ) ,
358+ Custom ( Arc < CustomPredicateBatch > , usize ) ,
354359}
355360
356361impl From < NativePredicate > for Predicate {
@@ -369,7 +374,8 @@ impl fmt::Display for Predicate {
369374 fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
370375 match self {
371376 Self :: Native ( p) => write ! ( f, "{:?}" , p) ,
372- Self :: Custom ( p) => write ! ( f, "{}" , p) ,
377+ Self :: BatchSelf ( i) => write ! ( f, "self.{}" , i) ,
378+ Self :: Custom ( pb, i) => write ! ( f, "{}.{}" , pb. hash( ) , i) ,
373379 }
374380 }
375381}
@@ -778,12 +784,74 @@ mod tests {
778784 }
779785 }
780786
781- fn predicate_and ( args : & [ & str ] , priv_args : & [ & str ] , sts : & [ StatementTmplBuilder ] ) -> Predicate {
782- predicate ( true , args , priv_args , sts )
787+ struct CustomPredicateBatchBuilder {
788+ predicates : Vec < CustomPredicate > ,
783789 }
784790
785- fn predicate_or ( args : & [ & str ] , priv_args : & [ & str ] , sts : & [ StatementTmplBuilder ] ) -> Predicate {
786- predicate ( false , args, priv_args, sts)
791+ impl CustomPredicateBatchBuilder {
792+ fn new ( ) -> Self {
793+ Self {
794+ predicates : Vec :: new ( ) ,
795+ }
796+ }
797+
798+ fn predicate_and (
799+ & mut self ,
800+ args : & [ & str ] ,
801+ priv_args : & [ & str ] ,
802+ sts : & [ StatementTmplBuilder ] ,
803+ ) -> Predicate {
804+ self . predicate ( true , args, priv_args, sts)
805+ }
806+
807+ fn predicate_or (
808+ & mut self ,
809+ args : & [ & str ] ,
810+ priv_args : & [ & str ] ,
811+ sts : & [ StatementTmplBuilder ] ,
812+ ) -> Predicate {
813+ self . predicate ( false , args, priv_args, sts)
814+ }
815+
816+ fn predicate (
817+ & mut self ,
818+ conjunction : bool ,
819+ args : & [ & str ] ,
820+ priv_args : & [ & str ] ,
821+ sts : & [ StatementTmplBuilder ] ,
822+ ) -> Predicate {
823+ use BuilderArg as BA ;
824+ let statements = sts
825+ . iter ( )
826+ . map ( |sb| {
827+ let args = sb
828+ . args
829+ . iter ( )
830+ . map ( |a| match a {
831+ BA :: Literal ( v) => StatementTmplArg :: Literal ( * v) ,
832+ BA :: Key ( pod_id, key) => StatementTmplArg :: Key (
833+ resolve_wildcard ( args, priv_args, pod_id) ,
834+ resolve_wildcard ( args, priv_args, key) ,
835+ ) ,
836+ } )
837+ . collect ( ) ;
838+ StatementTmpl ( sb. predicate . clone ( ) , args)
839+ } )
840+ . collect ( ) ;
841+ let custom_predicate = CustomPredicate {
842+ conjunction,
843+ statements,
844+ args_len : args. len ( ) ,
845+ } ;
846+ self . predicates . push ( custom_predicate) ;
847+ Predicate :: BatchSelf ( self . predicates . len ( ) - 1 )
848+ }
849+
850+ fn finish ( self ) -> Arc < CustomPredicateBatch > {
851+ Arc :: new ( CustomPredicateBatch {
852+ predicates : self . predicates ,
853+ } )
854+ }
787855 }
788856
789857 fn resolve_wildcard (
@@ -803,42 +871,12 @@ mod tests {
803871 }
804872 }
805873
806- fn predicate (
807- conjunction : bool ,
808- args : & [ & str ] ,
809- priv_args : & [ & str ] ,
810- sts : & [ StatementTmplBuilder ] ,
811- ) -> Predicate {
812- use BuilderArg as BA ;
813- let statements = sts
814- . iter ( )
815- . map ( |sb| {
816- let args = sb
817- . args
818- . iter ( )
819- . map ( |a| match a {
820- BA :: Literal ( v) => StatementTmplArg :: Literal ( * v) ,
821- BA :: Key ( pod_id, key) => StatementTmplArg :: Key (
822- resolve_wildcard ( args, priv_args, pod_id) ,
823- resolve_wildcard ( args, priv_args, key) ,
824- ) ,
825- } )
826- . collect ( ) ;
827- StatementTmpl ( sb. predicate . clone ( ) , args)
828- } )
829- . collect ( ) ;
830- let custom_predicate = CustomPredicate {
831- conjunction,
832- statements,
833- args_len : args. len ( ) ,
834- } ;
835- Predicate :: Custom ( Arc :: new ( custom_predicate) )
836- }
837-
838874 #[ test]
839875 fn test_custom_pred ( ) {
840876 use NativePredicate as NP ;
841- let eth_friend = predicate_and (
877+
878+ let mut builder = CustomPredicateBatchBuilder :: new ( ) ;
879+ let eth_friend = builder. predicate_and (
842880 & [ "src_or" , "src_key" , "dst_or" , "dst_key" ] ,
843881 & [ "attestation_pod" ] ,
844882 & [
@@ -854,9 +892,13 @@ mod tests {
854892 ] ,
855893 ) ;
856894
857- println ! ( "eth_friend = {}" , eth_friend) ;
895+ println ! ( "a.0. eth_friend = {}" , builder. predicates. last( ) . unwrap( ) ) ;
896+ let eth_friend = builder. finish ( ) ;
897+ // This batch only has 1 predicate, so we pick it already for convenience
898+ let eth_friend = Predicate :: Custom ( eth_friend, 0 ) ;
858899
859- let eth_dos_distance_base = predicate_and (
900+ let mut builder = CustomPredicateBatchBuilder :: new ( ) ;
901+ let eth_dos_distance_base = builder. predicate_and (
860902 & [
861903 "src_or" ,
862904 "src_key" ,
@@ -876,12 +918,14 @@ mod tests {
876918 ] ,
877919 ) ;
878920
879- println ! ( "eth_dos_distance_base = {}" , eth_dos_distance_base) ;
921+ println ! (
922+ "b.0. eth_dos_distance_base = {}" ,
923+ builder. predicates. last( ) . unwrap( )
924+ ) ;
880925
881- // TODO: replace this with a symbolic predicate index for recursion
882- let eth_dos_distance = NativePredicate :: None ;
926+ let eth_dos_distance = Predicate :: BatchSelf ( 3 ) ;
883927
884- let eth_dos_distance_ind = predicate_and (
928+ let eth_dos_distance_ind = builder . predicate_and (
885929 & [
886930 "src_or" ,
887931 "src_key" ,
@@ -899,7 +943,7 @@ mod tests {
899943 "intermed_key" ,
900944 ] ,
901945 & [
902- st_tmpl ( eth_dos_distance) // TODO: Handle recursion
946+ st_tmpl ( eth_dos_distance)
903947 . arg ( ( w ( "src_or" ) , w ( "src_key" ) ) )
904948 . arg ( ( w ( "intermed_or" ) , w ( "intermed_key" ) ) )
905949 . arg ( ( w ( "shorter_distance_or" ) , w ( "shorter_distance_key" ) ) ) ,
@@ -916,9 +960,12 @@ mod tests {
916960 ] ,
917961 ) ;
918962
919- println ! ( "eth_dos_distance_ind = {}" , eth_dos_distance_ind) ;
963+ println ! (
964+ "b.1. eth_dos_distance_ind = {}" ,
965+ builder. predicates. last( ) . unwrap( )
966+ ) ;
920967
921- let eth_dos_distance = predicate_or (
968+ let eth_dos_distance = builder . predicate_or (
922969 & [
923970 "src_or" ,
924971 "src_key" ,
@@ -940,6 +987,9 @@ mod tests {
940987 ] ,
941988 ) ;
942989
943- println ! ( "eth_dos_distance = {}" , eth_dos_distance) ;
990+ println ! (
991+ "b.2. eth_dos_distance = {}" ,
992+ builder. predicates. last( ) . unwrap( )
993+ ) ;
944994 }
945995}
0 commit comments