Skip to content

Commit 2b7c5ac

Browse files
committed
Experiment with statement & op enums
1 parent 83a4f89 commit 2b7c5ac

File tree

3 files changed

+711
-348
lines changed

3 files changed

+711
-348
lines changed

src/backends/mock_main.rs

Lines changed: 163 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@ use crate::middleware::{
22
self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativeStatement, NonePod,
33
Params, Pod, PodId, PodProver, Statement, StatementArg, ToFields, KEY_TYPE, SELF,
44
};
5-
use anyhow::Result;
5+
use anyhow::{anyhow, Result};
66
use itertools::Itertools;
77
use plonky2::hash::poseidon::PoseidonHash;
88
use plonky2::plonk::config::Hasher;
99
use std::any::Any;
1010
use std::error::Error;
1111
use std::fmt;
1212

13+
pub const VALUE_TYPE: &str = "MockMainPOD";
14+
1315
pub struct MockProver {}
1416

1517
impl PodProver for MockProver {
@@ -51,18 +53,103 @@ impl std::error::Error for OperationArgError {}
5153
struct Operation(pub NativeOperation, pub Vec<OperationArg>);
5254

5355
impl Operation {
54-
pub fn deref(&self, statements: &[Statement]) -> crate::middleware::Operation {
56+
pub fn deref(&self, statements: &[Statement]) -> Result<crate::middleware::Operation> {
5557
let deref_args = self
5658
.1
5759
.iter()
58-
.map(|arg| match arg {
59-
OperationArg::None => middleware::OperationArg::None,
60+
.flat_map(|arg| match arg {
61+
OperationArg::None => vec![],
6062
OperationArg::Index(i) => {
61-
middleware::OperationArg::Statement(statements[*i].clone())
63+
vec![statements[*i].clone().try_into()]
6264
}
6365
})
64-
.collect();
65-
middleware::Operation(self.0, deref_args)
66+
.collect::<Result<Vec<crate::middleware::Statement>>>()?;
67+
middleware::Operation::op(self.0, &deref_args)
68+
}
69+
}
70+
71+
#[derive(Clone, Debug, PartialEq, Eq)]
72+
pub struct Statement(pub NativeStatement, pub Vec<StatementArg>);
73+
74+
impl Statement {
75+
pub fn is_none(&self) -> bool {
76+
self.0 == NativeStatement::None
77+
}
78+
}
79+
80+
impl ToFields for Statement {
81+
fn to_fields(self) -> (Vec<middleware::F>, usize) {
82+
let (native_statement_f, native_statement_f_len) = self.0.to_fields();
83+
let (vec_statementarg_f, vec_statementarg_f_len) = self
84+
.1
85+
.into_iter()
86+
.map(|statement_arg| statement_arg.to_fields())
87+
.fold((Vec::new(), 0), |mut acc, (f, l)| {
88+
acc.0.extend(f);
89+
acc.1 += l;
90+
acc
91+
});
92+
(
93+
[native_statement_f, vec_statementarg_f].concat(),
94+
native_statement_f_len + vec_statementarg_f_len,
95+
)
96+
}
97+
}
98+
99+
impl TryInto<middleware::Statement> for Statement {
100+
type Error = anyhow::Error;
101+
fn try_into(self) -> Result<middleware::Statement> {
102+
type S = middleware::Statement;
103+
type NS = NativeStatement;
104+
type SA = StatementArg;
105+
let args = (
106+
self.1.get(0).cloned(),
107+
self.1.get(1).cloned(),
108+
self.1.get(2).cloned(),
109+
);
110+
Ok(match (self.0, args) {
111+
(NS::None, _) => S::None,
112+
(NS::ValueOf, (Some(SA::Key(ak)), Some(SA::Literal(v)), None)) => S::ValueOf(ak, v),
113+
(NS::Equal, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Equal(ak1, ak2),
114+
(NS::NotEqual, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::NotEqual(ak1, ak2),
115+
(NS::Gt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Gt(ak1, ak2),
116+
(NS::Lt, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Lt(ak1, ak2),
117+
(NS::Contains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => S::Contains(ak1, ak2),
118+
(NS::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None)) => {
119+
S::NotContains(ak1, ak2)
120+
}
121+
(NS::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
122+
S::SumOf(ak1, ak2, ak3)
123+
}
124+
(NS::ProductOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
125+
S::ProductOf(ak1, ak2, ak3)
126+
}
127+
(NS::MaxOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3)))) => {
128+
S::MaxOf(ak1, ak2, ak3)
129+
}
130+
_ => Err(anyhow!("Malformed statement expression {}", self))?,
131+
})
132+
}
133+
}
134+
135+
impl From<middleware::Statement> for Statement {
136+
fn from(s: middleware::Statement) -> Self {
137+
Statement(s.code(), s.args().into_iter().map(|arg| arg).collect())
138+
}
139+
}
140+
141+
impl fmt::Display for Statement {
142+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
143+
write!(f, "{:?} ", self.0)?;
144+
for (i, arg) in self.1.iter().enumerate() {
145+
if !(!f.alternate() && arg.is_none()) {
146+
if i != 0 {
147+
write!(f, " ")?;
148+
}
149+
write!(f, "{}", arg)?;
150+
}
151+
}
152+
Ok(())
66153
}
67154
}
68155

@@ -174,6 +261,11 @@ fn fill_pad<T: Clone>(v: &mut Vec<T>, pad_value: T, len: usize) {
174261
}
175262
}
176263

264+
fn pad<T: Clone>(v: Vec<T>, pad_value: T, len: usize) -> Vec<T> {
265+
let v_len = v.len();
266+
[v, (v_len..len).map(|_| pad_value.clone()).collect()].concat()
267+
}
268+
177269
impl MockMainPod {
178270
fn offset_input_signed_pods(&self) -> usize {
179271
0
@@ -188,12 +280,19 @@ impl MockMainPod {
188280
fn offset_public_statements(&self) -> usize {
189281
self.offset_input_statements() + self.params.max_priv_statements()
190282
}
283+
fn pad_statement(params: &Params, s: Statement) -> Statement {
284+
Statement(s.0, pad(s.1, StatementArg::None, params.max_statement_args))
285+
}
286+
fn pad_operation(params: &Params, op: Operation) -> Operation {
287+
Operation(
288+
op.0,
289+
pad(op.1, OperationArg::None, params.max_operation_args),
290+
)
291+
}
191292

192293
fn layout_statements(params: &Params, inputs: &MainPodInputs) -> Vec<Statement> {
193294
let mut statements = Vec::new();
194295

195-
let st_none = Self::statement_none(params);
196-
197296
// Input signed pods region
198297
let none_sig_pod: Box<dyn Pod> = Box::new(NonePod {});
199298
assert!(inputs.signed_pods.len() <= params.max_input_signed_pods);
@@ -203,12 +302,12 @@ impl MockMainPod {
203302
.get(i)
204303
.map(|p| *p)
205304
.unwrap_or(&none_sig_pod);
305+
// TODO
206306
let sts = pod.pub_statements();
207307
assert!(sts.len() <= params.max_signed_pod_values);
208308
for j in 0..params.max_signed_pod_values {
209-
let mut st = sts.get(j).unwrap_or(&st_none).clone();
210-
Self::pad_statement_args(params, &mut st.1);
211-
statements.push(st);
309+
let mut st = sts.get(j).unwrap_or(&middleware::Statement::None).clone();
310+
statements.push(Self::pad_statement(params, st.into()));
212311
}
213312
}
214313

@@ -224,64 +323,51 @@ impl MockMainPod {
224323
let sts = pod.pub_statements();
225324
assert!(sts.len() <= params.max_public_statements);
226325
for j in 0..params.max_public_statements {
227-
let mut st = sts.get(j).unwrap_or(&st_none).clone();
228-
Self::pad_statement_args(params, &mut st.1);
229-
statements.push(st);
326+
let mut st = sts.get(j).unwrap_or(&middleware::Statement::None).clone();
327+
statements.push(Self::pad_statement(params, st.into()));
230328
}
231329
}
232330

233331
// Input statements
234332
assert!(inputs.statements.len() <= params.max_priv_statements());
235333
for i in 0..params.max_priv_statements() {
236-
let mut st = inputs.statements.get(i).unwrap_or(&st_none).clone();
237-
Self::pad_statement_args(params, &mut st.1);
238-
statements.push(st);
334+
let mut st = inputs
335+
.statements
336+
.get(i)
337+
.unwrap_or(&middleware::Statement::None)
338+
.clone();
339+
statements.push(Self::pad_statement(params, st.into()));
239340
}
240341

241342
// Public statements
242343
assert!(inputs.public_statements.len() < params.max_public_statements);
243-
statements.push(Statement(
244-
NativeStatement::ValueOf,
245-
vec![StatementArg::Key(AnchoredKey(SELF, hash_str(KEY_TYPE)))],
344+
statements.push(Self::pad_statement(
345+
params,
346+
middleware::Statement::ValueOf(
347+
AnchoredKey(SELF, hash_str(KEY_TYPE)),
348+
middleware::Value(hash_str(VALUE_TYPE).0),
349+
)
350+
.into(),
246351
));
247352
for i in 0..(params.max_public_statements - 1) {
248-
let mut st = inputs.public_statements.get(i).unwrap_or(&st_none).clone();
249-
Self::pad_statement_args(params, &mut st.1);
250-
statements.push(st);
353+
let st = inputs
354+
.public_statements
355+
.get(i)
356+
.unwrap_or(&middleware::Statement::None)
357+
.clone();
358+
statements.push(Self::pad_statement(params, st.into()));
251359
}
252360

253361
statements
254362
}
255363

256-
pub fn find_op_arg(
257-
statements: &[Statement],
258-
op_arg: &middleware::OperationArg,
259-
) -> Result<OperationArg, OperationArgError> {
260-
match op_arg {
261-
middleware::OperationArg::None => Ok(OperationArg::None),
262-
middleware::OperationArg::Key(k) => {
263-
statements
264-
.iter()
265-
.enumerate()
266-
.find_map(|(i, s)| match s.0 {
267-
NativeStatement::ValueOf => match &s.1[0] {
268-
StatementArg::Key(sk) => (sk == k).then_some(i),
269-
_ => None,
270-
},
271-
_ => None,
272-
})
273-
.map(OperationArg::Index)
274-
.ok_or(OperationArgError::KeyNotFound)
275-
}
276-
middleware::OperationArg::Statement(st) => {
277-
statements
278-
.iter()
279-
.enumerate()
280-
.find_map(|(i, s)| (s == st).then_some(i))
281-
.map(OperationArg::Index)
364+
fn find_op_arg(statements: &[Statement], op_arg: &middleware::Statement) -> Result<OperationArg, OperationArgError> {
365+
statements
366+
.iter()
367+
.enumerate()
368+
.find_map(|(i, s)| (s == &op_arg.clone().into()))
369+
.map(OperationArg::Index)
282370
.ok_or(OperationArgError::StatementNotFound)
283-
}
284-
}
285371
}
286372

287373
fn process_private_statements_operations(
@@ -293,15 +379,18 @@ impl MockMainPod {
293379

294380
let mut operations = Vec::new();
295381
for i in 0..params.max_priv_statements() {
296-
let op = input_operations.get(i).unwrap_or(&op_none).clone();
297-
let mut mid_args = op.1;
382+
let op = input_operations
383+
.get(i)
384+
.unwrap_or(&middleware::Operation::None)
385+
.clone();
386+
let mut mid_args = op.args();
298387
Self::pad_operation_args(params, &mut mid_args);
299388
let mut args = Vec::with_capacity(mid_args.len());
300389
for mid_arg in &mid_args {
301390
let op_arg = Self::find_op_arg(statements, mid_arg)?;
302391
args.push(op_arg)
303392
}
304-
operations.push(Operation(op.0, args));
393+
operations.push(Operation(op.code(), args));
305394
}
306395
Ok(operations)
307396
}
@@ -320,11 +409,11 @@ impl MockMainPod {
320409
let mut op = if st.is_none() {
321410
Operation(NativeOperation::None, vec![])
322411
} else {
323-
let mid_arg = middleware::OperationArg::Statement(st.clone());
324-
let op_arg = Self::find_op_arg(statements, &mid_arg)?;
412+
let mid_arg = st.clone();
325413
Operation(
326414
NativeOperation::CopyStatement,
327-
vec![op_arg],
415+
// TODO
416+
vec![Self::find_op_arg(statements, &mid_arg.try_into().unwrap())],
328417
)
329418
};
330419
fill_pad(&mut op.1, OperationArg::None, params.max_operation_args);
@@ -351,7 +440,12 @@ impl MockMainPod {
351440
.map(|p| (*p).clone())
352441
.collect_vec();
353442
let input_main_pods = inputs.main_pods.iter().map(|p| (*p).clone()).collect_vec();
354-
let input_statements = inputs.statements.iter().cloned().collect_vec();
443+
let input_statements = inputs
444+
.statements
445+
.iter()
446+
.cloned()
447+
.map(|s| Self::pad_statement(params, s.into()))
448+
.collect_vec();
355449
let public_statements =
356450
statements[statements.len() - params.max_public_statements..].to_vec();
357451

@@ -376,26 +470,20 @@ impl MockMainPod {
376470
Statement(NativeStatement::None, args)
377471
}
378472

379-
fn operation_none(params: &Params) -> middleware::Operation {
380-
let mut args = Vec::with_capacity(params.max_operation_args);
381-
Self::pad_operation_args(&params, &mut args);
382-
middleware::Operation(NativeOperation::None, args)
473+
fn operation_none(params: &Params) -> Operation {
474+
Self::pad_operation(params, Operation(NativeOperation::None, vec![]))
383475
}
384476

385477
fn pad_statement_args(params: &Params, args: &mut Vec<StatementArg>) {
386478
fill_pad(args, StatementArg::None, params.max_statement_args)
387479
}
388480

389-
fn pad_operation_args(params: &Params, args: &mut Vec<middleware::OperationArg>) {
390-
fill_pad(
391-
args,
392-
middleware::OperationArg::None,
393-
params.max_operation_args,
394-
)
481+
fn pad_operation_args(params: &Params, args: &mut Vec<middleware::Statement>) {
482+
fill_pad(args, middleware::Statement::None, params.max_operation_args)
395483
}
396484
}
397485

398-
pub fn hash_statements(statements: &[middleware::Statement]) -> Result<middleware::Hash> {
486+
pub fn hash_statements(statements: &[Statement]) -> Result<middleware::Hash> {
399487
let field_elems = statements
400488
.into_iter()
401489
.flat_map(|statement| statement.clone().to_fields().0)
@@ -463,7 +551,8 @@ impl Pod for MockMainPod {
463551
.map(|(i, s)| {
464552
self.operations[i]
465553
.deref(&self.statements[..input_statement_offset + i])
466-
.check(s.clone())
554+
.unwrap()
555+
.check(&s.clone().try_into().unwrap())
467556
})
468557
.collect::<Result<Vec<_>>>()
469558
.unwrap();
@@ -472,7 +561,7 @@ impl Pod for MockMainPod {
472561
fn id(&self) -> PodId {
473562
self.id
474563
}
475-
fn pub_statements(&self) -> Vec<Statement> {
564+
fn pub_statements(&self) -> Vec<middleware::Statement> {
476565
// return the public statements, where when origin=SELF is replaced by origin=self.id()
477566
self.statements
478567
.iter()
@@ -492,6 +581,8 @@ impl Pod for MockMainPod {
492581
})
493582
.collect(),
494583
)
584+
.try_into()
585+
.unwrap()
495586
})
496587
.collect()
497588
}

0 commit comments

Comments
 (0)