Skip to content

Commit 01f93be

Browse files
committed
Add custom op enum variant and wildcard matching procedures
1 parent f98f297 commit 01f93be

File tree

3 files changed

+77
-7
lines changed

3 files changed

+77
-7
lines changed

src/middleware/custom.rs

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
use std::sync::Arc;
2-
use std::{fmt, hash as h};
2+
use std::{fmt, hash as h, iter::zip};
33

4-
use super::{hash_str, Hash, NativePredicate, ToFields, Value, F};
4+
use anyhow::{anyhow, Result};
5+
6+
use super::{
7+
hash_str, AnchoredKey, Hash, NativePredicate, PodId, Statement, StatementArg, ToFields, Value,
8+
F,
9+
};
510

611
// BEGIN Custom 1b
712

@@ -11,6 +16,19 @@ pub enum HashOrWildcard {
1116
Wildcard(usize),
1217
}
1318

19+
impl HashOrWildcard {
20+
/// Matches a hash or wildcard against a value, returning a pair
21+
/// representing a wildcard binding (if any) or an error if no
22+
/// match is possible.
23+
pub fn match_against(&self, v: &Value) -> Result<Option<(usize, Value)>> {
24+
match self {
25+
HashOrWildcard::Hash(h) if &Value::from(h.clone()) == v => Ok(None),
26+
HashOrWildcard::Wildcard(i) => Ok(Some((*i, v.clone()))),
27+
_ => Err(anyhow!("Failed to match {} against {}.", self, v)),
28+
}
29+
}
30+
}
31+
1432
impl fmt::Display for HashOrWildcard {
1533
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1634
match self {
@@ -27,6 +45,25 @@ pub enum StatementTmplArg {
2745
Key(HashOrWildcard, HashOrWildcard),
2846
}
2947

48+
impl StatementTmplArg {
49+
/// Matches a statement template argument against a statement
50+
/// argument, returning a wildcard correspondence in the case of
51+
/// one or more wildcard matches, nothing in the case of a
52+
/// literal/hash match, and an error otherwise.
53+
pub fn match_against(&self, s_arg: &StatementArg) -> Result<Vec<(usize, Value)>> {
54+
match (self, s_arg) {
55+
(Self::None, StatementArg::None) => Ok(vec![]),
56+
(Self::Literal(v), StatementArg::Literal(w)) if v == w => Ok(vec![]),
57+
(Self::Key(tmpl_o, tmpl_k), StatementArg::Key(AnchoredKey(PodId(o), k))) => {
58+
let o_corr = tmpl_o.match_against(&o.clone().into())?;
59+
let k_corr = tmpl_k.match_against(&k.clone().into())?;
60+
Ok([o_corr, k_corr].into_iter().flat_map(|x| x).collect())
61+
}
62+
_ => Err(anyhow!("Failed to match {} against {}.", self, s_arg)),
63+
}
64+
}
65+
}
66+
3067
impl fmt::Display for StatementTmplArg {
3168
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3269
match self {
@@ -53,6 +90,33 @@ impl fmt::Display for StatementTmplArg {
5390
#[derive(Clone, Debug, PartialEq, Eq)]
5491
pub struct StatementTmpl(pub Predicate, pub Vec<StatementTmplArg>);
5592

93+
impl StatementTmpl {
94+
pub fn pred(&self) -> &Predicate {
95+
&self.0
96+
}
97+
pub fn args(&self) -> &[StatementTmplArg] {
98+
&self.1
99+
}
100+
/// Matches a statement template against a statement, returning
101+
/// the variable bindings as an association list. Returns an error
102+
/// if there is type or argument mismatch.
103+
pub fn match_against(&self, s: &Statement) -> Result<Vec<(usize, Value)>> {
104+
type P = Predicate;
105+
if matches!(self, Self(P::BatchSelf(_), _)) {
106+
Err(anyhow!(
107+
"Cannot check self-referencing statement templates."
108+
))
109+
} else if self.pred() != &s.code() {
110+
Err(anyhow!("Type mismatch between {:?} and {}.", self, s))
111+
} else {
112+
zip(self.args(), s.args())
113+
.map(|(t_arg, s_arg)| t_arg.match_against(&s_arg))
114+
.collect::<Result<Vec<_>>>()
115+
.map(|v| v.concat())
116+
}
117+
}
118+
}
119+
56120
#[derive(Clone, Debug, PartialEq, Eq)]
57121
pub struct CustomPredicate {
58122
/// true for "and", false for "or"

src/middleware/operation.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use anyhow::{anyhow, Result};
22

3-
use super::Statement;
3+
use super::{CustomPredicateRef, Statement};
44
use crate::middleware::{AnchoredKey, SELF};
55

66
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -42,6 +42,7 @@ pub enum Operation {
4242
SumOf(Statement, Statement, Statement),
4343
ProductOf(Statement, Statement, Statement),
4444
MaxOf(Statement, Statement, Statement),
45+
Custom(CustomPredicateRef, Vec<Statement>),
4546
}
4647

4748
impl Operation {
@@ -64,6 +65,7 @@ impl Operation {
6465
Self::SumOf(_, _, _) => SumOf,
6566
Self::ProductOf(_, _, _) => ProductOf,
6667
Self::MaxOf(_, _, _) => MaxOf,
68+
Self::Custom(_, _) => todo!(),
6769
}
6870
}
6971

@@ -85,6 +87,7 @@ impl Operation {
8587
Self::SumOf(s1, s2, s3) => vec![s1, s2, s3],
8688
Self::ProductOf(s1, s2, s3) => vec![s1, s2, s3],
8789
Self::MaxOf(s1, s2, s3) => vec![s1, s2, s3],
90+
Self::Custom(_, args) => args,
8891
}
8992
}
9093
/// Forms operation from op-code and arguments.
@@ -171,6 +174,10 @@ impl Operation {
171174
let v3: i64 = v3.clone().try_into()?;
172175
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
173176
}
177+
(
178+
Self::Custom(CustomPredicateRef(cpb, i), _args),
179+
Custom(CustomPredicateRef(s_cpb, s_i), _s_args),
180+
) if cpb == s_cpb && i == s_i => todo!(),
174181
_ => Err(anyhow!(
175182
"Invalid deduction: {:?} ⇏ {:#}",
176183
self,

src/middleware/statement.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use anyhow::{anyhow, Result};
22
use plonky2::field::types::Field;
3-
use std::fmt;
3+
use std::{collections::HashMap, fmt};
44
use strum_macros::FromRepr;
55

66
use super::{AnchoredKey, CustomPredicateRef, Hash, Predicate, ToFields, Value, F};
@@ -30,7 +30,6 @@ impl ToFields for NativePredicate {
3030
}
3131
}
3232

33-
// TODO: Incorporate custom statements into this enum.
3433
/// Type encapsulating statements with their associated arguments.
3534
#[derive(Clone, Debug, PartialEq, Eq)]
3635
pub enum Statement {
@@ -45,7 +44,7 @@ pub enum Statement {
4544
SumOf(AnchoredKey, AnchoredKey, AnchoredKey),
4645
ProductOf(AnchoredKey, AnchoredKey, AnchoredKey),
4746
MaxOf(AnchoredKey, AnchoredKey, AnchoredKey),
48-
Custom(CustomPredicateRef, Vec<Hash>),
47+
Custom(CustomPredicateRef, Vec<AnchoredKey>),
4948
}
5049

5150
impl Statement {
@@ -83,7 +82,7 @@ impl Statement {
8382
Self::SumOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
8483
Self::ProductOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
8584
Self::MaxOf(ak1, ak2, ak3) => vec![Key(ak1), Key(ak2), Key(ak3)],
86-
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(|h| Literal(h.into()))),
85+
Self::Custom(_, args) => Vec::from_iter(args.into_iter().map(|h| Key(h))),
8786
}
8887
}
8988
}

0 commit comments

Comments
 (0)