Skip to content

Commit a37b96a

Browse files
authored
Serialize and hash custom predicates (#90)
* Print pods from SignedPodBuilder * Add additional print to test printing SignedPodBuilder * Mock-prove and print MainPod * Implement ToFields for custom predicates and dependencies * Test: print serialization of a recursive batch * Rearrange serialization of CustomPredicate so args_len is always in the same position * Serialize predicates with first entry nonzero to avoid collision with padding * Off by one error in ethdos test BatchSelf(2) * cargo fmt * not a typo * Typos, trying again
1 parent 05c21eb commit a37b96a

File tree

8 files changed

+262
-39
lines changed

8 files changed

+262
-39
lines changed

.github/workflows/typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[default.extend-words]
22
groth = "groth" # to avoid it dectecting it as 'growth'
33
BA = "BA"
4+
Ded = "Ded" # "ANDed", it thought "Ded" should be "Dead"

src/backends/mock_main/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ impl MockMainPod {
327327
statements[statements.len() - params.max_public_statements..].to_vec();
328328

329329
// get the id out of the public statements
330-
let id: PodId = PodId(hash_statements(&public_statements));
330+
let id: PodId = PodId(hash_statements(&public_statements, *params));
331331

332332
Ok(Self {
333333
params: params.clone(),
@@ -362,10 +362,10 @@ impl MockMainPod {
362362
}
363363
}
364364

365-
pub fn hash_statements(statements: &[Statement]) -> middleware::Hash {
365+
pub fn hash_statements(statements: &[Statement], params: Params) -> middleware::Hash {
366366
let field_elems = statements
367367
.into_iter()
368-
.flat_map(|statement| statement.clone().to_fields().0)
368+
.flat_map(|statement| statement.clone().to_fields(params).0)
369369
.collect::<Vec<_>>();
370370
Hash(PoseidonHash::hash_no_pad(&field_elems).elements)
371371
}
@@ -376,7 +376,7 @@ impl Pod for MockMainPod {
376376
// get the input_statements from the self.statements
377377
let input_statements = &self.statements[input_statement_offset..];
378378
// get the id out of the public statements, and ensure it is equal to self.id
379-
let ids_match = self.id == PodId(hash_statements(&self.public_statements));
379+
let ids_match = self.id == PodId(hash_statements(&self.public_statements, self.params));
380380
// find a ValueOf statement from the public statements with key=KEY_TYPE and check that the
381381
// value is PodType::MockMainPod
382382
let has_type_statement = self

src/backends/mock_main/statement.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use anyhow::{anyhow, Result};
22
use std::fmt;
33

4-
use crate::middleware::{self, NativePredicate, StatementArg, ToFields};
4+
use crate::middleware::{self, NativePredicate, Params, StatementArg, ToFields};
55

66
#[derive(Clone, Debug, PartialEq, Eq)]
77
pub struct Statement(pub NativePredicate, pub Vec<StatementArg>);
@@ -21,12 +21,13 @@ impl Statement {
2121
}
2222

2323
impl ToFields for Statement {
24-
fn to_fields(self) -> (Vec<middleware::F>, usize) {
25-
let (native_statement_f, native_statement_f_len) = self.0.to_fields();
24+
fn to_fields(&self, params: Params) -> (Vec<middleware::F>, usize) {
25+
let (native_statement_f, native_statement_f_len) = self.0.to_fields(params);
2626
let (vec_statementarg_f, vec_statementarg_f_len) = self
2727
.1
28+
.clone()
2829
.into_iter()
29-
.map(|statement_arg| statement_arg.to_fields())
30+
.map(|statement_arg| statement_arg.to_fields(params))
3031
.fold((Vec::new(), 0), |mut acc, (f, l)| {
3132
acc.0.extend(f);
3233
acc.1 += l;

src/frontend/custom.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,16 @@ fn resolve_wildcard(args: &[&str], priv_args: &[&str], v: &HashOrWildcardStr) ->
171171
#[cfg(test)]
172172
mod tests {
173173
use super::*;
174-
use crate::middleware::{CustomPredicateRef, PodType};
174+
use crate::middleware::{CustomPredicateRef, Params, PodType};
175175

176176
#[test]
177177
fn test_custom_pred() {
178178
use NativePredicate as NP;
179179
use StatementTmplBuilder as STB;
180180

181+
let params = Params::default();
182+
params.print_serialized_sizes();
183+
181184
let mut builder = CustomPredicateBatchBuilder::new("eth_friend".into());
182185
let _eth_friend = builder.predicate_and(
183186
// arguments:
@@ -239,7 +242,7 @@ mod tests {
239242
builder.predicates.last().unwrap()
240243
);
241244

242-
let eth_dos_distance = Predicate::BatchSelf(3);
245+
let eth_dos_distance = Predicate::BatchSelf(2);
243246

244247
// next chunk builds:
245248
let eth_dos_distance_ind = builder.predicate_and(
@@ -311,5 +314,9 @@ mod tests {
311314
"b.2. eth_dos_distance = {}",
312315
builder.predicates.last().unwrap()
313316
);
317+
318+
let eth_dos_batch_b = builder.finish();
319+
let fields = eth_dos_batch_b.to_fields(params);
320+
println!("Batch b, serialized: {:?}", fields);
314321
}
315322
}

src/frontend/mod.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,16 @@ pub struct SignedPodBuilder {
9696
pub kvs: HashMap<String, Value>,
9797
}
9898

99+
impl fmt::Display for SignedPodBuilder {
100+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101+
writeln!(f, "SignedPodBuilder:")?;
102+
for (k, v) in self.kvs.iter().sorted_by_key(|kv| kv.0) {
103+
writeln!(f, " - {}: {}", k, v)?;
104+
}
105+
Ok(())
106+
}
107+
}
108+
99109
impl SignedPodBuilder {
100110
pub fn new(params: &Params) -> Self {
101111
Self {
@@ -347,6 +357,22 @@ pub struct MainPod {
347357
// TODO: metadata
348358
}
349359

360+
impl fmt::Display for MainPod {
361+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362+
writeln!(f, "MainPod: {}", self.pod.id())?;
363+
writeln!(f, " valid? {}", self.pod.verify())?;
364+
writeln!(f, " statements:")?;
365+
for st in &self.pod.pub_statements() {
366+
writeln!(f, " - {}", st)?;
367+
}
368+
writeln!(f, " kvs:")?;
369+
for (k, v) in &self.pod.kvs() {
370+
writeln!(f, " - {}: {}", k, v)?;
371+
}
372+
Ok(())
373+
}
374+
}
375+
350376
impl MainPod {
351377
pub fn id(&self) -> PodId {
352378
self.pod.id()
@@ -487,6 +513,7 @@ pub mod build_utils {
487513
#[cfg(test)]
488514
pub mod tests {
489515
use super::*;
516+
use crate::backends::mock_main::MockProver;
490517
use crate::backends::mock_signed::MockSigner;
491518
use crate::examples::{
492519
great_boy_pod_full_flow, tickets_pod_full_flow, zu_kyc_pod_builder,
@@ -498,7 +525,8 @@ pub mod tests {
498525
let params = Params::default();
499526
let (gov_id, pay_stub) = zu_kyc_sign_pod_builders(&params);
500527

501-
// TODO: print pods from the builder
528+
println!("{}", gov_id);
529+
println!("{}", pay_stub);
502530

503531
let mut signer = MockSigner {
504532
pk: "ZooGov".into(),
@@ -515,7 +543,11 @@ pub mod tests {
515543
let kyc = zu_kyc_pod_builder(&params, &gov_id, &pay_stub)?;
516544
println!("{}", kyc);
517545

546+
let mut prover = MockProver {};
547+
let kyc = kyc.prove(&mut prover)?;
548+
518549
// TODO: prove kyc with MockProver and print it
550+
println!("{}", kyc);
519551

520552
Ok(())
521553
}

src/middleware/custom.rs

Lines changed: 143 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ use std::sync::Arc;
22
use std::{fmt, hash as h, iter::zip};
33

44
use anyhow::{anyhow, Result};
5+
use plonky2::field::goldilocks_field::GoldilocksField;
6+
use plonky2::field::types::Field;
7+
use plonky2::hash::poseidon::PoseidonHash;
8+
use plonky2::plonk::config::Hasher;
59

610
use super::{
7-
hash_str, AnchoredKey, Hash, NativePredicate, PodId, Statement, StatementArg, ToFields, Value,
8-
F,
11+
hash_str, AnchoredKey, Hash, NativePredicate, Params, PodId, Statement, StatementArg, ToFields,
12+
Value, F,
913
};
1014

1115
// BEGIN Custom 1b
@@ -38,6 +42,22 @@ impl fmt::Display for HashOrWildcard {
3842
}
3943
}
4044

45+
impl ToFields for HashOrWildcard {
46+
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
47+
match self {
48+
HashOrWildcard::Hash(h) => h.to_fields(params),
49+
HashOrWildcard::Wildcard(w) => {
50+
let usizes: Vec<usize> = vec![0, 0, 0, *w];
51+
let fields: Vec<F> = usizes
52+
.iter()
53+
.map(|x| F::from_canonical_u64(*x as u64))
54+
.collect();
55+
(fields, 4)
56+
}
57+
}
58+
}
59+
}
60+
4161
#[derive(Clone, Debug, PartialEq, Eq, h::Hash)]
4262
pub enum StatementTmplArg {
4363
None,
@@ -64,6 +84,40 @@ impl StatementTmplArg {
6484
}
6585
}
6686

87+
impl ToFields for StatementTmplArg {
88+
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
89+
// None => (0, ...)
90+
// Literal(value) => (1, [value], 0, 0, 0, 0)
91+
// Key(hash_or_wildcard1, hash_or_wildcard2)
92+
// => (2, [hash_or_wildcard1], [hash_or_wildcard2])
93+
// In all three cases, we pad to 2 * hash_size + 1 = 9 field elements
94+
let hash_size = 4;
95+
let statement_tmpl_arg_size = 2 * hash_size + 1;
96+
match self {
97+
StatementTmplArg::None => {
98+
let fields: Vec<F> = std::iter::repeat_with(|| F::from_canonical_u64(0))
99+
.take(statement_tmpl_arg_size)
100+
.collect();
101+
(fields, statement_tmpl_arg_size)
102+
}
103+
StatementTmplArg::Literal(v) => {
104+
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(1))
105+
.chain(v.to_fields(params).0.into_iter())
106+
.chain(std::iter::repeat_with(|| F::from_canonical_u64(0)).take(hash_size))
107+
.collect();
108+
(fields, statement_tmpl_arg_size)
109+
}
110+
StatementTmplArg::Key(hw1, hw2) => {
111+
let fields: Vec<F> = std::iter::once(F::from_canonical_u64(2))
112+
.chain(hw1.to_fields(params).0.into_iter())
113+
.chain(hw2.to_fields(params).0.into_iter())
114+
.collect();
115+
(fields, statement_tmpl_arg_size)
116+
}
117+
}
118+
}
119+
}
120+
67121
impl fmt::Display for StatementTmplArg {
68122
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69123
match self {
@@ -117,6 +171,26 @@ impl StatementTmpl {
117171
}
118172
}
119173

174+
impl ToFields for StatementTmpl {
175+
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
176+
// serialize as:
177+
// predicate (6 field elements)
178+
// then the StatementTmplArgs
179+
if self.1.len() > params.max_statement_args {
180+
panic!("Statement template has too many arguments");
181+
}
182+
let mut fields: Vec<F> = self
183+
.0
184+
.to_fields(params)
185+
.0
186+
.into_iter()
187+
.chain(self.1.iter().flat_map(|sta| sta.to_fields(params).0))
188+
.collect();
189+
fields.resize_with(params.statement_tmpl_size(), || F::from_canonical_u64(0));
190+
(fields, params.statement_tmpl_size())
191+
}
192+
}
193+
120194
#[derive(Clone, Debug, PartialEq, Eq)]
121195
pub struct CustomPredicate {
122196
/// true for "and", false for "or"
@@ -128,10 +202,22 @@ pub struct CustomPredicate {
128202
}
129203

130204
impl ToFields for CustomPredicate {
131-
fn to_fields(self) -> (Vec<F>, usize) {
132-
todo!()
133-
// let f: Vec<F> = Vec::new();
134-
// (self.conjunction.to_f(), 1)
205+
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
206+
// serialize as:
207+
// conjunction (one field element)
208+
// args_len (one field element)
209+
// statements
210+
// (params.max_custom_predicate_arity * params.statement_tmpl_size())
211+
// field elements
212+
if self.statements.len() > params.max_custom_predicate_arity {
213+
panic!("Custom predicate depends on too many statements");
214+
}
215+
let mut fields: Vec<F> = std::iter::once(F::from_bool(self.conjunction))
216+
.chain(std::iter::once(F::from_canonical_usize(self.args_len)))
217+
.chain(self.statements.iter().flat_map(|st| st.to_fields(params).0))
218+
.collect();
219+
fields.resize_with(params.custom_predicate_size(), || F::from_canonical_u64(0));
220+
(fields, params.custom_predicate_size())
135221
}
136222
}
137223

@@ -166,10 +252,30 @@ pub struct CustomPredicateBatch {
166252
pub predicates: Vec<CustomPredicate>,
167253
}
168254

255+
impl ToFields for CustomPredicateBatch {
256+
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
257+
// all the custom predicates in order
258+
if self.predicates.len() > params.max_custom_batch_size {
259+
panic!("Predicate batch exceeds maximum size");
260+
}
261+
let mut fields: Vec<F> = self
262+
.predicates
263+
.iter()
264+
.flat_map(|p| p.to_fields(params).0)
265+
.collect();
266+
fields.resize_with(params.custom_predicate_batch_size_field_elts(), || {
267+
F::from_canonical_u64(0)
268+
});
269+
270+
(fields, params.custom_predicate_batch_size_field_elts())
271+
}
272+
}
273+
169274
impl CustomPredicateBatch {
170-
pub fn hash(&self) -> Hash {
171-
// TODO
172-
hash_str(&format!("{:?}", self))
275+
pub fn hash(&self, params: Params) -> Hash {
276+
let input = self.to_fields(params).0;
277+
let h = Hash(PoseidonHash::hash_no_pad(&input).elements);
278+
h
173279
}
174280
}
175281

@@ -190,12 +296,36 @@ impl From<NativePredicate> for Predicate {
190296
}
191297

192298
impl ToFields for Predicate {
193-
fn to_fields(self) -> (Vec<F>, usize) {
299+
fn to_fields(&self, params: Params) -> (Vec<F>, usize) {
300+
// serialize:
301+
// NativePredicate(id) as (0, id, 0, 0, 0, 0) -- id: usize
302+
// BatchSelf(i) as (1, i, 0, 0, 0, 0) -- i: usize
303+
// CustomPredicateRef(pb, i) as
304+
// (2, [hash of pb], i) -- pb hashes to 4 field elements
305+
// -- i: usize
306+
307+
// in every case: pad to (hash_size + 2) field elements
308+
let mut fields: Vec<F> = Vec::new();
194309
match self {
195-
Self::Native(p) => p.to_fields(),
196-
Self::BatchSelf(i) => Value::from(i as i64).to_fields(),
197-
Self::Custom(_) => todo!(), // TODO
310+
Self::Native(p) => {
311+
fields = std::iter::once(F::from_canonical_u64(1))
312+
.chain(p.to_fields(params).0.into_iter())
313+
.collect();
314+
}
315+
Self::BatchSelf(i) => {
316+
fields = std::iter::once(F::from_canonical_u64(2))
317+
.chain(std::iter::once(F::from_canonical_usize(*i)))
318+
.collect();
319+
}
320+
Self::Custom(CustomPredicateRef(pb, i)) => {
321+
fields = std::iter::once(F::from_canonical_u64(3))
322+
.chain(pb.hash(params).0)
323+
.chain(std::iter::once(F::from_canonical_usize(*i)))
324+
.collect();
325+
}
198326
}
327+
fields.resize_with(params.predicate_size(), || F::from_canonical_u64(0));
328+
(fields, params.predicate_size())
199329
}
200330
}
201331

0 commit comments

Comments
 (0)