Skip to content

Commit 2864ef2

Browse files
tideofwordsax0
andauthored
Implement more frontend ops (#111)
* middleware operation output statement? * small refactor to op() on frontend * Implement op() * cargo fmt * Clippy * Code review --------- Co-authored-by: Ahmad <root@ahmadafuni.com>
1 parent 6627b46 commit 2864ef2

File tree

5 files changed

+561
-61
lines changed

5 files changed

+561
-61
lines changed

src/backends/plonky2/mock_main/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ impl Pod for MockMainPod {
455455
StatementArg::Key(AnchoredKey(pod_id, h)) if *pod_id == SELF => {
456456
StatementArg::Key(AnchoredKey(self.id(), *h))
457457
}
458-
_ => sa.clone(),
458+
_ => *sa,
459459
})
460460
.collect(),
461461
)

src/frontend/mod.rs

Lines changed: 286 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! The frontend includes the user-level abstractions and user-friendly types to define and work
22
//! with Pods.
33
4-
use anyhow::{anyhow, Result};
4+
use anyhow::{anyhow, Error, Result};
55
use itertools::Itertools;
66
use std::collections::HashMap;
77
use std::convert::From;
@@ -83,6 +83,17 @@ impl From<middleware::Value> for Value {
8383
}
8484
}
8585

