Skip to content

Commit 0405a25

Browse files
committed
feat: implement custom pred recursion
1 parent 05d2162 commit 0405a25

File tree

1 file changed

+108
-58
lines changed

1 file changed

+108
-58
lines changed

src/middleware/mod.rs

Lines changed: 108 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,7 @@ impl fmt::Display for CustomPredicate {
318318
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
319319
writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?;
320320
for st in &self.statements {
321-
// NOTE: With recursive custom predicates we can't just display the predicate again
322-
// because then this call will run into an infinite loop. Instead we should find a way
323-
// to name custom predicates and use the names here. For this we will probably need an
324-
// auxiliary data structure to hold the names, which IMO would be too complex to live
325-
// in the middleware. For the middleware we may just print the custom predicate hash.
326-
match &st.0 {
327-
Predicate::Native(p) => write!(f, " {:?}(", p)?,
328-
Predicate::Custom(_p) => write!(f, " TODO(")?,
329-
}
321+
write!(f, " {}", st.0)?;
330322
for (i, arg) in st.1.iter().enumerate() {
331323
if i != 0 {
332324
write!(f, ", ")?;
@@ -347,10 +339,23 @@ impl fmt::Display for CustomPredicate {
347339
}
348340
}
349341

342+
#[derive(Debug)]
343+
pub struct CustomPredicateBatch {
344+
predicates: Vec<CustomPredicate>,
345+
}
346+
347+
impl CustomPredicateBatch {
348+
pub fn hash(&self) -> Hash {
349+
// TODO
350+
hash_str(&format!("{:?}", self))
351+
}
352+
}
353+
350354
#[derive(Clone, Debug)]
351355
pub enum Predicate {
352356
Native(NativePredicate),
353-
Custom(Arc<CustomPredicate>),
357+
BatchSelf(usize),
358+
Custom(Arc<CustomPredicateBatch>, usize),
354359
}
355360

356361
impl From<NativePredicate> for Predicate {
@@ -369,7 +374,8 @@ impl fmt::Display for Predicate {
369374
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
370375
match self {
371376
Self::Native(p) => write!(f, "{:?}", p),
372-
Self::Custom(p) => write!(f, "{}", p),
377+
Self::BatchSelf(i) => write!(f, "self.{}", i),
378+
Self::Custom(pb, i) => write!(f, "{}.{}", pb.hash(), i),
373379
}
374380
}
375381
}
@@ -778,12 +784,74 @@ mod tests {
778784
}
779785
}
780786

781-
fn predicate_and(args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder]) -> Predicate {
782-
predicate(true, args, priv_args, sts)
787+
struct CustomPredicateBatchBuilder {
788+
predicates: Vec<CustomPredicate>,
783789
}
784790

785-
fn predicate_or(args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder]) -> Predicate {
786-
predicate(false, args, priv_args, sts)
791+
impl CustomPredicateBatchBuilder {
792+
fn new() -> Self {
793+
Self {
794+
predicates: Vec::new(),
795+
}
796+
}
797+
798+
fn predicate_and(
799+
&mut self,
800+
args: &[&str],
801+
priv_args: &[&str],
802+
sts: &[StatementTmplBuilder],
803+
) -> Predicate {
804+
self.predicate(true, args, priv_args, sts)
805+
}
806+
807+
fn predicate_or(
808+
&mut self,
809+
args: &[&str],
810+
priv_args: &[&str],
811+
sts: &[StatementTmplBuilder],
812+
) -> Predicate {
813+
self.predicate(false, args, priv_args, sts)
814+
}
815+
816+
fn predicate(
817+
&mut self,
818+
conjunction: bool,
819+
args: &[&str],
820+
priv_args: &[&str],
821+
sts: &[StatementTmplBuilder],
822+
) -> Predicate {
823+
use BuilderArg as BA;
824+
let statements = sts
825+
.iter()
826+
.map(|sb| {
827+
let args = sb
828+
.args
829+
.iter()
830+
.map(|a| match a {
831+
BA::Literal(v) => StatementTmplArg::Literal(*v),
832+
BA::Key(pod_id, key) => StatementTmplArg::Key(
833+
resolve_wildcard(args, priv_args, pod_id),
834+
resolve_wildcard(args, priv_args, key),
835+
),
836+
})
837+
.collect();
838+
StatementTmpl(sb.predicate.clone(), args)
839+
})
840+
.collect();
841+
let custom_predicate = CustomPredicate {
842+
conjunction,
843+
statements,
844+
args_len: args.len(),
845+
};
846+
self.predicates.push(custom_predicate);
847+
Predicate::BatchSelf(self.predicates.len() - 1)
848+
}
849+
850+
fn finish(self) -> Arc<CustomPredicateBatch> {
851+
Arc::new(CustomPredicateBatch {
852+
predicates: self.predicates,
853+
})
854+
}
787855
}
788856

789857
fn resolve_wildcard(
@@ -803,42 +871,12 @@ mod tests {
803871
}
804872
}
805873

806-
fn predicate(
807-
conjunction: bool,
808-
args: &[&str],
809-
priv_args: &[&str],
810-
sts: &[StatementTmplBuilder],
811-
) -> Predicate {
812-
use BuilderArg as BA;
813-
let statements = sts
814-
.iter()
815-
.map(|sb| {
816-
let args = sb
817-
.args
818-
.iter()
819-
.map(|a| match a {
820-
BA::Literal(v) => StatementTmplArg::Literal(*v),
821-
BA::Key(pod_id, key) => StatementTmplArg::Key(
822-
resolve_wildcard(args, priv_args, pod_id),
823-
resolve_wildcard(args, priv_args, key),
824-
),
825-
})
826-
.collect();
827-
StatementTmpl(sb.predicate.clone(), args)
828-
})
829-
.collect();
830-
let custom_predicate = CustomPredicate {
831-
conjunction,
832-
statements,
833-
args_len: args.len(),
834-
};
835-
Predicate::Custom(Arc::new(custom_predicate))
836-
}
837-
838874
#[test]
839875
fn test_custom_pred() {
840876
use NativePredicate as NP;
841-
let eth_friend = predicate_and(
877+
878+
let mut builder = CustomPredicateBatchBuilder::new();
879+
let eth_friend = builder.predicate_and(
842880
&["src_or", "src_key", "dst_or", "dst_key"],
843881
&["attestation_pod"],
844882
&[
@@ -854,9 +892,13 @@ mod tests {
854892
],
855893
);
856894

857-
println!("eth_friend = {}", eth_friend);
895+
println!("a.0. eth_friend = {}", builder.predicates.last().unwrap());
896+
let eth_friend = builder.finish();
897+
// This batch only has 1 predicate, so we pick it already for convenience
898+
let eth_friend = Predicate::Custom(eth_friend, 0);
858899

859-
let eth_dos_distance_base = predicate_and(
900+
let mut builder = CustomPredicateBatchBuilder::new();
901+
let eth_dos_distance_base = builder.predicate_and(
860902
&[
861903
"src_or",
862904
"src_key",
@@ -876,12 +918,14 @@ mod tests {
876918
],
877919
);
878920

879-
println!("eth_dos_distance_base = {}", eth_dos_distance_base);
921+
println!(
922+
"b.0. eth_dos_distance_base = {}",
923+
builder.predicates.last().unwrap()
924+
);
880925

881-
// TODO: replace this with a symbolic predicate index for recursion
882-
let eth_dos_distance = NativePredicate::None;
926+
let eth_dos_distance = Predicate::BatchSelf(3);
883927

884-
let eth_dos_distance_ind = predicate_and(
928+
let eth_dos_distance_ind = builder.predicate_and(
885929
&[
886930
"src_or",
887931
"src_key",
@@ -899,7 +943,7 @@ mod tests {
899943
"intermed_key",
900944
],
901945
&[
902-
st_tmpl(eth_dos_distance) // TODO: Handle recursion
946+
st_tmpl(eth_dos_distance)
903947
.arg((w("src_or"), w("src_key")))
904948
.arg((w("intermed_or"), w("intermed_key")))
905949
.arg((w("shorter_distance_or"), w("shorter_distance_key"))),
@@ -916,9 +960,12 @@ mod tests {
916960
],
917961
);
918962

919-
println!("eth_dos_distance_ind = {}", eth_dos_distance_ind);
963+
println!(
964+
"b.1. eth_dos_distance_ind = {}",
965+
builder.predicates.last().unwrap()
966+
);
920967

921-
let eth_dos_distance = predicate_or(
968+
let eth_dos_distance = builder.predicate_or(
922969
&[
923970
"src_or",
924971
"src_key",
@@ -940,6 +987,9 @@ mod tests {
940987
],
941988
);
942989

943-
println!("eth_dos_distance = {}", eth_dos_distance);
990+
println!(
991+
"b.2. eth_dos_distance = {}",
992+
builder.predicates.last().unwrap()
993+
);
944994
}
945995
}

0 commit comments

Comments
 (0)