Skip to content

Commit 8efc681

Browse files
committed
refactor custom predicates
1 parent 784f964 commit 8efc681

File tree

9 files changed

+342
-259
lines changed

9 files changed

+342
-259
lines changed

src/backends/plonky2/mock/mainpod/statement.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use anyhow::{anyhow, Result};
44

55
// use serde::{Deserialize, Serialize};
66
use crate::middleware::{
7-
self, AnchoredKey, NativePredicate, Params, Predicate, StatementArg, ToFields,
7+
self, AnchoredKey, NativePredicate, Params, Predicate, StatementArg, ToFields, WildcardValue,
88
};
99

1010
#[derive(Clone, Debug, PartialEq)]
@@ -81,15 +81,15 @@ impl TryFrom<Statement> for middleware::Statement {
8181
_ => Err(anyhow!("Ill-formed statement expression {:?}", s))?,
8282
},
8383
Predicate::Custom(cpr) => {
84-
let aks: Vec<AnchoredKey> = proper_args
84+
let vs: Vec<WildcardValue> = proper_args
8585
.into_iter()
8686
.filter_map(|arg| match arg {
8787
SA::None => None,
88-
SA::Key(ak) => Some(ak),
89-
SA::Literal(_) => unreachable!(),
88+
SA::WildcardLiteral(v) => Some(v),
89+
_ => unreachable!(),
9090
})
9191
.collect();
92-
S::Custom(cpr, aks)
92+
S::Custom(cpr, vs)
9393
}
9494
Predicate::BatchSelf(_) => {
9595
unreachable!()

src/examples/custom.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use anyhow::Result;
44
use StatementTmplBuilder as STB;
55

66
use crate::{
7-
frontend::{literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
7+
frontend::{key, literal, CustomPredicateBatchBuilder, StatementTmplBuilder},
88
middleware::{
99
self, CustomPredicateBatch, CustomPredicateRef, NativePredicate as NP, Params, PodType,
1010
Predicate, KEY_SIGNER, KEY_TYPE,
@@ -25,15 +25,15 @@ pub fn eth_friend_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
2525
&[
2626
// there is an attestation pod that's a SignedPod
2727
STB::new(NP::ValueOf)
28-
.arg(("attestation_pod", literal(KEY_TYPE)))
29-
.arg(PodType::MockSigned), // TODO
28+
.arg(("attestation_pod", key(KEY_TYPE)))
29+
.arg(literal(PodType::MockSigned)), // TODO
3030
// the attestation pod is signed by (src_or, src_key)
3131
STB::new(NP::Equal)
32-
.arg(("attestation_pod", literal(KEY_SIGNER)))
32+
.arg(("attestation_pod", key(KEY_SIGNER)))
3333
.arg(("src_ori", "src_key")),
3434
// that same attestation pod has an "attestation"
3535
STB::new(NP::Equal)
36-
.arg(("attestation_pod", literal("attestation")))
36+
.arg(("attestation_pod", key("attestation")))
3737
.arg(("dst_ori", "dst_key")),
3838
],
3939
)?;
@@ -72,7 +72,7 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
7272
.arg(("dst_ori", "dst_key")),
7373
STB::new(NP::ValueOf)
7474
.arg(("distance_ori", "distance_key"))
75-
.arg(0),
75+
.arg(literal(0)),
7676
],
7777
)?;
7878
println!(
@@ -106,19 +106,26 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
106106
&[
107107
// statement templates:
108108
STB::new(eth_dos_distance)
109-
.arg(("src_ori", "src_key"))
110-
.arg(("intermed_ori", "intermed_key"))
111-
.arg(("shorter_distance_ori", "shorter_distance_key")),
109+
.arg("src_ori")
110+
.arg("src_key")
111+
.arg("intermed_ori")
112+
.arg("intermed_key")
113+
.arg("shorter_distance_ori")
114+
.arg("shorter_distance_key"),
112115
// distance == shorter_distance + 1
113-
STB::new(NP::ValueOf).arg(("one_ori", "one_key")).arg(1),
116+
STB::new(NP::ValueOf)
117+
.arg(("one_ori", "one_key"))
118+
.arg(literal(1)),
114119
STB::new(NP::SumOf)
115120
.arg(("distance_ori", "distance_key"))
116121
.arg(("shorter_distance_ori", "shorter_distance_key"))
117122
.arg(("one_ori", "one_key")),
118123
// intermed is a friend of dst
119124
STB::new(eth_friend)
120-
.arg(("intermed_ori", "intermed_key"))
121-
.arg(("dst_ori", "dst_key")),
125+
.arg("intermed_ori")
126+
.arg("intermed_key")
127+
.arg("dst_ori")
128+
.arg("dst_key"),
122129
],
123130
)?;
124131

@@ -141,13 +148,19 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
141148
&[],
142149
&[
143150
STB::new(eth_dos_distance_base)
144-
.arg(("src_ori", "src_key"))
145-
.arg(("dst_ori", "dst_key"))
146-
.arg(("distance_ori", "distance_key")),
151+
.arg("src_ori")
152+
.arg("src_key")
153+
.arg("dst_ori")
154+
.arg("dst_key")
155+
.arg("distance_ori")
156+
.arg("distance_key"),
147157
STB::new(eth_dos_distance_ind)
148-
.arg(("src_ori", "src_key"))
149-
.arg(("dst_ori", "dst_key"))
150-
.arg(("distance_ori", "distance_key")),
158+
.arg("src_ori")
159+
.arg("src_key")
160+
.arg("dst_ori")
161+
.arg("dst_key")
162+
.arg("distance_ori")
163+
.arg("distance_key"),
151164
],
152165
)?;
153166