86+
impl TryInto<i64> for Value {
87+
type Error = Error;
88+
fn try_into(self) -> std::result::Result<i64, Self::Error> {
89+
if let Value::Int(n) = self {
90+
Ok(n)
91+
} else {
92+
Err(anyhow!("Value not an int"))
93+
}
94+
}
95+
}
96+
8697
impl fmt::Display for Value {
8798
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
8899
match self {
@@ -317,6 +328,7 @@ impl MainPodBuilder {
317328
panic!("Invalid statement argument.");
318329
}
319330
}
331+
// todo: better error handling
320332
OperationArg::Literal(v) => {
321333
let k = format!("c{}", self.const_cnt);
322334
self.const_cnt += 1;
@@ -354,45 +366,226 @@ impl MainPodBuilder {
354366
use NativeOperation::*;
355367
let Operation(op_type, ref mut args) = &mut op;
356368
// TODO: argument type checking
357-
let st = match op_type {
369+
let pred = op_type
370+
.output_predicate()
371+
.map(|p| Ok(p))
372+
.unwrap_or_else(|| {
373+
// We are dealing with a copy here.
374+
match (&args).get(0) {
375+
Some(OperationArg::Statement(s)) if args.len() == 1 => Ok(s.0.clone()),
376+
_ => Err(anyhow!("Invalid arguments to copy operation: {:?}", args)),
377+
}
378+
})?;
379+
380+
let st_args: Vec<StatementArg> = match op_type {
358381
OperationType::Native(o) => match o {
359-
None => Statement(Predicate::Native(NativePredicate::None), vec![]),
360-
NewEntry => Statement(
361-
Predicate::Native(NativePredicate::ValueOf),
362-
self.op_args_entries(public, args)?,
363-
),
364-
CopyStatement => todo!(),
365-
EqualFromEntries => Statement(
366-
Predicate::Native(NativePredicate::Equal),
367-
self.op_args_entries(public, args)?,
368-
),
369-
NotEqualFromEntries => Statement(
370-
Predicate::Native(NativePredicate::NotEqual),
371-
self.op_args_entries(public, args)?,
372-
),
373-
GtFromEntries => Statement(
374-
Predicate::Native(NativePredicate::Gt),
375-
self.op_args_entries(public, args)?,
376-
),
377-
LtFromEntries => Statement(
378-
Predicate::Native(NativePredicate::Lt),
379-
self.op_args_entries(public, args)?,
380-
),
381-
TransitiveEqualFromStatements => todo!(),
382-
GtToNotEqual => todo!(),
383-
LtToNotEqual => todo!(),
384-
ContainsFromEntries => Statement(
385-
Predicate::Native(NativePredicate::Contains),
386-
self.op_args_entries(public, args)?,
387-
),
388-
NotContainsFromEntries => Statement(
389-
Predicate::Native(NativePredicate::NotContains),
390-
self.op_args_entries(public, args)?,
391-
),
392-
RenameContainedBy => todo!(),
393-
SumOf => todo!(),
394-
ProductOf => todo!(),
395-
MaxOf => todo!(),
382+
None => vec![],
383+
NewEntry => self.op_args_entries(public, args)?,
384+
CopyStatement => match &args[0] {
385+
OperationArg::Statement(s) => s.1.clone(),
386+
_ => {
387+
return Err(anyhow!("Invalid arguments to operation: {}", op));
388+
}
389+
},
390+
EqualFromEntries => self.op_args_entries(public, args)?,
391+
NotEqualFromEntries => self.op_args_entries(public, args)?,
392+
GtFromEntries => self.op_args_entries(public, args)?,
393+
LtFromEntries => self.op_args_entries(public, args)?,
394+
TransitiveEqualFromStatements => {
395+
match (args[0].clone(), args[1].clone()) {
396+
(
397+
OperationArg::Statement(Statement(
398+
Predicate::Native(NativePredicate::Equal),
399+
st0_args,
400+
)),
401+
OperationArg::Statement(Statement(
402+
Predicate::Native(NativePredicate::Equal),
403+
st1_args,
404+
)),
405+
) => {
406+
// st_args0 == vec![ak0, ak1]
407+
// st_args1 == vec![ak1, ak2]
408+
// output statement Equals(ak0, ak2)
409+
if st0_args[1] == st1_args[0] {
410+
vec![st0_args[0].clone(), st1_args[1].clone()]
411+
} else {
412+
return Err(anyhow!("Invalid arguments to operation"));
413+
}
414+
}
415+
_ => {
416+
return Err(anyhow!("Invalid arguments to operation"));
417+
}
418+
}
419+
}
420+
GtToNotEqual => match args[0].clone() {
421+
OperationArg::Statement(Statement(
422+
Predicate::Native(NativePredicate::Gt),
423+
st_args,
424+
)) => {
425+
vec![st_args[0].clone()]
426+
}
427+
_ => {
428+
return Err(anyhow!("Invalid arguments to operation"));
429+
}
430+
},
431+
LtToNotEqual => match args[0].clone() {
432+
OperationArg::Statement(Statement(
433+
Predicate::Native(NativePredicate::Lt),
434+
st_args,
435+
)) => {
436+
vec![st_args[0].clone()]
437+
}
438+
_ => {
439+
return Err(anyhow!("Invalid arguments to operation"));
440+
}
441+
},
442+
ContainsFromEntries => self.op_args_entries(public, args)?,
443+
NotContainsFromEntries => self.op_args_entries(public, args)?,
444+
SumOf => match (args[0].clone(), args[1].clone(), args[2].clone()) {
445+
(
446+
OperationArg::Statement(Statement(
447+
Predicate::Native(NativePredicate::ValueOf),
448+
st0_args,
449+
)),
450+
OperationArg::Statement(Statement(
451+
Predicate::Native(NativePredicate::ValueOf),
452+
st1_args,
453+
)),
454+
OperationArg::Statement(Statement(
455+
Predicate::Native(NativePredicate::ValueOf),
456+
st2_args,
457+
)),
458+
) => {
459+
let st_args: Vec<StatementArg> = match (
460+
st0_args[1].clone(),
461+
st1_args[1].clone(),
462+
st2_args[1].clone(),
463+
) {
464+
(
465+
StatementArg::Literal(v0),
466+
StatementArg::Literal(v1),
467+
StatementArg::Literal(v2),
468+
) => {
469+
let v0: i64 = v0.clone().try_into()?;
470+
let v1: i64 = v1.clone().try_into()?;
471+
let v2: i64 = v2.clone().try_into()?;
472+
if v0 == v1 + v2 {
473+
vec![
474+
st0_args[0].clone(),
475+
st1_args[0].clone(),
476+
st2_args[0].clone(),
477+
]
478+
} else {
479+
return Err(anyhow!("Invalid arguments to operation"));
480+
}
481+
}
482+
_ => {
483+
return Err(anyhow!("Invalid arguments to operation"));
484+
}
485+
};
486+
st_args
487+
}
488+
_ => {
489+
return Err(anyhow!("Invalid arguments to operation"));
490+
}
491+
},
492+
ProductOf => match (args[0].clone(), args[1].clone(), args[2].clone()) {
493+
(
494+
OperationArg::Statement(Statement(
495+
Predicate::Native(NativePredicate::ValueOf),
496+
st0_args,
497+
)),
498+
OperationArg::Statement(Statement(
499+
Predicate::Native(NativePredicate::ValueOf),
500+
st1_args,
501+
)),
502+
OperationArg::Statement(Statement(
503+
Predicate::Native(NativePredicate::ValueOf),
504+
st2_args,
505+
)),
506+
) => {
507+
let st_args: Vec<StatementArg> = match (
508+
st0_args[1].clone(),
509+
st1_args[1].clone(),
510+
st2_args[1].clone(),
511+
) {
512+
(
513+
StatementArg::Literal(v0),
514+
StatementArg::Literal(v1),
515+
StatementArg::Literal(v2),
516+
) => {
517+
let v0: i64 = v0.clone().try_into()?;
518+
let v1: i64 = v1.clone().try_into()?;
519+
let v2: i64 = v2.clone().try_into()?;
520+
if v0 == v1 * v2 {
521+
vec![
522+
st0_args[0].clone(),
523+
st1_args[0].clone(),
524+
st2_args[0].clone(),
525+
]
526+
} else {
527+
return Err(anyhow!("Invalid arguments to operation"));
528+
}
529+
}
530+
_ => {
531+
return Err(anyhow!("Invalid arguments to operation"));
532+
}
533+
};
534+
st_args
535+
}
536+
_ => {
537+
return Err(anyhow!("Invalid arguments to operation"));
538+
}
539+
},
540+
MaxOf => match (args[0].clone(), args[1].clone(), args[2].clone()) {
541+
(
542+
OperationArg::Statement(Statement(
543+
Predicate::Native(NativePredicate::ValueOf),
544+
st0_args,
545+
)),
546+
OperationArg::Statement(Statement(
547+
Predicate::Native(NativePredicate::ValueOf),
548+
st1_args,
549+
)),
550+
OperationArg::Statement(Statement(
551+
Predicate::Native(NativePredicate::ValueOf),
552+
st2_args,
553+
)),
554+
) => {
555+
let st_args: Vec<StatementArg> = match (
556+
st0_args[1].clone(),
557+
st1_args[1].clone(),
558+
st2_args[1].clone(),
559+
) {
560+
(
561+
StatementArg::Literal(v0),
562+
StatementArg::Literal(v1),
563+
StatementArg::Literal(v2),
564+
) => {
565+
let v0: i64 = v0.clone().try_into()?;
566+
let v1: i64 = v1.clone().try_into()?;
567+
let v2: i64 = v2.clone().try_into()?;
568+
if v0 == std::cmp::max(v1, v2) {
569+
vec![
570+
st0_args[0].clone(),
571+
st1_args[0].clone(),
572+
st2_args[0].clone(),
573+
]
574+
} else {
575+
return Err(anyhow!("Invalid arguments to operation"));
576+
}
577+
}
578+
_ => {
579+
return Err(anyhow!("Invalid arguments to operation"));
580+
}
581+
};
582+
st_args
583+
}
584+
RenameContainedBy => todo!(),
585+
_ => {
586+
return Err(anyhow!("Invalid arguments to operation"));
587+
}
588+
},
396589
},
397590
OperationType::Custom(cpr) => {
398591
// All args should be statements to be pattern matched against statement templates.
@@ -413,7 +606,8 @@ impl MainPodBuilder {
413606
))
414607
})
415608
.collect::<Result<Vec<_>>>()?;
416-
let output_args = output_arg_values
609+
610+
output_arg_values
417611
.chunks(2)
418612
.map(|chunk| {
419613
Ok(StatementArg::Key(AnchoredKey(
@@ -430,10 +624,10 @@ impl MainPodBuilder {
430624
.ok_or(anyhow!("Missing key corresponding to hash."))?,
431625
)))
432626
})
433-
.collect::<Result<Vec<_>>>()?;
434-
Statement(Predicate::Custom(cpr.clone()), output_args)
627+
.collect::<Result<Vec<_>>>()?
435628
}
436629
};
630+
let st = Statement(pred, st_args);
437631
self.operations.push(op);
438632
if public {
439633
self.public_statements.push(st.clone());
@@ -679,8 +873,8 @@ pub mod build_utils {
679873
$crate::middleware::OperationType::Native($crate::middleware::NativeOperation::EqualFromEntries),
680874
$crate::op_args!($($arg),*)) };
681875
(ne, $($arg:expr),+) => { $crate::frontend::Operation(
682-
$crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotEqualFromEntries),
683-
crate::op_args!($($arg),*)) };
876+
$crate::middleware::OperationType::Native($crate::middleware::NativeOperation::NotEqualFromEntries),
877+
$crate::op_args!($($arg),*)) };
684878
(gt, $($arg:expr),+) => { crate::frontend::Operation(
685879
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::GtFromEntries),
686880
crate::op_args!($($arg),*)) };
@@ -830,6 +1024,54 @@ pub mod tests {
8301024
Ok(())
8311025
}
8321026

