Skip to content

Commit c302dad

Browse files
committed
limit the number of StatementTmpl in CustomPredicate:
- add constructor method for CustomPredicate - make size checks at the CustomPredicate creation, so that once instantiated we can assume that contains valid data This resolves #79
1 parent 7373b95 commit c302dad

File tree

6 files changed

+94
-54
lines changed

6 files changed

+94
-54
lines changed

src/backends/plonky2/mock_main/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ impl Pod for MockMainPod {
435435
self.operations[i]
436436
.deref(&self.statements[..input_statement_offset + i])
437437
.unwrap()
438-
.check(&s.clone().try_into().unwrap())
438+
.check(&self.params, &s.clone().try_into().unwrap())
439439
})
440440
.collect::<Result<Vec<_>>>()
441441
.unwrap();

src/frontend/custom.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#![allow(unused)]
2+
use anyhow::Result;
23
use std::sync::Arc;
34

45
use crate::middleware::{
5-
hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate,
6+
hash_str, CustomPredicate, CustomPredicateBatch, Hash, HashOrWildcard, NativePredicate, Params,
67
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F,
78
};
89

@@ -96,31 +97,34 @@ impl CustomPredicateBatchBuilder {
9697

9798
fn predicate_and(
9899
&mut self,
100+
params: &Params,
99101
args: &[&str],
100102
priv_args: &[&str],
101103
sts: &[StatementTmplBuilder],
102-
) -> Predicate {
103-
self.predicate(true, args, priv_args, sts)
104+
) -> Result<Predicate> {
105+
self.predicate(params, true, args, priv_args, sts)
104106
}
105107

106108
fn predicate_or(
107109
&mut self,
110+
params: &Params,
108111
args: &[&str],
109112
priv_args: &[&str],
110113
sts: &[StatementTmplBuilder],
111-
) -> Predicate {
112-
self.predicate(false, args, priv_args, sts)
114+
) -> Result<Predicate> {
115+
self.predicate(params, false, args, priv_args, sts)
113116
}
114117

115118
/// creates the custom predicate from the given input, adds it to the
116119
/// self.predicates, and returns the index of the created predicate
117120
fn predicate(
118121
&mut self,
122+
params: &Params,
119123
conjunction: bool,
120124
args: &[&str],
121125
priv_args: &[&str],
122126
sts: &[StatementTmplBuilder],
123-
) -> Predicate {
127+
) -> Result<Predicate> {
124128
let statements = sts
125129
.iter()
126130
.map(|sb| {
@@ -138,13 +142,9 @@ impl CustomPredicateBatchBuilder {
138142
StatementTmpl(sb.predicate.clone(), args)
139143
})
140144
.collect();
141-
let custom_predicate = CustomPredicate {
142-
conjunction,
143-
statements,
144-
args_len: args.len(),
145-
};
145+
let custom_predicate = CustomPredicate::new(params, conjunction, statements, args.len())?;
146146
self.predicates.push(custom_predicate);
147-
Predicate::BatchSelf(self.predicates.len() - 1)
147+
Ok(Predicate::BatchSelf(self.predicates.len() - 1))
148148
}
149149

150150
fn finish(self) -> Arc<CustomPredicateBatch> {
@@ -174,7 +174,7 @@ mod tests {
174174
use crate::middleware::{CustomPredicateRef, Params, PodType};
175175

176176
#[test]
177-
fn test_custom_pred() {
177+
fn test_custom_pred() -> Result<()> {
178178
use NativePredicate as NP;
179179
use StatementTmplBuilder as STB;
180180

@@ -183,6 +183,7 @@ mod tests {
183183

184184
let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into());
185185
let _eth_friend = builder.predicate_and(
186+
&params,
186187
// arguments:
187188
&["src_ori", "src_key", "dst_ori", "dst_key"],
188189
// private arguments:
@@ -202,7 +203,7 @@ mod tests {
202203
.arg(("attestation_pod", literal("attestation")))
203204
.arg(("dst_ori", "dst_key")),
204205
],
205-
);
206+
)?;
206207

207208
println!("a.0. eth_friend = {}", builder.predicates.last().unwrap());
208209
let eth_friend = builder.finish();
@@ -216,6 +217,7 @@ mod tests {
216217
// >
217218
let mut builder = CustomPredicateBatchBuilder::new("eth_dos_distance_base".into());
218219
let eth_dos_distance_base = builder.predicate_and(
220+
&params,
219221
&[
220222
// arguments:
221223
"src_ori",
@@ -236,7 +238,7 @@ mod tests {
236238
.arg(("distance_ori", "distance_key"))
237239
.arg(0),
238240
],
239-
);
241+
)?;
240242
println!(
241243
"b.0. eth_dos_distance_base = {}",
242244
builder.predicates.last().unwrap()
@@ -246,6 +248,7 @@ mod tests {
246248

247249
// next chunk builds:
248250
let eth_dos_distance_ind = builder.predicate_and(
251+
&params,
249252
&[
250253
// arguments:
251254
"src_ori",
@@ -281,14 +284,15 @@ mod tests {
281284
.arg(("intermed_ori", "intermed_key"))
282285
.arg(("dst_ori", "dst_key")),
283286
],
284-
);
287+
)?;
285288

286289
println!(
287290
"b.1. eth_dos_distance_ind = {}",
288291
builder.predicates.last().unwrap()
289292
);
290293

291294
let _eth_dos_distance = builder.predicate_or(
295+
&params,
292296
&[
293297
"src_ori",
294298
"src_key",
@@ -308,7 +312,7 @@ mod tests {
308312
.arg(("dst_ori", "dst_key"))
309313
.arg(("distance_ori", "distance_key")),
310314
],
311-
);
315+
)?;
312316

313317
println!(
314318
"b.2. eth_dos_distance = {}",
@@ -318,5 +322,7 @@ mod tests {
318322
let eth_dos_batch_b = builder.finish();
319323
let fields = eth_dos_batch_b.to_fields(&params);
320324
println!("Batch b, serialized: {:?}", fields);
325+
326+
Ok(())
321327
}
322328
}

src/frontend/operation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::fmt;
22

33
use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value};
4-
use crate::middleware::{hash_str, NativeOperation, NativePredicate, OperationType, Predicate};
4+
use crate::middleware::{hash_str, NativePredicate, OperationType, Predicate};
55

66
#[derive(Clone, Debug, PartialEq, Eq)]
77
pub enum OperationArg {

src/middleware/custom.rs

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,42 @@ impl ToFields for StatementTmpl {
195195

196196
#[derive(Clone, Debug, PartialEq, Eq)]
197197
pub struct CustomPredicate {
198+
/// NOTE: fields are not public (outside of crate) to enforce the struct instantiation through
199+
/// the `::and/or` methods, which performs checks on the values.
200+
198201
/// true for "and", false for "or"
199-
pub conjunction: bool,
200-
pub statements: Vec<StatementTmpl>,
201-
pub args_len: usize,
202+
pub(crate) conjunction: bool,
203+
pub(crate) statements: Vec<StatementTmpl>,
204+
pub(crate) args_len: usize,
202205
// TODO: Add private args length?
203206
// TODO: Add args type information?
204207
}
205208

209+
impl CustomPredicate {
210+
pub fn and(params: &Params, statements: Vec<StatementTmpl>, args_len: usize) -> Result<Self> {
211+
Self::new(params, true, statements, args_len)
212+
}
213+
pub fn or(params: &Params, statements: Vec<StatementTmpl>, args_len: usize) -> Result<Self> {
214+
Self::new(params, false, statements, args_len)
215+
}
216+
pub fn new(
217+
params: &Params,
218+
conjunction: bool,
219+
statements: Vec<StatementTmpl>,
220+
args_len: usize,
221+
) -> Result<Self> {
222+
if statements.len() > params.max_custom_predicate_arity {
223+
return Err(anyhow!("Custom predicate depends on too many statements"));
224+
}
225+
226+
Ok(Self {
227+
conjunction,
228+
statements,
229+
args_len,
230+
})
231+
}
232+
}
233+
206234
impl ToFields for CustomPredicate {
207235
fn to_fields(&self, params: &Params) -> (Vec<F>, usize) {
208236
// serialize as:
@@ -212,9 +240,9 @@ impl ToFields for CustomPredicate {
212240
// (params.max_custom_predicate_arity * params.statement_tmpl_size())
213241
// field elements
214242

215-
// TODO think if this check should go into the StatementTmpl creation,
216-
// instead of at the `to_fields` method, where we should assume that the
217-
// values are already valid
243+
// NOTE: this method assumes that the self.params.len() is inside the
244+
// expected bound, as Self should be instantiated with the constructor
245+
// method `new` which performs the check.
218246
if self.statements.len() > params.max_custom_predicate_arity {
219247
panic!("Custom predicate depends on too many statements");
220248
}
@@ -353,7 +381,7 @@ mod tests {
353381

354382
use crate::middleware::{
355383
AnchoredKey, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Hash,
356-
HashOrWildcard, NativePredicate, Operation, PodId, PodType, Predicate, Statement,
384+
HashOrWildcard, NativePredicate, Operation, Params, PodId, PodType, Predicate, Statement,
357385
StatementTmpl, StatementTmplArg, SELF,
358386
};
359387

@@ -418,13 +446,16 @@ mod tests {
418446
],
419447
);
420448

421-
assert!(custom_deduction.check(&custom_statement)?);
449+
let params = Params::default();
450+
assert!(custom_deduction.check(&params, &custom_statement)?);
422451

423452
Ok(())
424453
}
425454

426455
#[test]
427456
fn ethdos_test() -> Result<()> {
457+
let params = Params::default();
458+
428459
let eth_friend_cp = CustomPredicate {
429460
conjunction: true,
430461
statements: vec![
@@ -561,7 +592,7 @@ mod tests {
561592
);
562593

563594
// Copies should work.
564-
assert!(Operation::CopyStatement(ethdos_example.clone()).check(&ethdos_example)?);
595+
assert!(Operation::CopyStatement(ethdos_example.clone()).check(&params, &ethdos_example)?);
565596

566597
// This could arise as the inductive step.
567598
let ethdos_ind_example = Statement::Custom(
@@ -577,7 +608,7 @@ mod tests {
577608
CustomPredicateRef(eth_dos_distance_batch.clone(), 2),
578609
vec![ethdos_ind_example.clone()]
579610
)
580-
.check(&ethdos_example)?);
611+
.check(&params, &ethdos_example)?);
581612

582613
// And the inductive step would arise as follows: Say the
583614
// ETHDoS distance from Alice to Charlie is 6, which is one
@@ -610,7 +641,7 @@ mod tests {
610641
CustomPredicateRef(eth_dos_distance_batch.clone(), 1),
611642
ethdos_facts
612643
)
613-
.check(&ethdos_ind_example)?);
644+
.check(&params, &ethdos_ind_example)?);
614645

615646
Ok(())
616647
}

src/middleware/mod.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,22 @@ pub struct Params {
9292
pub max_custom_batch_size: usize,
9393
}
9494

95+
impl Default for Params {
96+
fn default() -> Self {
97+
Self {
98+
max_input_signed_pods: 3,
99+
max_input_main_pods: 3,
100+
max_statements: 20,
101+
max_signed_pod_values: 8,
102+
max_public_statements: 10,
103+
max_statement_args: 5,
104+
max_operation_args: 5,
105+
max_custom_predicate_arity: 5,
106+
max_custom_batch_size: 5,
107+
}
108+
}
109+
}
110+
95111
impl Params {
96112
pub fn max_priv_statements(&self) -> usize {
97113
self.max_statements - self.max_public_statements
@@ -134,22 +150,6 @@ impl Params {
134150
}
135151
}
136152

137-
impl Default for Params {
138-
fn default() -> Self {
139-
Self {
140-
max_input_signed_pods: 3,
141-
max_input_main_pods: 3,
142-
max_statements: 20,
143-
max_signed_pod_values: 8,
144-
max_public_statements: 10,
145-
max_statement_args: 5,
146-
max_operation_args: 5,
147-
max_custom_predicate_arity: 5,
148-
max_custom_batch_size: 5,
149-
}
150-
}
151-
}
152-
153153
pub trait Pod: fmt::Debug + DynClone {
154154
fn verify(&self) -> bool;
155155
fn id(&self) -> PodId;

src/middleware/operation.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ use anyhow::{anyhow, Result};
44

55
use super::{CustomPredicateRef, Statement};
66
use crate::{
7-
middleware::{AnchoredKey, CustomPredicate, PodId, Predicate, StatementTmpl, Value, SELF},
7+
middleware::{
8+
AnchoredKey, CustomPredicate, Params, PodId, Predicate, StatementTmpl, Value, SELF,
9+
},
810
util::hashmap_insert_no_dupe,
911
};
1012

@@ -145,7 +147,7 @@ impl Operation {
145147
})
146148
}
147149
/// Checks the given operation against a statement.
148-
pub fn check(&self, output_statement: &Statement) -> Result<bool> {
150+
pub fn check(&self, params: &Params, output_statement: &Statement) -> Result<bool> {
149151
use Statement::*;
150152
match (self, output_statement) {
151153
(Self::None, None) => Ok(true),
@@ -211,10 +213,11 @@ impl Operation {
211213
// references with custom predicate references.
212214
let custom_predicate = {
213215
let cp = (**cpb).predicates[*i].clone();
214-
CustomPredicate {
215-
conjunction: cp.conjunction,
216-
statements: cp
217-
.statements
216+
CustomPredicate::new(
217+
params,
218+
cp.conjunction,
219+
// statments:
220+
cp.statements
218221
.into_iter()
219222
.map(|StatementTmpl(p, args)| {
220223
StatementTmpl(
@@ -228,8 +231,8 @@ impl Operation {
228231
)
229232
})
230233
.collect(),
231-
args_len: cp.args_len,
232-
}
234+
cp.args_len,
235+
)?
233236
};
234237
match custom_predicate.conjunction {
235238
true if custom_predicate.statements.len() == args.len() => {

0 commit comments

Comments
 (0)