Skip to content

Commit 220a387

Browse files
committed
Experiment with statement & op enums
1 parent 90e9782 commit 220a387

File tree

3 files changed

+728
-377
lines changed

3 files changed

+728
-377
lines changed

src/backends/mock_main.rs

Lines changed: 170 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
use crate::middleware::{
22
self, hash_str, AnchoredKey, Hash, MainPod, MainPodInputs, NativeOperation, NativeStatement,
3-
NoneMainPod, NoneSignedPod, Params, PodId, PodProver, SignedPod, Statement, StatementArg,
4-
ToFields, KEY_TYPE, SELF,
3+
NoneMainPod, NoneSignedPod, Params, PodId, PodProver, SignedPod, StatementArg, ToFields, Value,
4+
KEY_TYPE, SELF,
55
};
6-
use anyhow::Result;
6+
use anyhow::{anyhow, Result};
77
use itertools::Itertools;
88
use plonky2::hash::poseidon::PoseidonHash;
99
use plonky2::plonk::config::Hasher;
1010
use std::any::Any;
1111
use std::fmt;
1212
use std::io::{self, Write};
1313

14+
pub const VALUE_TYPE: &str = "MockMainPOD";
15+
1416
pub struct MockProver {}
1517

1618
impl PodProver for MockProver {
@@ -35,18 +37,103 @@ impl OperationArg {
3537
struct Operation(pub NativeOperation, pub Vec<OperationArg>);
3638

3739
impl Operation {
38-
pub fn deref(&self, statements: &[Statement]) -> crate::middleware::Operation {
40+
pub fn deref(&self, statements: &[Statement]) -> Result<crate::middleware::Operation> {
3941
let deref_args = self
4042
.1
4143
.iter()
42-
.map(|arg| match arg {
43-
OperationArg::None => middleware::OperationArg::None,
44+
.flat_map(|arg| match arg {
45+
OperationArg::None => vec![],
4446
OperationArg::Index(i) => {
45-
middleware::OperationArg::Statement(statements[*i].clone())
47+
vec![statements[*i].clone().try_into()]
4648
}
4749
})
48-
.collect();
49-
middleware::Operation(self.0, deref_args)
50+
.collect::<Result<Vec<crate::middleware::Statement>>>()?;
51+
middleware::Operation::op(self.0, &deref_args)
52+
}
53+
}
54+
55+
#[derive(Clone, Debug, PartialEq, Eq)]
56+
pub struct Statement(pub NativeStatement, pub Vec<StatementArg>);
57+
58+
impl Statement {
59+
pub fn is_none(&self) -> bool {
60+
self.0 == NativeStatement::None
61+
}
62+
}
63+
64+
impl ToFields for Statement {
65+
fn to_fields(self) -> (Vec<middleware::F>, usize) {
66+
let (native_statement_f, native_statement_f_len) = self.0.to_fields();
67+
let (vec_statementarg_f, vec_statementarg_f_len) = self
68+
.1
69+
.into_iter()
70+
.map(|statement_arg| statement_arg.to_fields())
71+
.fold((Vec::new(), 0), |mut acc, (f, l)| {
72+
acc.0.extend(f);
73+
acc.1 += l;
74+
acc
75+
});
76+
(
77+
[native_statement_f, vec_statementarg_f].concat(),
78+
native_statement_f_len + vec_statementarg_f_len,
79+
)
80+
}
81+
}
82+
83+
impl TryInto<middleware::Statement> for Statement {
84+
type Error = anyhow::Error;
85+
fn try_into(self) -> Result<middleware::Statement> {
86+
type S = middleware::Statement;
87+
type NS = NativeStatement;
88+
type SA = StatementArg;
89+
let args = (
90+
self.1.get(0).cloned(),
91+
self.1.get(1).cloned(),
92+
self.1.get(2).cloned(),
93+
);
94+
Ok(match (self.0, args) {
95+
(NS::None, _) => S::None,
96+
(NS::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None)) => S::ValueOf(ak, v),
97+
(NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Equal(ak1, ak2),
98+
(NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::NotEqual(ak1, ak2),
99+
(NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Gt(ak1, ak2),
100+
(NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Lt(ak1, ak2),
101+
(NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Contains(ak1, ak2),
102+
(NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
103+
S::NotContains(ak1, ak2)
104+
}
105+
(NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
106+
S::SumOf(ak1, ak2, ak3)
107+
}
108+
(NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
109+
S::ProductOf(ak1, ak2, ak3)
110+
}
111+
(NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
112+
S::MaxOf(ak1, ak2, ak3)
113+
}
114+
_ => Err(anyhow!("Malformed statement expression {}", self))?,
115+
})
116+
}
117+
}
118+
119+
impl From<middleware::Statement> for Statement {
120+
fn from(s: middleware::Statement) -> Self {
121+
Statement(s.code(), s.args().into_iter().map(|arg| arg).collect())
122+
}
123+
}
124+
125+
impl fmt::Display for Statement {
126+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
127+
write!(f, "{:?} ", self.0)?;
128+
for (i, arg) in self.1.iter().enumerate() {
129+
if !(!f.alternate() && arg.is_none()) {
130+
if i != 0 {
131+
write!(f, " ")?;
132+
}
133+
write!(f, "{}", arg)?;
134+
}
135+
}
136+
Ok(())
50137
}
51138
}
52139

@@ -158,6 +245,11 @@ fn fill_pad<T: Clone>(v: &mut Vec<T>, pad_value: T, len: usize) {
158245
}
159246
}
160247

248+
fn pad<T: Clone>(v: Vec<T>, pad_value: T, len: usize) -> Vec<T> {
249+
let v_len = v.len();
250+
[v, (v_len..len).map(|_| pad_value.clone()).collect()].concat()
251+
}
252+
161253
impl MockMainPod {
162254
fn offset_input_signed_pods(&self) -> usize {
163255
0
@@ -172,12 +264,19 @@ impl MockMainPod {
172264
fn offset_public_statements(&self) -> usize {
173265
self.offset_input_statements() + self.params.max_priv_statements()
174266
}
267+
fn pad_statement(params: &Params, s: Statement) -> Statement {
268+
Statement(s.0, pad(s.1, StatementArg::None, params.max_statement_args))
269+
}
270+
fn pad_operation(params: &Params, op: Operation) -> Operation {
271+
Operation(
272+
op.0,
273+
pad(op.1, OperationArg::None, params.max_operation_args),
274+
)
275+
}
175276

176277
fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec<Statement> {
177278
let mut statements = Vec::new();
178279

179-
let st_none = Self::statement_none(params);
180-
181280
// Input signed pods region
182281
let none_sig_pod: Box<dyn SignedPod> = Box::new(NoneSignedPod {});
183282
assert!(inputs.signed_pods.len() <= params.max_input_signed_pods);
@@ -187,12 +286,12 @@ impl MockMainPod {
187286
.get(i)
188287
.map(|p| *p)
189288
.unwrap_or(&none_sig_pod);
289+
// TODO
190290
let sts = pod.pub_statements();
191291
assert!(sts.len() <= params.max_signed_pod_values);
192292
for j in 0..params.max_signed_pod_values {
193-
let mut st = sts.get(j).unwrap_or(&st_none).clone();
194-
Self::pad_statement_args(params, &mut st.1);
195-
statements.push(st);
293+
let mut st = sts.get(j).unwrap_or(&middleware::Statement::None).clone();
294+
statements.push(Self::pad_statement(params, st.into()));
196295
}
197296
}
198297

@@ -208,80 +307,74 @@ impl MockMainPod {
208307
let sts = pod.pub_statements();
209308
assert!(sts.len() <= params.max_public_statements);
210309
for j in 0..params.max_public_statements {
211-
let mut st = sts.get(j).unwrap_or(&st_none).clone();
212-
Self::pad_statement_args(params, &mut st.1);
213-
statements.push(st);
310+
let mut st = sts.get(j).unwrap_or(&middleware::Statement::None).clone();
311+
statements.push(Self::pad_statement(params, st.into()));
214312
}
215313
}
216314

217315
// Input statements
218316
assert!(inputs.statements.len() <= params.max_priv_statements());
219317
for i in 0..params.max_priv_statements() {
220-
let mut st = inputs.statements.get(i).unwrap_or(&st_none).clone();
221-
Self::pad_statement_args(params, &mut st.1);
222-
statements.push(st);
318+
let mut st = inputs
319+
.statements
320+
.get(i)
321+
.unwrap_or(&middleware::Statement::None)
322+
.clone();
323+
statements.push(Self::pad_statement(params, st.into()));
223324
}
224325

225326
// Public statements
226327
assert!(inputs.public_statements.len() < params.max_public_statements);
227-
statements.push(Statement(
228-
NativeStatement::ValueOf,
229-
vec![StatementArg::Key(AnchoredKey(SELF, hash_str(KEY_TYPE)))],
328+
statements.push(Self::pad_statement(
329+
params,
330+
middleware::Statement::ValueOf(
331+
AnchoredKey(SELF, hash_str(KEY_TYPE)),
332+
middleware::Value(hash_str(VALUE_TYPE).0),
333+
)
334+
.into(),
230335
));
231336
for i in 0..(params.max_public_statements - 1) {
232-
let mut st = inputs.public_statements.get(i).unwrap_or(&st_none).clone();
233-
Self::pad_statement_args(params, &mut st.1);
234-
statements.push(st);
337+
let st = inputs
338+
.public_statements
339+
.get(i)
340+
.unwrap_or(&middleware::Statement::None)
341+
.clone();
342+
statements.push(Self::pad_statement(params, st.into()));
235343
}
236344

237345
statements
238346
}
239347

240-
fn find_op_arg(statements: &[Statement], op_arg: &middleware::OperationArg) -> OperationArg {
241-
match op_arg {
242-
middleware::OperationArg::None => OperationArg::None,
243-
middleware::OperationArg::Key(k) => OperationArg::Index(
244-
// TODO: Error handling when the key is not found in any ValueOf statement
245-
statements
246-
.iter()
247-
.enumerate()
248-
.find_map(|(i, s)| match s.0 {
249-
NativeStatement::ValueOf => match &s.1[0] {
250-
StatementArg::Key(sk) => (sk == k).then_some(i),
251-
_ => None,
252-
},
253-
_ => None,
254-
})
255-
.unwrap(),
256-
),
257-
middleware::OperationArg::Statement(st) => OperationArg::Index(
258-
// TODO: Error handling when the statement is not found
259-
statements
260-
.iter()
261-
.enumerate()
262-
.find_map(|(i, s)| (s == st).then_some(i))
263-
.unwrap(),
264-
),
265-
}
348+
fn find_op_arg(statements: &[Statement], op_arg: &middleware::Statement) -> OperationArg {
349+
OperationArg::Index(
350+
// TODO: Error handling when the statement is not found
351+
statements
352+
.iter()
353+
.enumerate()
354+
.find_map(|(i, s)| (s == &op_arg.clone().into()).then_some(i))
355+
.unwrap(),
356+
)
266357
}
267358

268359
fn process_private_statements_operations(
269360
params: &Params,
270361
statements: &[Statement],
271362
input_operations: &[middleware::Operation],
272363
) -> Vec<Operation> {
273-
let op_none = Self::operation_none(params);
274-
275364
let mut operations = Vec::new();
276365
for i in 0..params.max_priv_statements() {
277-
let op = input_operations.get(i).unwrap_or(&op_none).clone();
278-
let mut mid_args = op.1;
366+
let op = input_operations
367+
.get(i)
368+
.unwrap_or(&middleware::Operation::None)
369+
.clone();
370+
let mut mid_args = op.args();
279371
Self::pad_operation_args(params, &mut mid_args);
280372
let mut args = Vec::with_capacity(mid_args.len());
281373
for mid_arg in &mid_args {
282-
args.push(Self::find_op_arg(statements, mid_arg));
374+
let op_arg = Self::find_op_arg(statements, mid_arg);
375+
args.push(op_arg)
283376
}
284-
operations.push(Operation(op.0, args));
377+
operations.push(Operation(op.code(), args));
285378
}
286379
operations
287380
}
@@ -293,19 +386,18 @@ impl MockMainPod {
293386
statements: &[Statement],
294387
mut operations: Vec<Operation>,
295388
) -> Vec<Operation> {
296-
let op_none = Self::operation_none(params);
297-
298389
let offset_public_statements = statements.len() - params.max_public_statements;
299390
operations.push(Operation(NativeOperation::NewEntry, vec![]));
300391
for i in 0..(params.max_public_statements - 1) {
301392
let st = &statements[offset_public_statements + i + 1];
302393
let mut op = if st.is_none() {
303394
Operation(NativeOperation::None, vec![])
304395
} else {
305-
let mid_arg = middleware::OperationArg::Statement(st.clone());
396+
let mid_arg = st.clone();
306397
Operation(
307398
NativeOperation::CopyStatement,
308-
vec![Self::find_op_arg(statements, &mid_arg)],
399+
// TODO
400+
vec![Self::find_op_arg(statements, &mid_arg.try_into().unwrap())],
309401
)
310402
};
311403
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args);
@@ -332,7 +424,12 @@ impl MockMainPod {
332424
.map(|p| (*p).clone())
333425
.collect_vec();
334426
let input_main_pods = inputs.main_pods.iter().map(|p| (*p).clone()).collect_vec();
335-
let input_statements = inputs.statements.iter().cloned().collect_vec();
427+
let input_statements = inputs
428+
.statements
429+
.iter()
430+
.cloned()
431+
.map(|s| Self::pad_statement(params, s.into()))
432+
.collect_vec();
336433
let public_statements =
337434
statements[statements.len() - params.max_public_statements..].to_vec();
338435

@@ -357,26 +454,20 @@ impl MockMainPod {
357454
Statement(NativeStatement::None, args)
358455
}
359456

360-
fn operation_none(params: &Params) -> middleware::Operation {
361-
let mut args = Vec::with_capacity(params.max_operation_args);
362-
Self::pad_operation_args(&params, &mut args);
363-
middleware::Operation(NativeOperation::None, args)
457+
fn operation_none(params: &Params) -> Operation {
458+
Self::pad_operation(params, Operation(NativeOperation::None, vec![]))
364459
}
365460

366461
fn pad_statement_args(params: &Params, args: &mut Vec<StatementArg>) {
367462
fill_pad(args, StatementArg::None, params.max_statement_args)
368463
}
369464

370-
fn pad_operation_args(params: &Params, args: &mut Vec<middleware::OperationArg>) {
371-
fill_pad(
372-
args,
373-
middleware::OperationArg::None,
374-
params.max_operation_args,
375-
)
465+
fn pad_operation_args(params: &Params, args: &mut Vec<middleware::Statement>) {
466+
fill_pad(args, middleware::Statement::None, params.max_operation_args)
376467
}
377468
}
378469

379-
pub fn hash_statements(statements: &[middleware::Statement]) -> Result<middleware::Hash> {
470+
pub fn hash_statements(statements: &[Statement]) -> Result<middleware::Hash> {
380471
let field_elems = statements
381472
.into_iter()
382473
.flat_map(|statement| statement.clone().to_fields().0)
@@ -444,7 +535,8 @@ impl MainPod for MockMainPod {
444535
.map(|(i, s)| {
445536
self.operations[i]
446537
.deref(&self.statements[..input_statement_offset + i])
447-
.check(s.clone())
538+
.unwrap()
539+
.check(&s.clone().try_into().unwrap())
448540
})
449541
.collect::<Result<Vec<_>>>()
450542
.unwrap();
@@ -453,7 +545,7 @@ impl MainPod for MockMainPod {
453545
fn id(&self) -> PodId {
454546
self.id
455547
}
456-
fn pub_statements(&self) -> Vec<Statement> {
548+
fn pub_statements(&self) -> Vec<middleware::Statement> {
457549
// return the public statements, where when origin=SELF is replaced by origin=self.id()
458550
self.statements
459551
.iter()
@@ -473,6 +565,8 @@ impl MainPod for MockMainPod {
473565
})
474566
.collect(),
475567
)
568+
.try_into()
569+
.unwrap()
476570
})
477571
.collect()
478572
}

0 commit comments

Comments
 (0)