1027+
#[test]
1028+
// Transitive equality not implemented yet
1029+
#[should_panic]
1030+
fn test_equal() {
1031+
let params = Params::default();
1032+
let mut signed_builder = SignedPodBuilder::new(&params);
1033+
signed_builder.insert("a", 1);
1034+
signed_builder.insert("b", 1);
1035+
let mut signer = MockSigner { pk: "key".into() };
1036+
let signed_pod = signed_builder.sign(&mut signer).unwrap();
1037+
1038+
let mut builder = MainPodBuilder::new(&params);
1039+
builder.add_signed_pod(&signed_pod);
1040+
1041+
//let op_val1 = Operation{
1042+
// OperationType::Native(NativeOperation::CopyStatement),
1043+
// signed_pod.
1044+
//}
1045+
1046+
let op_eq1 = Operation(
1047+
OperationType::Native(NativeOperation::EqualFromEntries),
1048+
vec![
1049+
OperationArg::from((&signed_pod, "a")),
1050+
OperationArg::from((&signed_pod, "b")),
1051+
],
1052+
);
1053+
let st1 = builder.op(true, op_eq1).unwrap();
1054+
let op_eq2 = Operation(
1055+
OperationType::Native(NativeOperation::EqualFromEntries),
1056+
vec![
1057+
OperationArg::from((&signed_pod, "b")),
1058+
OperationArg::from((&signed_pod, "a")),
1059+
],
1060+
);
1061+
let st2 = builder.op(true, op_eq2).unwrap();
1062+
1063+
let op_eq3 = Operation(
1064+
OperationType::Native(NativeOperation::TransitiveEqualFromStatements),
1065+
vec![OperationArg::Statement(st1), OperationArg::Statement(st2)],
1066+
);
1067+
let st3 = builder.op(true, op_eq3);
1068+
1069+
let mut prover = MockProver {};
1070+
let pod = builder.prove(&mut prover, &params).unwrap();
1071+
1072+
println!("{}", pod);
1073+
}
1074+
8331075
#[test]
8341076
#[should_panic]
8351077
fn test_false_st() {

0 commit comments

Comments
 (0)