src/examples/mod.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ pub fn eth_dos_pod_builder(
139139
let ethdos_alice_alice_is_zero = alice_bob_ethdos.priv_op(op!(
140140
custom,
141141
eth_dos.clone(),
142-
ethdos_alice_alice_is_zero_base
142+
ethdos_alice_alice_is_zero_base,
143+
Statement::None
143144
))?;
144145

145146
// Alice and Charlie are ETH friends.
@@ -192,6 +193,7 @@ pub fn eth_dos_pod_builder(
192193
let ethdos_alice_charlie_is_one = alice_bob_ethdos.priv_op(op!(
193194
custom,
194195
eth_dos.clone(),
196+
Statement::None,
195197
ethdos_alice_charlie_is_one_ind
196198
))?;
197199

@@ -210,8 +212,12 @@ pub fn eth_dos_pod_builder(
210212
ethdos_sum,
211213
ethfriends_charlie_bob
212214
))?;
213-
let _ethdos_alice_bob_is_two =
214-
alice_bob_ethdos.pub_op(op!(custom, eth_dos.clone(), ethdos_alice_bob_is_two_ind))?;
215+
let _ethdos_alice_bob_is_two = alice_bob_ethdos.pub_op(op!(
216+
custom,
217+
eth_dos.clone(),
218+
Statement::None,
219+
ethdos_alice_bob_is_two_ind
220+
))?;
215221

216222
Ok(alice_bob_ethdos)
217223
}

src/frontend/custom.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,28 @@ pub enum KeyOrWildcardStr {
2222
}
2323

2424
/// helper to build a literal KeyOrWildcardStr::Key from the given str
25-
pub fn literal(s: &str) -> KeyOrWildcardStr {
25+
pub fn key(s: &str) -> KeyOrWildcardStr {
2626
KeyOrWildcardStr::Key(s.to_string())
2727
}
2828

2929
/// helper to build a KeyOrWildcardStr::Wildcard from the given str. For the
3030
/// moment this method does not need to be public.
31-
fn wildcard(s: &str) -> KeyOrWildcardStr {
32-
KeyOrWildcardStr::Wildcard(s.to_string())
33-
}
31+
// fn wildcard(s: &str) -> KeyOrWildcardStr {
32+
// KeyOrWildcardStr::Wildcard(s.to_string())
33+
// }
3434

3535
/// Builder Argument for the StatementTmplBuilder
3636
pub enum BuilderArg {
3737
Literal(Value),
3838
/// Key: (origin, key), where origin is a Wildcard and key can be both Key or Wildcard
3939
Key(String, KeyOrWildcardStr),
40+
WildcardLiteral(String),
4041
}
4142

4243
/// When defining a `BuilderArg`, it can be done from 3 different inputs:
4344
/// i. (&str, literal): this is to set a POD and a field, ie. (POD, literal("field"))
4445
/// ii. (&str, &str): this is to define a origin-key wildcard pair, ie. (src_origin, src_dest)
45-
/// iii. Value: this is to define a literal value, ie. 0
46+
/// iii. &str: this is to define a WildcardValue wildcard, ie. "src_or"
4647
///
4748
/// case i.
4849
impl From<(&str, KeyOrWildcardStr)> for BuilderArg {
@@ -58,19 +59,20 @@ impl From<(&str, KeyOrWildcardStr)> for BuilderArg {
5859
/// case ii.
5960
impl From<(&str, &str)> for BuilderArg {
6061
fn from((origin, field): (&str, &str)) -> Self {
61-
Self::Key(origin.into(), wildcard(field))
62+
Self::Key(origin.into(), KeyOrWildcardStr::Wildcard(field.to_string()))
6263
}
6364
}
6465
/// case iii.
65-
impl<V> From<V> for BuilderArg
66-
where
67-
V: Into<Value>,
68-
{
69-
fn from(v: V) -> Self {
70-
Self::Literal(v.into())
66+
impl From<&str> for BuilderArg {
67+
fn from(wc: &str) -> Self {
68+
Self::WildcardLiteral(wc.to_string())
7169
}
7270
}
7371

72+
pub fn literal(v: impl Into<Value>) -> BuilderArg {
73+
BuilderArg::Literal(v.into())
74+
}
75+
7476
pub struct StatementTmplBuilder {
7577
predicate: Predicate,
7678
args: Vec<BuilderArg>,
@@ -136,6 +138,21 @@ impl CustomPredicateBatchBuilder {
136138
priv_args: &[&str],
137139
sts: &[StatementTmplBuilder],
138140
) -> Result<Predicate> {
141+
if args.len() > params.max_statement_args {
142+
return Err(anyhow!(
143+
"args.len {} is over the limit {}",
144+
args.len(),
145+
params.max_statement_args
146+
));
147+
}
148+
if (args.len() + priv_args.len()) > params.max_custom_predicate_wildcards {
149+
return Err(anyhow!(
150+
"wildcards.len {} is over the limit {}",
151+
args.len() + priv_args.len(),
152+
params.max_custom_predicate_wildcards
153+
));
154+
}
155+
139156
let statements = sts
140157
.iter()
141158
.map(|sb| {
@@ -148,6 +165,9 @@ impl CustomPredicateBatchBuilder {
148165
resolve_wildcard(args, priv_args, &pod_id),
149166
resolve_key_or_wildcard(args, priv_args, &key),
150167
),
168+
BuilderArg::WildcardLiteral(v) => {
169+
StatementTmplArg::WildcardLiteral(resolve_wildcard(args, priv_args, &v))
170+
}
151171
})
152172
.collect();
153173
StatementTmpl {

0 commit comments

Comments
 (0)