Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 102 additions & 138 deletions src/backends/mock_main.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
mod operation;
mod statement;

use crate::middleware::{
self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativeStatement, NonePod,
Params, Pod, PodId, PodProver, Statement, StatementArg, ToFields, KEY_TYPE, SELF,
Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF,
};
use anyhow::Result;
use itertools::Itertools;
pub use operation::*;
use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher;
pub use statement::*;
use std::any::Any;
use std::error::Error;
use std::fmt;

pub const VALUE_TYPE: &str = "MockMainPOD";

pub struct MockProver {}

impl PodProver for MockProver {
Expand All @@ -18,72 +24,6 @@ impl PodProver for MockProver {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
enum OperationArg {
None,
Index(usize),
}

impl OperationArg {
fn is_none(&self) -> bool {
matches!(self, OperationArg::None)
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
enum OperationArgError {
KeyNotFound,
StatementNotFound,
}

impl std::fmt::Display for OperationArgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OperationArgError::KeyNotFound => write!(f, "Key not found"),
OperationArgError::StatementNotFound => write!(f, "Statement not found"),
}
}
}

impl std::error::Error for OperationArgError {}

#[derive(Clone, Debug, PartialEq, Eq)]
struct Operation(pub NativeOperation, pub Vec<OperationArg>);

impl Operation {
pub fn deref(&self, statements: &[Statement]) -> crate::middleware::Operation {
let deref_args = self
.1
.iter()
.map(|arg| match arg {
OperationArg::None => middleware::OperationArg::None,
OperationArg::Index(i) => {
middleware::OperationArg::Statement(statements[*i].clone())
}
})
.collect();
middleware::Operation(self.0, deref_args)
}
}

impl fmt::Display for Operation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?} ", self.0)?;
for (i, arg) in self.1.iter().enumerate() {
if !(!f.alternate() && arg.is_none()) {
if i != 0 {
write!(f, " ")?;
}
match arg {
OperationArg::None => write!(f, "none")?,
OperationArg::Index(i) => write!(f, "{:02}", i)?,
}
}
}
Ok(())
}
}

