Skip to content

Commit 0186913

Browse files
committed
Implement custom op check
1 parent 05c21eb commit 0186913

File tree

5 files changed

+337
-4
lines changed

5 files changed

+337
-4
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pub mod constants;
33
pub mod frontend;
44
pub mod middleware;
55
pub mod primitives;
6+
mod util;
67

78
#[cfg(test)]
89
pub mod examples;

src/middleware/custom.rs

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ use std::sync::Arc;
22
use std::{fmt, hash as h, iter::zip};
33

44
use anyhow::{anyhow, Result};
5+
use plonky2::field::goldilocks_field::GoldilocksField;
6+
7+
use crate::middleware::{Operation, SELF};
58

69
use super::{
710
hash_str, AnchoredKey, Hash, NativePredicate, PodId, Statement, StatementArg, ToFields, Value,
@@ -208,3 +211,239 @@ impl fmt::Display for Predicate {
208211
}
209212
}
210213
}
214+
215+
mod tests {
216+
use std::{array, sync::Arc};
217+
218+
use anyhow::Result;
219+
use plonky2::field::goldilocks_field::GoldilocksField;
220+
221+
use crate::middleware::{
222+
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash,
223+
HashOrWildcard, NativePredicate, Operation, PodId, PodType, Predicate, Statement,
224+
StatementTmpl, StatementTmplArg, SELF,
225+
};
226+
227+
fn st(p: Predicate, args: Vec<StatementTmplArg>) -> StatementTmpl {
228+
StatementTmpl(p, args)
229+
}
230+
231+
type STA = StatementTmplArg;
232+
type HOW = HashOrWildcard;
233+
type P = Predicate;
234+
type NP = NativePredicate;
235+
236+
#[test]
237+
fn is_double_test() -> Result<()> {
238+
/*
239+
is_double(S1, S2) :-
240+
p:value_of(Constant, 2),
241+
p:product_of(S1, Constant, S2)
242+
*/
243+
let cust_pred_batch = Arc::new(CustomPredicateBatch {
244+
name: "is_double".to_string(),
245+
predicates: vec![CustomPredicate {
246+
conjunction: true,
247+
statements: vec![
248+
st(
249+
P::Native(NP::ValueOf),
250+
vec![
251+
STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)),
252+
STA::Literal(2.into()),
253+
],
254+
),
255+
st(
256+
P::Native(NP::ProductOf),
257+
vec![
258+
STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)),
259+
STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)),
260+
STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)),
261+
],
262+
),
263+
],
264+
args_len: 4,
265+
}],
266+
});
267+
268+
let custom_statement = Statement::Custom(
269+
CustomPredicateRef(cust_pred_batch.clone(), 0),
270+
vec![
271+
AnchoredKey(SELF, "Some value".into()),
272+
AnchoredKey(SELF, "Some other value".into()),
273+
],
274+
);
275+
276+
let custom_deduction = Operation::Custom(
277+
CustomPredicateRef(cust_pred_batch, 0),
278+
vec![
279+
Statement::ValueOf(AnchoredKey(SELF, "Some constant".into()), 2.into()),
280+
Statement::ProductOf(
281+
AnchoredKey(SELF, "Some value".into()),
282+
AnchoredKey(SELF, "Some constant".into()),
283+
AnchoredKey(SELF, "Some other value".into()),
284+
),
285+
],
286+
);
287+
288+
assert!(custom_deduction.check(&custom_statement)?);
289+
290+
Ok(())
291+
}
292+
293+
#[test]
294+
fn ethdos_test() -> Result<()> {
295+
let eth_friend_cp = CustomPredicate {
296+
conjunction: true,
297+
statements: vec![
298+
st(
299+
P::Native(NP::ValueOf),
300+
vec![
301+
STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("type".into())),
302+
STA::Literal(PodType::Signed.into()),
303+
],
304+
),
305+
st(
306+
P::Native(NP::Equal),
307+
vec![
308+
STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("signer".into())),
309+
STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)),
310+
],
311+
),
312+
st(
313+
P::Native(NP::Equal),
314+
vec![
315+
STA::Key(HOW::Wildcard(4), HashOrWildcard::Hash("attestation".into())),
316+
STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)),
317+
],
318+
),
319+
],
320+
args_len: 4,
321+
};
322+
323+
let eth_friend_batch = Arc::new(CustomPredicateBatch {
324+
name: "eth_friend".to_string(),
325+
predicates: vec![eth_friend_cp],
326+
});
327+
328+
let eth_dos_base = CustomPredicate {
329+
conjunction: true,
330+
statements: vec![
331+
st(
332+
P::Native(NP::Equal),
333+
vec![
334+
STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)),
335+
STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)),
336+
],
337+
),
338+
st(
339+
P::Native(NP::ValueOf),
340+
vec![
341+
STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)),
342+
STA::Literal(0.into()),
343+
],
344+
),
345+
],
346+
args_len: 6,
347+
};
348+
349+
let eth_dos_ind = CustomPredicate {
350+
conjunction: true,
351+
statements: vec![
352+
st(
353+
P::BatchSelf(2),
354+
vec![
355+
STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)),
356+
STA::Key(HOW::Wildcard(10), HOW::Wildcard(11)),
357+
STA::Key(HOW::Wildcard(8), HOW::Wildcard(9)),
358+
],
359+
),
360+
st(
361+
P::Native(NP::ValueOf),
362+
vec![
363+
STA::Key(HOW::Wildcard(6), HOW::Wildcard(7)),
364+
STA::Literal(1.into()),
365+
],
366+
),
367+
st(
368+
P::Native(NP::SumOf),
369+
vec![
370+
STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)),
371+
STA::Key(HOW::Wildcard(8), HOW::Wildcard(9)),
372+
STA::Key(HOW::Wildcard(6), HOW::Wildcard(7)),
373+
],
374+
),
375+
st(
376+
P::Custom(CustomPredicateRef(eth_friend_batch.clone(), 0)),
377+
vec![
378+
STA::Key(HOW::Wildcard(10), HOW::Wildcard(11)),
379+
STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)),
380+
],
381+
),
382+
],
383+
args_len: 6,
384+
};
385+
386+
let eth_dos_distance_either = CustomPredicate {
387+
conjunction: false,
388+
statements: vec![
389+
st(
390+
P::BatchSelf(0),
391+
vec![
392+
STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)),
393+
STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)),
394+
STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)),
395+
],
396+
),
397+
st(
398+
P::BatchSelf(1),
399+
vec![
400+
STA::Key(HOW::Wildcard(0), HOW::Wildcard(1)),
401+
STA::Key(HOW::Wildcard(2), HOW::Wildcard(3)),
402+
STA::Key(HOW::Wildcard(4), HOW::Wildcard(5)),
403+
],
404+
),
405+
],
406+
args_len: 6,
407+
};
408+
409+
let eth_dos_distance_batch = Arc::new(CustomPredicateBatch {
410+
name: "ETHDoS_distance".to_string(),
411+
predicates: vec![eth_dos_base, eth_dos_ind, eth_dos_distance_either],
412+
});
413+
414+
// Some POD IDs
415+
let attestation_pod_id = PodId(Hash(array::from_fn(|i| GoldilocksField(i as u64))));
416+
let other_pod_id = PodId(Hash(array::from_fn(|i| GoldilocksField((i * i) as u64))));
417+
418+
// Example statement
419+
let ethdos_example = Statement::Custom(
420+
CustomPredicateRef(eth_dos_distance_batch.clone(), 2),
421+
vec![
422+
AnchoredKey(SELF, "Alice".into()),
423+
AnchoredKey(SELF, "Bob".into()),
424+
AnchoredKey(SELF, "Seven".into()),
425+
],
426+
);
427+
428+
// Copies should work.
429+
assert!(Operation::CopyStatement(ethdos_example.clone()).check(&ethdos_example)?);
430+
431+
// This could arise as the inductive step.
432+
let ethdos_ind_example = Statement::Custom(
433+
CustomPredicateRef(eth_dos_distance_batch.clone(), 1),
434+
vec![
435+
AnchoredKey(SELF, "Alice".into()),
436+
AnchoredKey(SELF, "Bob".into()),
437+
AnchoredKey(SELF, "Seven".into()),
438+
],
439+
);
440+
441+
assert!(Operation::Custom(
442+
CustomPredicateRef(eth_dos_distance_batch.clone(), 2),
443+
vec![ethdos_ind_example.clone()]
444+
)
445+
.check(&ethdos_example)?);
446+
447+
Ok(())
448+
}
449+
}

