Skip to content

Commit 7373b95

Browse files
authored
feat: custom predicates in frontend statement and operation types (#97)
* Modify frontend statement type * Modify frontend operation type * Add exception to typos.toml
1 parent bcfad30 commit 7373b95

File tree

8 files changed

+168
-125
lines changed

8 files changed

+168
-125
lines changed

.github/workflows/typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
groth = "groth" # to avoid it dectecting it as 'growth'
33
BA = "BA"
44
Ded = "Ded" # "ANDed", it thought "Ded" should be "Dead"
5+
OT = "OT"

src/backends/plonky2/mock_main/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::fmt;
77

88
use crate::middleware::{
99
self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod,
10-
Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
10+
OperationType, Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
1111
};
1212

1313
mod operation;
@@ -261,7 +261,11 @@ impl MockMainPod {
261261
.map(|mid_arg| Self::find_op_arg(statements, mid_arg))
262262
.collect::<Result<Vec<_>>>()?;
263263
Self::pad_operation_args(params, &mut args);
264-
operations.push(Operation(op.code(), args));
264+
let op_code = match op.code() {
265+
OperationType::Native(code) => code,
266+
_ => unimplemented!(),
267+
};
268+
operations.push(Operation(op_code, args));
265269
}
266270
Ok(operations)
267271
}

src/backends/plonky2/mock_main/operation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use anyhow::Result;
22
use std::fmt;
33

44
use super::Statement;
5-
use crate::middleware::{self, NativeOperation};
5+
use crate::middleware::{self, NativeOperation, OperationType};
66

77
#[derive(Clone, Debug, PartialEq, Eq)]
88
pub enum OperationArg {
@@ -29,7 +29,7 @@ impl Operation {
2929
OperationArg::Index(i) => Some(statements[*i].clone().try_into()),
3030
})
3131
.collect::<Result<Vec<crate::middleware::Statement>>>()?;
32-
middleware::Operation::op(self.0, &deref_args)
32+
middleware::Operation::op(OperationType::Native(self.0), &deref_args)
3333
}
3434
}
3535

src/frontend/custom.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::middleware::{
66
Predicate, StatementTmpl, StatementTmplArg, ToFields, Value, F,
77
};
88