#[derive(Clone, Debug)]
pub struct MockMainPod {
params: Params,
Expand Down Expand Up @@ -188,12 +128,16 @@ impl MockMainPod {
fn offset_public_statements(&self) -> usize {
self.offset_input_statements() + self.params.max_priv_statements()
}
fn pad_statement(params: &Params, s: &mut Statement) {
fill_pad(&mut s.1, StatementArg::None, params.max_statement_args)
}
fn pad_operation(params: &Params, op: &mut Operation) {
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args)
}

fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec<Statement> {
let mut statements = Vec::new();

let st_none = Self::statement_none(params);

// Input signed pods region
let none_sig_pod: Box<dyn Pod> = Box::new(NonePod {});
assert!(inputs.signed_pods.len() <= params.max_input_signed_pods);
Expand All @@ -206,8 +150,12 @@ impl MockMainPod {
let sts = pod.pub_statements();
assert!(sts.len() <= params.max_signed_pod_values);
for j in 0..params.max_signed_pod_values {
let mut st = sts.get(j).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = sts
.get(j)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}
}
Expand All @@ -224,63 +172,68 @@ impl MockMainPod {
let sts = pod.pub_statements();
assert!(sts.len() <= params.max_public_statements);
for j in 0..params.max_public_statements {
let mut st = sts.get(j).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = sts
.get(j)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}
}

// Input statements
assert!(inputs.statements.len() <= params.max_priv_statements());
for i in 0..params.max_priv_statements() {
let mut st = inputs.statements.get(i).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = inputs
.statements
.get(i)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}

// Public statements
assert!(inputs.public_statements.len() < params.max_public_statements);
statements.push(Statement(
NativeStatement::ValueOf,
vec![StatementArg::Key(AnchoredKey(SELF, hash_str(KEY_TYPE)))],
));
let mut type_st = middleware::Statement::ValueOf(
AnchoredKey(SELF, hash_str(KEY_TYPE)),
middleware::Value(hash_str(VALUE_TYPE).0),
)
.into();
Self::pad_statement(params, &mut type_st);
statements.push(type_st);

for i in 0..(params.max_public_statements - 1) {
let mut st = inputs.public_statements.get(i).unwrap_or(&st_none).clone();
Self::pad_statement_args(params, &mut st.1);
let mut st = inputs
.public_statements
.get(i)
.unwrap_or(&middleware::Statement::None)
.clone()
.into();
Self::pad_statement(params, &mut st);
statements.push(st);
}

statements
}

pub fn find_op_arg(
fn find_op_arg(
statements: &[Statement],
op_arg: &middleware::OperationArg,
op_arg: &middleware::Statement,
) -> Result<OperationArg, OperationArgError> {
match op_arg {
middleware::OperationArg::None => Ok(OperationArg::None),
middleware::OperationArg::Key(k) => {
statements
.iter()
.enumerate()
.find_map(|(i, s)| match s.0 {
NativeStatement::ValueOf => match &s.1[0] {
StatementArg::Key(sk) => (sk == k).then_some(i),
_ => None,
},
_ => None,
})
.map(OperationArg::Index)
.ok_or(OperationArgError::KeyNotFound)
}
middleware::OperationArg::Statement(st) => {
statements
.iter()
.enumerate()
.find_map(|(i, s)| (s == st).then_some(i))
.map(OperationArg::Index)
.ok_or(OperationArgError::StatementNotFound)
}
middleware::Statement::None => Ok(OperationArg::None),
_ => statements
.iter()
.enumerate()
.find_map(|(i, s)| {
// TODO: Error handling
(&middleware::Statement::try_from(s.clone()).unwrap() == op_arg).then_some(i)
})
.map(OperationArg::Index)
.ok_or(OperationArgError::StatementNotFound),
}
}

Expand All @@ -289,19 +242,19 @@ impl MockMainPod {
statements: &[Statement],
input_operations: &[middleware::Operation],
) -> Result<Vec<Operation>, OperationArgError> {
let op_none = Self::operation_none(params);

let mut operations = Vec::new();
for i in 0..params.max_priv_statements() {
let op = input_operations.get(i).unwrap_or(&op_none).clone();
let mut mid_args = op.1;
Self::pad_operation_args(params, &mut mid_args);
let mut args = Vec::with_capacity(mid_args.len());
for mid_arg in &mid_args {
let op_arg = Self::find_op_arg(statements, mid_arg)?;
args.push(op_arg)
}
operations.push(Operation(op.0, args));
let op = input_operations
.get(i)
.unwrap_or(&middleware::Operation::None)
.clone();
let mid_args = op.args();
let mut args = mid_args
.iter()
.map(|mid_arg| Self::find_op_arg(statements, mid_arg))
.collect::<Result<Vec<_>, OperationArgError>>()?;
Self::pad_operation_args(params, &mut args);
operations.push(Operation(op.code(), args));
}
Ok(operations)
}
Expand All @@ -320,11 +273,11 @@ impl MockMainPod {
let mut op = if st.is_none() {
Operation(NativeOperation::None, vec![])
} else {
let mid_arg = middleware::OperationArg::Statement(st.clone());
let op_arg = Self::find_op_arg(statements, &mid_arg)?;
let mid_arg = st.clone();
Operation(
NativeOperation::CopyStatement,
vec![op_arg],
// TODO
vec![Self::find_op_arg(statements, &mid_arg.try_into().unwrap())?],
)
};
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args);
Expand All @@ -351,7 +304,16 @@ impl MockMainPod {
.map(|p| (*p).clone())
.collect_vec();
let input_main_pods = inputs.main_pods.iter().map(|p| (*p).clone()).collect_vec();
let input_statements = inputs.statements.iter().cloned().collect_vec();
let input_statements = inputs
.statements
.iter()
.cloned()
.map(|s| {
let mut s = s.into();
Self::pad_statement(params, &mut s);
s
})
.collect_vec();
let public_statements =
statements[statements.len() - params.max_public_statements..].to_vec();

Expand All @@ -376,26 +338,22 @@ impl MockMainPod {
Statement(NativeStatement::None, args)
}

fn operation_none(params: &Params) -> middleware::Operation {
let mut args = Vec::with_capacity(params.max_operation_args);
Self::pad_operation_args(&params, &mut args);
middleware::Operation(NativeOperation::None, args)
fn operation_none(params: &Params) -> Operation {
let mut op = Operation(NativeOperation::None, vec![]);
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args);
op
}

fn pad_statement_args(params: &Params, args: &mut Vec<StatementArg>) {
fill_pad(args, StatementArg::None, params.max_statement_args)
}

fn pad_operation_args(params: &Params, args: &mut Vec<middleware::OperationArg>) {
fill_pad(
args,
middleware::OperationArg::None,
params.max_operation_args,
)
fn pad_operation_args(params: &Params, args: &mut Vec<OperationArg>) {
fill_pad(args, OperationArg::None, params.max_operation_args)
}
}

pub fn hash_statements(statements: &[middleware::Statement]) -> Result<middleware::Hash> {
pub fn hash_statements(statements: &[Statement]) -> Result<middleware::Hash> {
let field_elems = statements
.into_iter()
.flat_map(|statement| statement.clone().to_fields().0)
Expand Down Expand Up @@ -444,7 +402,7 @@ impl Pod for MockMainPod {
s,
)
})
.filter(|(i, s)| s.0 == NativeStatement::ValueOf)
.filter(|(_, s)| s.0 == NativeStatement::ValueOf)
.flat_map(|(i, s)| {
if let StatementArg::Key(ak) = &s.1[0] {
vec![(i, ak.1, ak.0)]
Expand All @@ -463,7 +421,8 @@ impl Pod for MockMainPod {
.map(|(i, s)| {
self.operations[i]
.deref(&self.statements[..input_statement_offset + i])
.check(s.clone())
.unwrap()
.check(&s.clone().try_into().unwrap())
})
.collect::<Result<Vec<_>>>()
.unwrap();
Expand All @@ -472,7 +431,7 @@ impl Pod for MockMainPod {
fn id(&self) -> PodId {
self.id
}
fn pub_statements(&self) -> Vec<Statement> {
fn pub_statements(&self) -> Vec<middleware::Statement> {
// return the public statements, where when origin=SELF is replaced by origin=self.id()
self.statements
.iter()
Expand All @@ -492,6 +451,8 @@ impl Pod for MockMainPod {
})
.collect(),
)
.try_into()
.unwrap()
})
.collect()
}
Expand All @@ -505,7 +466,10 @@ impl Pod for MockMainPod {
pub mod tests {
use super::*;
use crate::backends::mock_signed::MockSigner;
use crate::examples::{great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders};
use crate::examples::{
great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder,
zu_kyc_sign_pod_builders,
};
use crate::middleware;

#[test]
Expand Down Expand Up @@ -559,6 +523,6 @@ pub mod tests {
let pod = proof_pod.pod.into_any().downcast::<MockMainPod>().unwrap();

println!("{}", pod);
assert_eq!(pod.verify(), true);
assert_eq!(pod.verify(), true);
}
}
Loading
Loading