src/middleware/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ impl FromHex for Hash {
178178
}
179179
}
180180

181+
impl From<&str> for Hash {
182+
fn from(s: &str) -> Self {
183+
hash_str(s)
184+
}
185+
}
186+
181187
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
182188
pub struct PodId(pub Hash);
183189

src/middleware/operation.rs

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
use std::collections::HashMap;
2+
13
use anyhow::{anyhow, Result};
24

35
use super::{CustomPredicateRef, Statement};
4-
use crate::middleware::{AnchoredKey, SELF};
6+
use crate::{
7+
middleware::{AnchoredKey, CustomPredicate, PodId, Predicate, StatementTmpl, Value, SELF},
8+
util::hashmap_insert_no_dupe,
9+
};
510

611
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
712
pub enum NativeOperation {
@@ -175,9 +180,69 @@ impl Operation {
175180
Ok((v1 == v2 + v3) && ak4 == ak1 && ak5 == ak2 && ak6 == ak3)
176181
}
177182
(
178-
Self::Custom(CustomPredicateRef(cpb, i), _args),
179-
Custom(CustomPredicateRef(s_cpb, s_i), _s_args),
180-
) if cpb == s_cpb && i == s_i => todo!(),
183+
Self::Custom(CustomPredicateRef(cpb, i), args),
184+
Custom(CustomPredicateRef(s_cpb, s_i), s_args),
185+
) if cpb == s_cpb && i == s_i => {
186+
// Bind statement arguments
187+
let mut bindings = s_args
188+
.into_iter()
189+
.enumerate()
190+
.flat_map(|(i, AnchoredKey(PodId(o), k))| {
191+
vec![
192+
(2 * i, Value::from(o.clone())),
193+
(2 * i + 1, Value::from(k.clone())),
194+
]
195+
})
196+
.collect::<HashMap<_, _>>();
197+
198+
// Single out custom predicate, replacing batch-self
199+
// references with custom predicate references.
200+
let custom_predicate = {
201+
let cp = (**cpb).predicates[*i].clone();
202+
CustomPredicate {
203+
conjunction: cp.conjunction,
204+
statements: cp
205+
.statements
206+
.into_iter()
207+
.map(|StatementTmpl(p, args)| {
208+
StatementTmpl(
209+
match p {
210+
Predicate::BatchSelf(i) => {
211+
Predicate::Custom(CustomPredicateRef(cpb.clone(), i))
212+
}
213+
_ => p,
214+
},
215+
args,
216+
)
217+
})
218+
.collect(),
219+
args_len: cp.args_len,
220+
}
221+
};
222+
match custom_predicate.conjunction {
223+
true if custom_predicate.statements.len() == args.len() => {
224+
// Match op args against statement templates
225+
let match_bindings = std::iter::zip(custom_predicate.statements, args).map(
226+
|(s_tmpl, s)| s_tmpl.match_against(s)
227+
).collect::<Result<Vec<_>>>()
228+
.map(|v| v.concat())?;
229+
// Add bindings to binding table, throwing if there is an inconsistency.
230+
match_bindings.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
231+
Ok(true)
232+
},
233+
false if args.len() == 1 => {
234+
// Match op arg against each statement template
235+
custom_predicate.statements.into_iter().map(
236+
|s_tmpl| {
237+
let mut bindings = bindings.clone();
238+
s_tmpl.match_against(&args[0])?.into_iter().try_for_each(|kv| hashmap_insert_no_dupe(&mut bindings, kv))?;
239+
Ok::<_, anyhow::Error>(true)
240+
}
241+
).find(|m| m.is_ok()).unwrap_or(Ok(false))
242+
},
243+
_ => Err(anyhow!("Custom predicate statement template list {:?} does not match op argument list {:?}.", custom_predicate.statements, args))
244+
}
245+
}
181246
_ => Err(anyhow!(
182247
"Invalid deduction: {:?} ⇏ {:#}",
183248
self,

src/util.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use std::collections::HashMap;
2+
use std::fmt::Debug;
3+
use std::hash::Hash;
4+
5+
use anyhow::{anyhow, Result};
6+
7+
pub(crate) fn hashmap_insert_no_dupe<S: Clone + Debug + Eq + Hash, T: Clone + Debug + Eq>(
8+
hm: &mut HashMap<S, T>,
9+
kv: (S, T),
10+
) -> Result<()> {
11+
let (k, v) = kv.clone();
12+
let res = hm.insert(kv.0, kv.1);
13+
match res {
14+
Some(w) if w != v => Err(anyhow!(
15+
"Key {:?} exists in table with value {:?} != {:?}.",
16+
k,
17+
w,
18+
v
19+
)),
20+
_ => Ok(()),
21+
}
22+
}

0 commit comments

Comments
 (0)