9-
/// Argument to an statement template
9+
/// Argument to a statement template
1010
pub enum HashOrWildcardStr {
1111
Hash(Hash), // represents a literal key
1212
Wildcard(String),

src/frontend/mod.rs

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::middleware::{
1313
hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver,
1414
PodSigner, SELF,
1515
};
16+
use crate::middleware::{OperationType, Predicate};
1617

1718
mod custom;
1819
mod operation;
@@ -254,7 +255,7 @@ impl MainPodBuilder {
254255
for arg in args.iter_mut() {
255256
match arg {
256257
OperationArg::Statement(s) => {
257-
if s.0 == NativePredicate::ValueOf {
258+
if s.0 == Predicate::Native(NativePredicate::ValueOf) {
258259
st_args.push(s.1[0].clone())
259260
} else {
260261
panic!("Invalid statement argument.");
@@ -266,7 +267,7 @@ impl MainPodBuilder {
266267
let value_of_st = self.op(
267268
public,
268269
Operation(
269-
NativeOperation::NewEntry,
270+
OperationType::Native(NativeOperation::NewEntry),
270271
vec![OperationArg::Entry(k.clone(), v.clone())],
271272
),
272273
);
@@ -291,36 +292,49 @@ impl MainPodBuilder {
291292

292293
pub fn op(&mut self, public: bool, mut op: Operation) -> Statement {
293294
use NativeOperation::*;
294-
let Operation(op_type, ref mut args) = op;
295+
let Operation(op_type, ref mut args) = &mut op;
295296
// TODO: argument type checking
296297
let st = match op_type {
297-
None => Statement(NativePredicate::None, vec![]),
298-
NewEntry => Statement(NativePredicate::ValueOf, self.op_args_entries(public, args)),
299-
CopyStatement => todo!(),
300-
EqualFromEntries => {
301-
Statement(NativePredicate::Equal, self.op_args_entries(public, args))
302-
}
303-
NotEqualFromEntries => Statement(
304-
NativePredicate::NotEqual,
305-
self.op_args_entries(public, args),
306-
),
307-
GtFromEntries => Statement(NativePredicate::Gt, self.op_args_entries(public, args)),
308-
LtFromEntries => Statement(NativePredicate::Lt, self.op_args_entries(public, args)),
309-
TransitiveEqualFromStatements => todo!(),
310-
GtToNotEqual => todo!(),
311-
LtToNotEqual => todo!(),
312-
ContainsFromEntries => Statement(
313-
NativePredicate::Contains,
314-
self.op_args_entries(public, args),
315-
),
316-
NotContainsFromEntries => Statement(
317-
NativePredicate::NotContains,
318-
self.op_args_entries(public, args),
319-
),
320-
RenameContainedBy => todo!(),
321-
SumOf => todo!(),
322-
ProductOf => todo!(),
323-
MaxOf => todo!(),
298+
OperationType::Native(o) => match o {
299+
None => Statement(Predicate::Native(NativePredicate::None), vec![]),
300+
NewEntry => Statement(
301+
Predicate::Native(NativePredicate::ValueOf),
302+
self.op_args_entries(public, args),
303+
),
304+
CopyStatement => todo!(),
305+
EqualFromEntries => Statement(
306+
Predicate::Native(NativePredicate::Equal),
307+
self.op_args_entries(public, args),
308+
),
309+
NotEqualFromEntries => Statement(
310+
Predicate::Native(NativePredicate::NotEqual),
311+
self.op_args_entries(public, args),
312+
),
313+
GtFromEntries => Statement(
314+
Predicate::Native(NativePredicate::Gt),
315+
self.op_args_entries(public, args),
316+
),
317+
LtFromEntries => Statement(
318+
Predicate::Native(NativePredicate::Lt),
319+
self.op_args_entries(public, args),
320+
),
321+
TransitiveEqualFromStatements => todo!(),
322+
GtToNotEqual => todo!(),
323+
LtToNotEqual => todo!(),
324+
ContainsFromEntries => Statement(
325+
Predicate::Native(NativePredicate::Contains),
326+
self.op_args_entries(public, args),
327+
),
328+
NotContainsFromEntries => Statement(
329+
Predicate::Native(NativePredicate::NotContains),
330+
self.op_args_entries(public, args),
331+
),
332+
RenameContainedBy => todo!(),
333+
SumOf => todo!(),
334+
ProductOf => todo!(),
335+
MaxOf => todo!(),
336+
},
337+
_ => todo!(),
324338
};
325339
self.operations.push(op);
326340
if public {
@@ -440,7 +454,7 @@ impl MainPodCompiler {
440454

441455
fn compile_op(&self, op: &Operation) -> middleware::Operation {
442456
// TODO
443-
let mop_code: middleware::NativeOperation = op.0.into();
457+
let mop_code: OperationType = op.0.clone();
444458
let mop_args =
445459
op.1.iter()
446460
.flat_map(|arg| self.compile_op_arg(arg).map(|s| s.try_into().unwrap()))
@@ -496,22 +510,22 @@ pub mod build_utils {
496510
#[macro_export]
497511
macro_rules! op {
498512
(eq, $($arg:expr),+) => { crate::frontend::Operation(
499-
crate::middleware::NativeOperation::EqualFromEntries,
513+
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::EqualFromEntries),
500514
crate::op_args!($($arg),*)) };
501515
(ne, $($arg:expr),+) => { crate::frontend::Operation(
502-
crate::middleware::NativeOperation::NotEqualFromEntries,
516+
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotEqualFromEntries),
503517
crate::op_args!($($arg),*)) };
504518
(gt, $($arg:expr),+) => { crate::frontend::Operation(
505-
crate::middleware::NativeOperation::GtFromEntries,
519+
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::GtFromEntries),
506520
crate::op_args!($($arg),*)) };
507521
(lt, $($arg:expr),+) => { crate::frontend::Operation(
508-
crate::middleware::NativeOperation::LtFromEntries,
522+
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::LtFromEntries),
509523
crate::op_args!($($arg),*)) };
510524
(contains, $($arg:expr),+) => { crate::frontend::Operation(
511-
crate::middleware::NativeOperation::ContainsFromEntries,
525+
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::ContainsFromEntries),
512526
crate::op_args!($($arg),*)) };
513527
(not_contains, $($arg:expr),+) => { crate::frontend::Operation(
514-
crate::middleware::NativeOperation::NotContainsFromEntries,
528+
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotContainsFromEntries),
515529
crate::op_args!($($arg),*)) };
516530
}
517531
}

src/frontend/operation.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::fmt;
22

33
use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value};
4-
use crate::middleware::{hash_str, NativeOperation, NativePredicate};
4+
use crate::middleware::{hash_str, NativeOperation, NativePredicate, OperationType, Predicate};
55

66
#[derive(Clone, Debug, PartialEq, Eq)]
77
pub enum OperationArg {
@@ -55,7 +55,7 @@ impl From<(&SignedPod, &str)> for OperationArg {
5555
// TODO: Actual value, TryFrom.
5656
let value = pod.kvs().get(&hash_str(key)).unwrap().clone();
5757
Self::Statement(Statement(
58-
NativePredicate::ValueOf,
58+
Predicate::Native(NativePredicate::ValueOf),
5959
vec![
6060
StatementArg::Key(AnchoredKey(pod.origin(), key.to_string())),
6161
StatementArg::Literal(Value::Raw(value)),
@@ -65,7 +65,7 @@ impl From<(&SignedPod, &str)> for OperationArg {
6565
}
6666

6767
#[derive(Clone, Debug, PartialEq, Eq)]
68-
pub struct Operation(pub NativeOperation, pub Vec<OperationArg>);
68+
pub struct Operation(pub OperationType, pub Vec<OperationArg>);
6969

7070
impl fmt::Display for Operation {
7171
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {

src/frontend/statement.rs

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
22
use std::fmt;
33

44
use super::{AnchoredKey, Value};
5-
use crate::middleware::{self, NativePredicate};
5+
use crate::middleware::{self, NativePredicate, Predicate};
66

77
#[derive(Clone, Debug, PartialEq, Eq)]
88
pub enum StatementArg {
@@ -20,7 +20,7 @@ impl fmt::Display for StatementArg {
2020
}
2121

2222
#[derive(Clone, Debug, PartialEq, Eq)]
23-
pub struct Statement(pub NativePredicate, pub Vec<StatementArg>);
23+
pub struct Statement(pub Predicate, pub Vec<StatementArg>);
2424

2525
impl TryFrom<Statement> for middleware::Statement {
2626
type Error = anyhow::Error;
@@ -33,38 +33,50 @@ impl TryFrom<Statement> for middleware::Statement {
3333
s.1.get(1).cloned(),
3434
s.1.get(2).cloned(),
3535
);
36-
Ok(match (s.0, args) {
37-
(NP::None, (None, None, None)) => MS::None,
38-
(NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => {
39-
MS::ValueOf(ak.into(), (&v).into())
40-
}
41-
(NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
42-
MS::Equal(ak1.into(), ak2.into())
43-
}
44-
(NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
45-
MS::NotEqual(ak1.into(), ak2.into())
46-
}
47-
(NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
48-
MS::Gt(ak1.into(), ak2.into())
49-
}
50-
(NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
51-
MS::Lt(ak1.into(), ak2.into())
52-
}
53-
(NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
54-
MS::Contains(ak1.into(), ak2.into())
55-
}
56-
(NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
57-
MS::NotContains(ak1.into(), ak2.into())
58-
}
59-
(NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
60-
MS::SumOf(ak1.into(), ak2.into(), ak3.into())
61-
}
62-
(NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
63-
MS::ProductOf(ak1.into(), ak2.into(), ak3.into())
64-
}
65-
(NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
66-
MS::MaxOf(ak1.into(), ak2.into(), ak3.into())
67-
}
36+
Ok(match &s.0 {
37+
Predicate::Native(np) => match (np, args) {
38+
(NP::None, (None, None, None)) => MS::None,
39+
(NP::ValueOf, (Some(SA::Key(ak)), Some(StatementArg::Literal(v)), None)) => {
40+
MS::ValueOf(ak.into(), (&v).into())
41+
}
42+
(NP::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
43+
MS::Equal(ak1.into(), ak2.into())
44+
}
45+
(NP::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
46+
MS::NotEqual(ak1.into(), ak2.into())
47+
}
48+
(NP::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
49+
MS::Gt(ak1.into(), ak2.into())
50+
}
51+
(NP::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
52+
MS::Lt(ak1.into(), ak2.into())
53+
}
54+
(NP::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
55+
MS::Contains(ak1.into(), ak2.into())
56+
}
57+
(NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
58+
MS::NotContains(ak1.into(), ak2.into())
59+
}
60+
(NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
61+
MS::SumOf(ak1.into(), ak2.into(), ak3.into())
62+
}
63+
(NP::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
64+
MS::ProductOf(ak1.into(), ak2.into(), ak3.into())
65+
}
66+
(NP::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
67+
MS::MaxOf(ak1.into(), ak2.into(), ak3.into())
68+
}
69+
_ => Err(anyhow!("Ill-formed statement: {}", s))?,
70+
},
71+
Predicate::Custom(cpr) => MS::Custom(
72+
cpr.clone(),
73+
s.1.iter()
74+
.map(|arg| match arg {
75+
StatementArg::Key(ak) => Ok(ak.clone().into()),
76+
_ => Err(anyhow!("Invalid statement arg: {}", arg)),
77+
})
78+
.collect::<Result<Vec<_>>>()?,
79+
),
6880
_ => Err(anyhow!("Ill-formed statement: {}", s))?,
6981
})
7082
}

0 commit comments

Comments
 (0)