From 2576535675ccafe61444d58ab305bbfd2652e942 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Tue, 18 Feb 2025 11:24:56 +0100 Subject: [PATCH 01/12] wip --- src/backends/mock_main.rs | 12 +-- src/backends/mock_signed.rs | 4 +- src/frontend.rs | 20 ++--- src/middleware/mod.rs | 149 +++++++++++++++++++++++++++++------- 4 files changed, 141 insertions(+), 44 deletions(-) diff --git a/src/backends/mock_main.rs b/src/backends/mock_main.rs index 25a8ebf4..ff60fc23 100644 --- a/src/backends/mock_main.rs +++ b/src/backends/mock_main.rs @@ -1,5 +1,5 @@ use crate::middleware::{ - self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativeStatement, NonePod, + self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod, Params, Pod, PodId, PodProver, Statement, StatementArg, ToFields, KEY_TYPE, SELF, }; use anyhow::Result; @@ -241,7 +241,7 @@ impl MockMainPod { // Public statements assert!(inputs.public_statements.len() < params.max_public_statements); statements.push(Statement( - NativeStatement::ValueOf, + NativePredicate::ValueOf, vec![StatementArg::Key(AnchoredKey(SELF, hash_str(KEY_TYPE)))], )); for i in 0..(params.max_public_statements - 1) { @@ -264,7 +264,7 @@ impl MockMainPod { .iter() .enumerate() .find_map(|(i, s)| match s.0 { - NativeStatement::ValueOf => match &s.1[0] { + NativePredicate::ValueOf => match &s.1[0] { StatementArg::Key(sk) => (sk == k).then_some(i), _ => None, }, @@ -373,7 +373,7 @@ impl MockMainPod { fn statement_none(params: &Params) -> Statement { let mut args = Vec::with_capacity(params.max_statement_args); Self::pad_statement_args(¶ms, &mut args); - Statement(NativeStatement::None, args) + Statement(NativePredicate::None, args) } fn operation_none(params: &Params) -> middleware::Operation { @@ -416,7 +416,7 @@ impl Pod for MockMainPod { .public_statements .iter() .find(|s| { - s.0 == NativeStatement::ValueOf + s.0 == NativePredicate::ValueOf && s.1.len() > 0 && if let StatementArg::Key(AnchoredKey(pod_id, key_hash)) = s.1[0] { pod_id == SELF && key_hash == hash_str(KEY_TYPE) @@ -444,7 +444,7 @@ impl Pod for MockMainPod { s, ) }) - .filter(|(i, s)| s.0 == NativeStatement::ValueOf) + .filter(|(i, s)| s.0 == NativePredicate::ValueOf) .flat_map(|(i, s)| { if let StatementArg::Key(ak) = &s.1[0] { vec![(i, ak.1, ak.0)] diff --git a/src/backends/mock_signed.rs b/src/backends/mock_signed.rs index 62a8ccc6..3fd9a908 100644 --- a/src/backends/mock_signed.rs +++ b/src/backends/mock_signed.rs @@ -1,5 +1,5 @@ use crate::middleware::{ - containers::Dictionary, hash_str, AnchoredKey, Hash, NativeStatement, Params, Pod, PodId, + containers::Dictionary, hash_str, AnchoredKey, Hash, NativePredicate, Params, Pod, PodId, PodSigner, PodType, Statement, StatementArg, Value, KEY_SIGNER, KEY_TYPE, }; use crate::primitives::merkletree::MerkleTree; @@ -83,7 +83,7 @@ impl Pod for MockSignedPod { .iter() .map(|(k, v)| { Statement( - NativeStatement::ValueOf, + NativePredicate::ValueOf, vec![ StatementArg::Key(AnchoredKey(id, Hash(k.0))), StatementArg::Literal(*v), diff --git a/src/frontend.rs b/src/frontend.rs index 4b861297..2c18d346 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -10,7 +10,7 @@ use std::fmt; use crate::middleware::{ self, containers::{Array, Dictionary, Set}, - hash_str, Hash, MainPodInputs, NativeOperation, NativeStatement, Params, PodId, PodProver, + hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver, PodSigner, SELF, }; @@ -166,7 +166,7 @@ impl fmt::Display for StatementArg { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativeStatement, pub Vec); +pub struct Statement(pub NativePredicate, pub Vec); impl fmt::Display for Statement { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -350,27 +350,27 @@ impl MainPodBuilder { let Operation(op_type, ref mut args) = op; // TODO: argument type checking let st = match op_type { - None => Statement(NativeStatement::None, vec![]), - NewEntry => Statement(NativeStatement::ValueOf, self.op_args_entries(public, args)), + None => Statement(NativePredicate::None, vec![]), + NewEntry => Statement(NativePredicate::ValueOf, self.op_args_entries(public, args)), CopyStatement => todo!(), EqualFromEntries => { - Statement(NativeStatement::Equal, self.op_args_entries(public, args)) + Statement(NativePredicate::Equal, self.op_args_entries(public, args)) } NotEqualFromEntries => Statement( - NativeStatement::NotEqual, + NativePredicate::NotEqual, self.op_args_entries(public, args), ), - GtFromEntries => Statement(NativeStatement::Gt, self.op_args_entries(public, args)), - LtFromEntries => Statement(NativeStatement::Lt, self.op_args_entries(public, args)), + GtFromEntries => Statement(NativePredicate::Gt, self.op_args_entries(public, args)), + LtFromEntries => Statement(NativePredicate::Lt, self.op_args_entries(public, args)), TransitiveEqualFromStatements => todo!(), GtToNotEqual => todo!(), LtToNotEqual => todo!(), ContainsFromEntries => Statement( - NativeStatement::Contains, + NativePredicate::Contains, self.op_args_entries(public, args), ), NotContainsFromEntries => Statement( - NativeStatement::NotContains, + NativePredicate::NotContains, self.op_args_entries(public, args), ), RenameContainedBy => todo!(), diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index e5bddd90..78e4def0 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -232,7 +232,7 @@ impl Default for Params { } #[derive(Clone, Copy, Debug, FromRepr, PartialEq, Eq)] -pub enum NativeStatement { +pub enum NativePredicate { None = 0, ValueOf = 1, Equal = 2, @@ -246,12 +246,75 @@ pub enum NativeStatement { MaxOf = 10, } -impl ToFields for NativeStatement { +impl ToFields for NativePredicate { fn to_fields(self) -> (Vec, usize) { (vec![F::from_canonical_u64(self as u64)], 1) } } +use std::sync::Arc; + +// BEGIN Custom 1b + +/* +pub enum PodIdOrWildcard { + PodId(PodId), + Wildcard(usize), +} + +pub enum HashOrWildcard { + Hash(Hash), + Wildcard(usize), +} + +pub enum StatementTmplArg { + None, + Literal(Value), + Key(PodIdOrWildcard, HashOrWildcard), +} +*/ + +// END + +// BEGIN Custom 2 + +pub enum StatementTmplArg { + None, + Literal(Value), + Wildcard(usize), +} + +// END + +/// Statement Template for a Custom Predicate +pub struct StatementTmpl(Predicate, Vec); + +pub struct CustomPredicate { + /// true for "and", false for "or" + pub conjunction: bool, + pub statements: Vec, + pub args_len: usize, + // TODO: Add private args length? + // TODO: Add args type information? +} + +pub enum Predicate { + Native(NativePredicate), + Custom(Arc), +} + +impl From for Predicate { + fn from(v: NativePredicate) -> Self { + Self::Native(v) + } +} + +impl ToFields for Predicate { + fn to_fields(self) -> (Vec, usize) { + todo!() + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] /// AnchoredKey is a tuple containing (OriginId: PodId, key: Hash) pub struct AnchoredKey(pub PodId, pub Hash); @@ -332,7 +395,7 @@ impl ToFields for StatementArg { // TODO: Replace this with a more stringly typed enum as in the Devcon implementation. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Statement(pub NativeStatement, pub Vec); +pub struct Statement(pub NativePredicate, pub Vec); impl fmt::Display for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -350,14 +413,14 @@ impl fmt::Display for Statement { } impl Statement { - pub fn code(&self) -> NativeStatement { + pub fn code(&self) -> NativePredicate { self.0 } pub fn args(&self) -> &[StatementArg] { &self.1 } pub fn is_none(&self) -> bool { - matches!(self.0, NativeStatement::None) + matches!(self.0, NativePredicate::None) } } @@ -443,11 +506,11 @@ impl Operation { use NativeOperation::*; match self.0 { // Nothing to check. - None => Ok(output_statement.code() == NativeStatement::None), + None => Ok(output_statement.code() == NativePredicate::None), // Check that the resulting statement is of type `ValueOf` // and its origin is `SELF`. NewEntry => - Ok(output_statement.code() == NativeStatement::ValueOf && output_statement.args()[0].key()?.origin() == SELF) + Ok(output_statement.code() == NativePredicate::ValueOf && output_statement.args()[0].key()?.origin() == SELF) , // Check that the operation acts on a statement *and* the // output is equal to this statement. @@ -458,30 +521,30 @@ impl Operation { let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); let s2 = self.args()[1].statement()?; let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value == s2_value; - Ok(statements_equal && output_statement.code() == NativeStatement::Equal && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)} + let statements_equal = s1.code() == NativePredicate::ValueOf && s2.code() == NativePredicate::ValueOf && s1_value == s2_value; + Ok(statements_equal && output_statement.code() == NativePredicate::Equal && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)} , NotEqualFromEntries => { let s1 = self.args()[0].statement()?; let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); let s2 = self.args()[1].statement()?; let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_not_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value != s2_value; - Ok(statements_not_equal && output_statement.code() == NativeStatement::NotEqual && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)} , + let statements_not_equal = s1.code() == NativePredicate::ValueOf && s2.code() == NativePredicate::ValueOf && s1_value != s2_value; + Ok(statements_not_equal && output_statement.code() == NativePredicate::NotEqual && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)} , GtFromEntries => { let s1 = self.args()[0].statement()?; let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); let s2 = self.args()[1].statement()?; let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_not_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value > s2_value; - Ok(statements_not_equal && output_statement.code() == NativeStatement::Gt && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)}, + let statements_not_equal = s1.code() == NativePredicate::ValueOf && s2.code() == NativePredicate::ValueOf && s1_value > s2_value; + Ok(statements_not_equal && output_statement.code() == NativePredicate::Gt && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)}, LtFromEntries => { let s1 = self.args()[0].statement()?; let (s1_key, s1_value) = (s1.args()[0].key()?, s1.args()[1].literal()?); let s2 = self.args()[1].statement()?; let (s2_key, s2_value) = (s2.args()[0].key()?, s2.args()[1].literal()?); - let statements_not_equal = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s1_value < s2_value; - Ok(statements_not_equal && output_statement.code() == NativeStatement::Lt && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)}, + let statements_not_equal = s1.code() == NativePredicate::ValueOf && s2.code() == NativePredicate::ValueOf && s1_value < s2_value; + Ok(statements_not_equal && output_statement.code() == NativePredicate::Lt && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key)}, TransitiveEqualFromStatements => { let s1 = self.args()[0].statement()?; let s2 = self.args()[1].statement()?; @@ -489,18 +552,18 @@ impl Operation { let key2 = s1.args()[1].key()?; let key3 = s2.args()[0].key()?; let key4 = s2.args()[1].key()?; - let statements_satisfy_transitivity = s1.code() == NativeStatement::Equal && s2.code() == NativeStatement::Equal && key2 == key3; - Ok(statements_satisfy_transitivity && output_statement.code() == NativeStatement::Equal && output_statement.args()[0].key()? == key1 && output_statement.args()[1].key()? == key4) + let statements_satisfy_transitivity = s1.code() == NativePredicate::Equal && s2.code() == NativePredicate::Equal && key2 == key3; + Ok(statements_satisfy_transitivity && output_statement.code() == NativePredicate::Equal && output_statement.args()[0].key()? == key1 && output_statement.args()[1].key()? == key4) }, GtToNotEqual => { let s = self.args()[0].statement()?; - let arg_is_gt = s.code() == NativeStatement::Gt; - Ok(arg_is_gt && output_statement.code() == NativeStatement::NotEqual && output_statement.args() == s.args()) + let arg_is_gt = s.code() == NativePredicate::Gt; + Ok(arg_is_gt && output_statement.code() == NativePredicate::NotEqual && output_statement.args() == s.args()) }, LtToNotEqual => { let s = self.args()[0].statement()?; - let arg_is_lt = s.code() == NativeStatement::Lt; - Ok(arg_is_lt && output_statement.code() == NativeStatement::NotEqual && output_statement.args() == s.args()) + let arg_is_lt = s.code() == NativePredicate::Lt; + Ok(arg_is_lt && output_statement.code() == NativePredicate::NotEqual && output_statement.args() == s.args()) }, RenameContainedBy => { let s1 = self.args()[0].statement()?; @@ -509,8 +572,8 @@ impl Operation { let key2 = s1.args()[1].key()?; let key3 = s2.args()[0].key()?; let key4 = s2.args()[1].key()?; - let args_satisfy_rename = s1.code() == NativeStatement::Contains && s2.code() == NativeStatement::Equal && key1 == key3; - Ok(args_satisfy_rename && output_statement.code() == NativeStatement::Contains && output_statement.args()[0].key()? == key4 && output_statement.args()[1].key()? == key2) + let args_satisfy_rename = s1.code() == NativePredicate::Contains && s2.code() == NativePredicate::Equal && key1 == key3; + Ok(args_satisfy_rename && output_statement.code() == NativePredicate::Contains && output_statement.args()[0].key()? == key4 && output_statement.args()[1].key()? == key2) }, SumOf => { let s1 = self.args()[0].statement()?; @@ -522,8 +585,8 @@ impl Operation { let s3 = self.args()[2].statement()?; let s3_key = s3.args()[0].key()?; let s3_value: i64 = s3.args()[1].literal()?.try_into()?; - let sum_holds = s1.code() == NativeStatement::ValueOf && s2.code() == NativeStatement::ValueOf && s3.code() == NativeStatement::ValueOf && s1_value == s2_value + s3_value; - Ok(sum_holds && output_statement.code() == NativeStatement::SumOf && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key && output_statement.args()[2].key()? == s3_key) + let sum_holds = s1.code() == NativePredicate::ValueOf && s2.code() == NativePredicate::ValueOf && s3.code() == NativePredicate::ValueOf && s1_value == s2_value + s3_value; + Ok(sum_holds && output_statement.code() == NativePredicate::SumOf && output_statement.args()[0].key()? == s1_key && output_statement.args()[1].key()? == s2_key && output_statement.args()[2].key()? == s3_key) }, // TODO: Remaining ops. _ => Ok(true) @@ -540,7 +603,7 @@ pub trait Pod: fmt::Debug + DynClone { self.pub_statements() .into_iter() .filter_map(|st| match st.0 { - NativeStatement::ValueOf => Some(( + NativePredicate::ValueOf => Some(( st.1[0].key().expect("key"), st.1[1].literal().expect("literal"), )), @@ -599,3 +662,37 @@ pub trait ToFields { /// does the vector contain fn to_fields(self) -> (Vec, usize); } + +#[cfg(test)] +mod tests { + use super::*; + + struct StatementTmplBuilder { + predicate: Predicate, + } + + fn st_tmpl(p: impl Into) -> StatementTmplBuilder { + StatementTmplBuilder { + predicate: p.into(), + } + } + + fn predicate_and(args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder]) -> Predicate { + let custom_predicate = CustomPredicate { + conjunction: true, + statements: vec![], // TODO + args_len: args.len(), + }; + Predicate::Custom(Arc::new(custom_predicate)) + } + + #[test] + fn test_custom_pred() { + use NativePredicate::*; + let eth_friend = predicate_and( + &["src_or", "src_key", "dst_or", "dst_key"], + &["attestation_pod"], + &[st_tmpl(Equal)], + ); + } +} From 05d2162dde76d6492de9cb8c49383e58ab1539a8 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Tue, 18 Feb 2025 16:57:04 +0100 Subject: [PATCH 02/12] prototype custom predicates 1b --- src/middleware/mod.rs | 281 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 264 insertions(+), 17 deletions(-) diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 78e4def0..d20352b6 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -256,39 +256,55 @@ use std::sync::Arc; // BEGIN Custom 1b -/* -pub enum PodIdOrWildcard { - PodId(PodId), - Wildcard(usize), -} - +#[derive(Debug)] pub enum HashOrWildcard { Hash(Hash), Wildcard(usize), } +impl fmt::Display for HashOrWildcard { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Hash(h) => write!(f, "{}", h), + Self::Wildcard(n) => write!(f, "*{}", n), + } + } +} + +#[derive(Debug)] pub enum StatementTmplArg { None, Literal(Value), - Key(PodIdOrWildcard, HashOrWildcard), + Key(HashOrWildcard, HashOrWildcard), +} + +impl fmt::Display for StatementTmplArg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::None => write!(f, "none"), + Self::Literal(v) => write!(f, "{}", v), + Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), + } + } } -*/ // END // BEGIN Custom 2 -pub enum StatementTmplArg { - None, - Literal(Value), - Wildcard(usize), -} +// pub enum StatementTmplArg { +// None, +// Literal(Value), +// Wildcard(usize), +// } // END /// Statement Template for a Custom Predicate +#[derive(Debug)] pub struct StatementTmpl(Predicate, Vec); +#[derive(Debug)] pub struct CustomPredicate { /// true for "and", false for "or" pub conjunction: bool, @@ -298,6 +314,40 @@ pub struct CustomPredicate { // TODO: Add args type information? } +impl fmt::Display for CustomPredicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; + for st in &self.statements { + // NOTE: With recursive custom predicates we can't just display the predicate again + // because then this call will run into an infinite loop. Instead we should find a way + // to name custom predicates and use the names here. For this we will probably need an + // auxiliary data structure to hold the names, which IMO would be too complex to live + // in the middleware. For the middleware we may just print the custom predicate hash. + match &st.0 { + Predicate::Native(p) => write!(f, " {:?}(", p)?, + Predicate::Custom(_p) => write!(f, " TODO(")?, + } + for (i, arg) in st.1.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + writeln!(f, "),")?; + } + write!(f, ">(")?; + for i in 0..self.args_len { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "*{}", i)?; + } + writeln!(f, ")")?; + Ok(()) + } +} + +#[derive(Clone, Debug)] pub enum Predicate { Native(NativePredicate), Custom(Arc), @@ -315,6 +365,15 @@ impl ToFields for Predicate { } } +impl fmt::Display for Predicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Native(p) => write!(f, "{:?}", p), + Self::Custom(p) => write!(f, "{}", p), + } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] /// AnchoredKey is a tuple containing (OriginId: PodId, key: Hash) pub struct AnchoredKey(pub PodId, pub Hash); @@ -667,20 +726,110 @@ pub trait ToFields { mod tests { use super::*; + enum HashOrWildcardStr { + Hash(Hash), + Wildcard(String), + } + + fn l(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Hash(hash_str(s)) + } + + fn w(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Wildcard(s.to_string()) + } + + enum BuilderArg { + Literal(Value), + Key(HashOrWildcardStr, HashOrWildcardStr), + } + + impl From<(HashOrWildcardStr, HashOrWildcardStr)> for BuilderArg { + fn from((pod_id, key): (HashOrWildcardStr, HashOrWildcardStr)) -> Self { + Self::Key(pod_id, key) + } + } + + impl From for BuilderArg + where + V: Into, + { + fn from(v: V) -> Self { + Self::Literal(v.into()) + } + } + struct StatementTmplBuilder { predicate: Predicate, + args: Vec, } fn st_tmpl(p: impl Into) -> StatementTmplBuilder { StatementTmplBuilder { predicate: p.into(), + args: Vec::new(), + } + } + + impl StatementTmplBuilder { + fn arg(mut self, a: impl Into) -> Self { + self.args.push(a.into()); + self } } fn predicate_and(args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder]) -> Predicate { + predicate(true, args, priv_args, sts) + } + + fn predicate_or(args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder]) -> Predicate { + predicate(false, args, priv_args, sts) + } + + fn resolve_wildcard( + args: &[&str], + priv_args: &[&str], + v: &HashOrWildcardStr, + ) -> HashOrWildcard { + match v { + HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h), + HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard( + args.iter() + .chain(priv_args.iter()) + .enumerate() + .find_map(|(i, name)| (&s == name).then_some(i)) + .unwrap(), + ), + } + } + + fn predicate( + conjunction: bool, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + use BuilderArg as BA; + let statements = sts + .iter() + .map(|sb| { + let args = sb + .args + .iter() + .map(|a| match a { + BA::Literal(v) => StatementTmplArg::Literal(*v), + BA::Key(pod_id, key) => StatementTmplArg::Key( + resolve_wildcard(args, priv_args, pod_id), + resolve_wildcard(args, priv_args, key), + ), + }) + .collect(); + StatementTmpl(sb.predicate.clone(), args) + }) + .collect(); let custom_predicate = CustomPredicate { - conjunction: true, - statements: vec![], // TODO + conjunction, + statements, args_len: args.len(), }; Predicate::Custom(Arc::new(custom_predicate)) @@ -688,11 +837,109 @@ mod tests { #[test] fn test_custom_pred() { - use NativePredicate::*; + use NativePredicate as NP; let eth_friend = predicate_and( &["src_or", "src_key", "dst_or", "dst_key"], &["attestation_pod"], - &[st_tmpl(Equal)], + &[ + st_tmpl(NP::ValueOf) + .arg((w("attestation_pod"), l("type"))) + .arg(PodType::Signed), + st_tmpl(NP::Equal) + .arg((w("attestation_pod"), l("signer"))) + .arg((w("src_or"), w("src_key"))), + st_tmpl(NP::Equal) + .arg((w("attestation_pod"), l("attestation"))) + .arg((w("dst_or"), w("dst_key"))), + ], + ); + + println!("eth_friend = {}", eth_friend); + + let eth_dos_distance_base = predicate_and( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[], + &[ + st_tmpl(NP::Equal) + .arg((w("src_or"), l("src_key"))) + .arg((w("dst_or"), w("dst_key"))), + st_tmpl(NP::ValueOf) + .arg((w("distance_or"), w("distance_key"))) + .arg(0), + ], + ); + + println!("eth_dos_distance_base = {}", eth_dos_distance_base); + + // TODO: replace this with a symbolic predicate index for recursion + let eth_dos_distance = NativePredicate::None; + + let eth_dos_distance_ind = predicate_and( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[ + "one_or", + "one_key", + "shorter_distance_or", + "shorter_distance_key", + "intermed_or", + "intermed_key", + ], + &[ + st_tmpl(eth_dos_distance) // TODO: Handle recursion + .arg((w("src_or"), w("src_key"))) + .arg((w("intermed_or"), w("intermed_key"))) + .arg((w("shorter_distance_or"), w("shorter_distance_key"))), + // distance == shorter_distance + 1 + st_tmpl(NP::ValueOf).arg((w("one_or"), w("one_key"))).arg(1), + st_tmpl(NP::SumOf) + .arg((w("distance_or"), w("distance_key"))) + .arg((w("shorter_distance_or"), w("shorter_distance_key"))) + .arg((w("one_or"), w("one_key"))), + // intermed is a friend of dst + st_tmpl(eth_friend) + .arg((w("intermed_or"), w("intermed_key"))) + .arg((w("dst_or"), w("dst_key"))), + ], + ); + + println!("eth_dos_distance_ind = {}", eth_dos_distance_ind); + + let eth_dos_distance = predicate_or( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[], + &[ + st_tmpl(eth_dos_distance_base) + .arg((w("src_or"), w("src_key"))) + .arg((w("dst_or"), w("dst_key"))) + .arg((w("distance_or"), w("distance_key"))), + st_tmpl(eth_dos_distance_ind) + .arg((w("src_or"), w("src_key"))) + .arg((w("dst_or"), w("dst_key"))) + .arg((w("distance_or"), w("distance_key"))), + ], ); + + println!("eth_dos_distance = {}", eth_dos_distance); } } From 0405a255bcac615905745e36b07e5af08491bb73 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Tue, 18 Feb 2025 18:18:48 +0100 Subject: [PATCH 03/12] feat: implement custom pred recursion --- src/middleware/mod.rs | 166 +++++++++++++++++++++++++++--------------- 1 file changed, 108 insertions(+), 58 deletions(-) diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index d20352b6..1693e8e0 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -318,15 +318,7 @@ impl fmt::Display for CustomPredicate { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; for st in &self.statements { - // NOTE: With recursive custom predicates we can't just display the predicate again - // because then this call will run into an infinite loop. Instead we should find a way - // to name custom predicates and use the names here. For this we will probably need an - // auxiliary data structure to hold the names, which IMO would be too complex to live - // in the middleware. For the middleware we may just print the custom predicate hash. - match &st.0 { - Predicate::Native(p) => write!(f, " {:?}(", p)?, - Predicate::Custom(_p) => write!(f, " TODO(")?, - } + write!(f, " {}", st.0)?; for (i, arg) in st.1.iter().enumerate() { if i != 0 { write!(f, ", ")?; @@ -347,10 +339,23 @@ impl fmt::Display for CustomPredicate { } } +#[derive(Debug)] +pub struct CustomPredicateBatch { + predicates: Vec, +} + +impl CustomPredicateBatch { + pub fn hash(&self) -> Hash { + // TODO + hash_str(&format!("{:?}", self)) + } +} + #[derive(Clone, Debug)] pub enum Predicate { Native(NativePredicate), - Custom(Arc), + BatchSelf(usize), + Custom(Arc, usize), } impl From for Predicate { @@ -369,7 +374,8 @@ impl fmt::Display for Predicate { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Native(p) => write!(f, "{:?}", p), - Self::Custom(p) => write!(f, "{}", p), + Self::BatchSelf(i) => write!(f, "self.{}", i), + Self::Custom(pb, i) => write!(f, "{}.{}", pb.hash(), i), } } } @@ -778,12 +784,74 @@ mod tests { } } - fn predicate_and(args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder]) -> Predicate { - predicate(true, args, priv_args, sts) + struct CustomPredicateBatchBuilder { + predicates: Vec, } - fn predicate_or(args: &[&str], priv_args: &[&str], sts: &[StatementTmplBuilder]) -> Predicate { - predicate(false, args, priv_args, sts) + impl CustomPredicateBatchBuilder { + fn new() -> Self { + Self { + predicates: Vec::new(), + } + } + + fn predicate_and( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(true, args, priv_args, sts) + } + + fn predicate_or( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(false, args, priv_args, sts) + } + + fn predicate( + &mut self, + conjunction: bool, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + use BuilderArg as BA; + let statements = sts + .iter() + .map(|sb| { + let args = sb + .args + .iter() + .map(|a| match a { + BA::Literal(v) => StatementTmplArg::Literal(*v), + BA::Key(pod_id, key) => StatementTmplArg::Key( + resolve_wildcard(args, priv_args, pod_id), + resolve_wildcard(args, priv_args, key), + ), + }) + .collect(); + StatementTmpl(sb.predicate.clone(), args) + }) + .collect(); + let custom_predicate = CustomPredicate { + conjunction, + statements, + args_len: args.len(), + }; + self.predicates.push(custom_predicate); + Predicate::BatchSelf(self.predicates.len() - 1) + } + + fn finish(self) -> Arc { + Arc::new(CustomPredicateBatch { + predicates: self.predicates, + }) + } } fn resolve_wildcard( @@ -803,42 +871,12 @@ mod tests { } } - fn predicate( - conjunction: bool, - args: &[&str], - priv_args: &[&str], - sts: &[StatementTmplBuilder], - ) -> Predicate { - use BuilderArg as BA; - let statements = sts - .iter() - .map(|sb| { - let args = sb - .args - .iter() - .map(|a| match a { - BA::Literal(v) => StatementTmplArg::Literal(*v), - BA::Key(pod_id, key) => StatementTmplArg::Key( - resolve_wildcard(args, priv_args, pod_id), - resolve_wildcard(args, priv_args, key), - ), - }) - .collect(); - StatementTmpl(sb.predicate.clone(), args) - }) - .collect(); - let custom_predicate = CustomPredicate { - conjunction, - statements, - args_len: args.len(), - }; - Predicate::Custom(Arc::new(custom_predicate)) - } - #[test] fn test_custom_pred() { use NativePredicate as NP; - let eth_friend = predicate_and( + + let mut builder = CustomPredicateBatchBuilder::new(); + let eth_friend = builder.predicate_and( &["src_or", "src_key", "dst_or", "dst_key"], &["attestation_pod"], &[ @@ -854,9 +892,13 @@ mod tests { ], ); - println!("eth_friend = {}", eth_friend); + println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); + let eth_friend = builder.finish(); + // This batch only has 1 predicate, so we pick it already for convenience + let eth_friend = Predicate::Custom(eth_friend, 0); - let eth_dos_distance_base = predicate_and( + let mut builder = CustomPredicateBatchBuilder::new(); + let eth_dos_distance_base = builder.predicate_and( &[ "src_or", "src_key", @@ -876,12 +918,14 @@ mod tests { ], ); - println!("eth_dos_distance_base = {}", eth_dos_distance_base); + println!( + "b.0. eth_dos_distance_base = {}", + builder.predicates.last().unwrap() + ); - // TODO: replace this with a symbolic predicate index for recursion - let eth_dos_distance = NativePredicate::None; + let eth_dos_distance = Predicate::BatchSelf(3); - let eth_dos_distance_ind = predicate_and( + let eth_dos_distance_ind = builder.predicate_and( &[ "src_or", "src_key", @@ -899,7 +943,7 @@ mod tests { "intermed_key", ], &[ - st_tmpl(eth_dos_distance) // TODO: Handle recursion + st_tmpl(eth_dos_distance) .arg((w("src_or"), w("src_key"))) .arg((w("intermed_or"), w("intermed_key"))) .arg((w("shorter_distance_or"), w("shorter_distance_key"))), @@ -916,9 +960,12 @@ mod tests { ], ); - println!("eth_dos_distance_ind = {}", eth_dos_distance_ind); + println!( + "b.1. eth_dos_distance_ind = {}", + builder.predicates.last().unwrap() + ); - let eth_dos_distance = predicate_or( + let eth_dos_distance = builder.predicate_or( &[ "src_or", "src_key", @@ -940,6 +987,9 @@ mod tests { ], ); - println!("eth_dos_distance = {}", eth_dos_distance); + println!( + "b.2. eth_dos_distance = {}", + builder.predicates.last().unwrap() + ); } } From 9e6f2d1f4b2a72eb7b298c0f3f1c16d899a147d5 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Thu, 20 Feb 2025 15:53:37 +0100 Subject: [PATCH 04/12] files reorg, add github CI for rustfmt checks --- .github/workflows/rustfmt.yml | 21 + .github/workflows/typos.toml | 1 + .../{mock_main.rs => mock_main/mod.rs} | 59 +-- src/backends/mock_main/operation.rs | 23 +- src/backends/mock_main/statement.rs | 3 +- src/backends/mock_signed.rs | 7 +- src/{frontend.rs => frontend/mod.rs} | 6 +- src/frontend/operation.rs | 3 +- src/frontend/statement.rs | 6 +- src/middleware/custom.rs | 397 ++++++++++++++++++ src/middleware/mod.rs | 6 +- src/middleware/operation.rs | 2 +- src/middleware/statement.rs | 397 +----------------- 13 files changed, 470 insertions(+), 461 deletions(-) create mode 100644 .github/workflows/rustfmt.yml rename src/backends/{mock_main.rs => mock_main/mod.rs} (94%) rename src/{frontend.rs => frontend/mod.rs} (100%) create mode 100644 src/middleware/custom.rs diff --git a/.github/workflows/rustfmt.yml b/.github/workflows/rustfmt.yml new file mode 100644 index 00000000..af1ff025 --- /dev/null +++ b/.github/workflows/rustfmt.yml @@ -0,0 +1,21 @@ +name: Rustfmt Check + +on: + pull_request: + branches: [ main ] + types: [ready_for_review, opened, synchronize, reopened] + push: + branches: [ main ] + +jobs: + rustfmt: + if: github.event.pull_request.draft == false + name: Rust formatting + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rustfmt + - name: Check formatting + uses: actions-rust-lang/rustfmt@v1 diff --git a/.github/workflows/typos.toml b/.github/workflows/typos.toml index 39107461..48147009 100644 --- a/.github/workflows/typos.toml +++ b/.github/workflows/typos.toml @@ -1,2 +1,3 @@ [default.extend-words] groth = "groth" # to avoid it dectecting it as 'growth' +BA = "BA" diff --git a/src/backends/mock_main.rs b/src/backends/mock_main/mod.rs similarity index 94% rename from src/backends/mock_main.rs rename to src/backends/mock_main/mod.rs index 09a9ef14..cbb0c7dc 100644 --- a/src/backends/mock_main.rs +++ b/src/backends/mock_main/mod.rs @@ -1,18 +1,19 @@ -mod operation; -mod statement; +use anyhow::{anyhow, Result}; +use itertools::Itertools; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::config::Hasher; +use std::any::Any; +use std::fmt; use crate::middleware::{ self, hash_str, AnchoredKey, Hash, MainPodInputs, NativeOperation, NativePredicate, NonePod, Params, Pod, PodId, PodProver, StatementArg, ToFields, KEY_TYPE, SELF, }; -use anyhow::Result; -use itertools::Itertools; + +mod operation; +mod statement; pub use operation::*; -use plonky2::hash::poseidon::PoseidonHash; -use plonky2::plonk::config::Hasher; pub use statement::*; -use std::any::Any; -use std::fmt; pub const VALUE_TYPE: &str = "MockMainPOD"; @@ -222,18 +223,17 @@ impl MockMainPod { fn find_op_arg( statements: &[Statement], op_arg: &middleware::Statement, - ) -> Result { + ) -> Result { match op_arg { middleware::Statement::None => Ok(OperationArg::None), _ => statements .iter() .enumerate() .find_map(|(i, s)| { - // TODO: Error handling - (&middleware::Statement::try_from(s.clone()).unwrap() == op_arg).then_some(i) + (&middleware::Statement::try_from(s.clone()).ok()? == op_arg).then_some(i) }) .map(OperationArg::Index) - .ok_or(OperationArgError::StatementNotFound), + .ok_or(anyhow!("statement not found")), } } @@ -241,7 +241,7 @@ impl MockMainPod { params: &Params, statements: &[Statement], input_operations: &[middleware::Operation], - ) -> Result, OperationArgError> { + ) -> Result> { let mut operations = Vec::new(); for i in 0..params.max_priv_statements() { let op = input_operations @@ -252,7 +252,7 @@ impl MockMainPod { let mut args = mid_args .iter() .map(|mid_arg| Self::find_op_arg(statements, mid_arg)) - .collect::, OperationArgError>>()?; + .collect::>>()?; Self::pad_operation_args(params, &mut args); operations.push(Operation(op.code(), args)); } @@ -265,7 +265,7 @@ impl MockMainPod { params: &Params, statements: &[Statement], mut operations: Vec, - ) -> Result, OperationArgError> { + ) -> Result> { let offset_public_statements = statements.len() - params.max_public_statements; operations.push(Operation(NativeOperation::NewEntry, vec![])); for i in 0..(params.max_public_statements - 1) { @@ -318,7 +318,7 @@ impl MockMainPod { statements[statements.len() - params.max_public_statements..].to_vec(); // get the id out of the public statements - let id: PodId = PodId(hash_statements(&public_statements)?); + let id: PodId = PodId(hash_statements(&public_statements)); Ok(Self { params: params.clone(), @@ -353,12 +353,12 @@ impl MockMainPod { } } -pub fn hash_statements(statements: &[Statement]) -> Result { +pub fn hash_statements(statements: &[Statement]) -> middleware::Hash { let field_elems = statements .into_iter() .flat_map(|statement| statement.clone().to_fields().0) .collect::>(); - Ok(Hash(PoseidonHash::hash_no_pad(&field_elems).elements)) + Hash(PoseidonHash::hash_no_pad(&field_elems).elements) } impl Pod for MockMainPod { @@ -367,7 +367,7 @@ impl Pod for MockMainPod { // get the input_statements from the self.statements let input_statements = &self.statements[input_statement_offset..]; // get the id out of the public statements, and ensure it is equal to self.id - let ids_match = self.id == PodId(hash_statements(&self.public_statements).unwrap()); + let ids_match = self.id == PodId(hash_statements(&self.public_statements)); // find a ValueOf statement from the public statements with key=KEY_TYPE and check that the // value is PodType::MockMainPod let has_type_statement = self @@ -473,22 +473,22 @@ pub mod tests { use crate::middleware; #[test] - fn test_mock_main_zu_kyc() { + fn test_mock_main_zu_kyc() -> Result<()> { let params = middleware::Params::default(); let (gov_id_builder, pay_stub_builder) = zu_kyc_sign_pod_builders(¶ms); let mut signer = MockSigner { pk: "ZooGov".into(), }; - let gov_id_pod = gov_id_builder.sign(&mut signer).unwrap(); + let gov_id_pod = gov_id_builder.sign(&mut signer)?; let mut signer = MockSigner { pk: "ZooDeel".into(), }; - let pay_stub_pod = pay_stub_builder.sign(&mut signer).unwrap(); + let pay_stub_pod = pay_stub_builder.sign(&mut signer)?; let kyc_builder = zu_kyc_pod_builder(¶ms, &gov_id_pod, &pay_stub_pod); let mut prover = MockProver {}; - let kyc_pod = kyc_builder.prove(&mut prover).unwrap(); + let kyc_pod = kyc_builder.prove(&mut prover)?; let pod = kyc_pod.pod.into_any().downcast::().unwrap(); println!("{:#}", pod); @@ -496,14 +496,15 @@ pub mod tests { assert_eq!(pod.verify(), true); // TODO // println!("id: {}", pod.id()); // println!("pub_statements: {:?}", pod.pub_statements()); + Ok(()) } #[test] - fn test_mock_main_great_boy() { + fn test_mock_main_great_boy() -> Result<()> { let great_boy_builder = great_boy_pod_full_flow(); let mut prover = MockProver {}; - let great_boy_pod = great_boy_builder.prove(&mut prover).unwrap(); + let great_boy_pod = great_boy_builder.prove(&mut prover)?; let pod = great_boy_pod .pod .into_any() @@ -513,16 +514,20 @@ pub mod tests { println!("{}", pod); assert_eq!(pod.verify(), true); + + Ok(()) } #[test] - fn test_mock_main_tickets() { + fn test_mock_main_tickets() -> Result<()> { let tickets_builder = tickets_pod_full_flow(); let mut prover = MockProver {}; - let proof_pod = tickets_builder.prove(&mut prover).unwrap(); + let proof_pod = tickets_builder.prove(&mut prover)?; let pod = proof_pod.pod.into_any().downcast::().unwrap(); println!("{}", pod); assert_eq!(pod.verify(), true); + + Ok(()) } } diff --git a/src/backends/mock_main/operation.rs b/src/backends/mock_main/operation.rs index 12cb9333..cb5ff3a1 100644 --- a/src/backends/mock_main/operation.rs +++ b/src/backends/mock_main/operation.rs @@ -1,10 +1,8 @@ -use std::fmt; - use anyhow::Result; - -use crate::middleware::{self, NativeOperation}; +use std::fmt; use super::Statement; +use crate::middleware::{self, NativeOperation}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { @@ -18,23 +16,6 @@ impl OperationArg { } } -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum OperationArgError { - KeyNotFound, - StatementNotFound, -} - -impl std::fmt::Display for OperationArgError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - OperationArgError::KeyNotFound => write!(f, "Key not found"), - OperationArgError::StatementNotFound => write!(f, "Statement not found"), - } - } -} - -impl std::error::Error for OperationArgError {} - #[derive(Clone, Debug, PartialEq, Eq)] pub struct Operation(pub NativeOperation, pub Vec); diff --git a/src/backends/mock_main/statement.rs b/src/backends/mock_main/statement.rs index 71b38b35..290bd61f 100644 --- a/src/backends/mock_main/statement.rs +++ b/src/backends/mock_main/statement.rs @@ -1,6 +1,5 @@ -use std::fmt; - use anyhow::{anyhow, Result}; +use std::fmt; use crate::middleware::{self, NativePredicate, StatementArg, ToFields}; diff --git a/src/backends/mock_signed.rs b/src/backends/mock_signed.rs index 3da701f6..afbaf760 100644 --- a/src/backends/mock_signed.rs +++ b/src/backends/mock_signed.rs @@ -1,11 +1,12 @@ +use anyhow::Result; +use std::any::Any; +use std::collections::HashMap; + use crate::middleware::{ containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, PodType, Statement, Value, KEY_SIGNER, KEY_TYPE, }; use crate::primitives::merkletree::MerkleTree; -use anyhow::Result; -use std::any::Any; -use std::collections::HashMap; pub struct MockSigner { pub pk: String, diff --git a/src/frontend.rs b/src/frontend/mod.rs similarity index 100% rename from src/frontend.rs rename to src/frontend/mod.rs index 2e5c8956..085b093c 100644 --- a/src/frontend.rs +++ b/src/frontend/mod.rs @@ -1,9 +1,6 @@ //! The frontend includes the user-level abstractions and user-friendly types to define and work //! with Pods. -mod operation; -mod statement; - use anyhow::Result; use itertools::Itertools; use std::collections::HashMap; @@ -16,6 +13,9 @@ use crate::middleware::{ hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver, PodSigner, SELF, }; + +mod operation; +mod statement; pub use operation::*; pub use statement::*; diff --git a/src/frontend/operation.rs b/src/frontend/operation.rs index aeac5089..57d6f4f1 100644 --- a/src/frontend/operation.rs +++ b/src/frontend/operation.rs @@ -1,8 +1,7 @@ use std::fmt; -use crate::middleware::{hash_str, NativeOperation, NativePredicate}; - use super::{AnchoredKey, SignedPod, Statement, StatementArg, Value}; +use crate::middleware::{hash_str, NativeOperation, NativePredicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum OperationArg { diff --git a/src/frontend/statement.rs b/src/frontend/statement.rs index 085de3bf..59a75e22 100644 --- a/src/frontend/statement.rs +++ b/src/frontend/statement.rs @@ -1,10 +1,8 @@ -use std::fmt; - use anyhow::{anyhow, Result}; - -use crate::middleware::{self, NativePredicate}; +use std::fmt; use super::{AnchoredKey, Value}; +use crate::middleware::{self, NativePredicate}; #[derive(Clone, Debug, PartialEq, Eq)] pub enum StatementArg { diff --git a/src/middleware/custom.rs b/src/middleware/custom.rs new file mode 100644 index 00000000..4bef99da --- /dev/null +++ b/src/middleware/custom.rs @@ -0,0 +1,397 @@ +use std::fmt; +use std::sync::Arc; + +use super::{hash_str, Hash, NativePredicate, ToFields, Value, F}; + +// BEGIN Custom 1b + +#[derive(Debug)] +pub enum HashOrWildcard { + Hash(Hash), + Wildcard(usize), +} + +impl fmt::Display for HashOrWildcard { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Hash(h) => write!(f, "{}", h), + Self::Wildcard(n) => write!(f, "*{}", n), + } + } +} + +#[derive(Debug)] +pub enum StatementTmplArg { + None, + Literal(Value), + Key(HashOrWildcard, HashOrWildcard), +} + +impl fmt::Display for StatementTmplArg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::None => write!(f, "none"), + Self::Literal(v) => write!(f, "{}", v), + Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), + } + } +} + +// END + +// BEGIN Custom 2 + +// pub enum StatementTmplArg { +// None, +// Literal(Value), +// Wildcard(usize), +// } + +// END + +/// Statement Template for a Custom Predicate +#[derive(Debug)] +pub struct StatementTmpl(Predicate, Vec); + +#[derive(Debug)] +pub struct CustomPredicate { + /// true for "and", false for "or" + pub conjunction: bool, + pub statements: Vec, + pub args_len: usize, + // TODO: Add private args length? + // TODO: Add args type information? +} + +impl fmt::Display for CustomPredicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; + for st in &self.statements { + write!(f, " {}", st.0)?; + for (i, arg) in st.1.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; + } + writeln!(f, "),")?; + } + write!(f, ">(")?; + for i in 0..self.args_len { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "*{}", i)?; + } + writeln!(f, ")")?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct CustomPredicateBatch { + predicates: Vec, +} + +impl CustomPredicateBatch { + pub fn hash(&self) -> Hash { + // TODO + hash_str(&format!("{:?}", self)) + } +} + +#[derive(Clone, Debug)] +pub enum Predicate { + Native(NativePredicate), + BatchSelf(usize), + Custom(Arc, usize), +} + +impl From for Predicate { + fn from(v: NativePredicate) -> Self { + Self::Native(v) + } +} + +impl ToFields for Predicate { + fn to_fields(self) -> (Vec, usize) { + todo!() + } +} + +impl fmt::Display for Predicate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Native(p) => write!(f, "{:?}", p), + Self::BatchSelf(i) => write!(f, "self.{}", i), + Self::Custom(pb, i) => write!(f, "{}.{}", pb.hash(), i), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::middleware::PodType; + + enum HashOrWildcardStr { + Hash(Hash), + Wildcard(String), + } + + fn l(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Hash(hash_str(s)) + } + + fn w(s: &str) -> HashOrWildcardStr { + HashOrWildcardStr::Wildcard(s.to_string()) + } + + enum BuilderArg { + Literal(Value), + Key(HashOrWildcardStr, HashOrWildcardStr), + } + + impl From<(HashOrWildcardStr, HashOrWildcardStr)> for BuilderArg { + fn from((pod_id, key): (HashOrWildcardStr, HashOrWildcardStr)) -> Self { + Self::Key(pod_id, key) + } + } + + impl From for BuilderArg + where + V: Into, + { + fn from(v: V) -> Self { + Self::Literal(v.into()) + } + } + + struct StatementTmplBuilder { + predicate: Predicate, + args: Vec, + } + + fn st_tmpl(p: impl Into) -> StatementTmplBuilder { + StatementTmplBuilder { + predicate: p.into(), + args: Vec::new(), + } + } + + impl StatementTmplBuilder { + fn arg(mut self, a: impl Into) -> Self { + self.args.push(a.into()); + self + } + } + + struct CustomPredicateBatchBuilder { + predicates: Vec, + } + + impl CustomPredicateBatchBuilder { + fn new() -> Self { + Self { + predicates: Vec::new(), + } + } + + fn predicate_and( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(true, args, priv_args, sts) + } + + fn predicate_or( + &mut self, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + self.predicate(false, args, priv_args, sts) + } + + fn predicate( + &mut self, + conjunction: bool, + args: &[&str], + priv_args: &[&str], + sts: &[StatementTmplBuilder], + ) -> Predicate { + use BuilderArg as BA; + let statements = sts + .iter() + .map(|sb| { + let args = sb + .args + .iter() + .map(|a| match a { + BA::Literal(v) => StatementTmplArg::Literal(*v), + BA::Key(pod_id, key) => StatementTmplArg::Key( + resolve_wildcard(args, priv_args, pod_id), + resolve_wildcard(args, priv_args, key), + ), + }) + .collect(); + StatementTmpl(sb.predicate.clone(), args) + }) + .collect(); + let custom_predicate = CustomPredicate { + conjunction, + statements, + args_len: args.len(), + }; + self.predicates.push(custom_predicate); + Predicate::BatchSelf(self.predicates.len() - 1) + } + + fn finish(self) -> Arc { + Arc::new(CustomPredicateBatch { + predicates: self.predicates, + }) + } + } + + fn resolve_wildcard( + args: &[&str], + priv_args: &[&str], + v: &HashOrWildcardStr, + ) -> HashOrWildcard { + match v { + HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h), + HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard( + args.iter() + .chain(priv_args.iter()) + .enumerate() + .find_map(|(i, name)| (&s == name).then_some(i)) + .unwrap(), + ), + } + } + + #[test] + fn test_custom_pred() { + use NativePredicate as NP; + + let mut builder = CustomPredicateBatchBuilder::new(); + let _eth_friend = builder.predicate_and( + &["src_or", "src_key", "dst_or", "dst_key"], + &["attestation_pod"], + &[ + st_tmpl(NP::ValueOf) + .arg((w("attestation_pod"), l("type"))) + .arg(PodType::Signed), + st_tmpl(NP::Equal) + .arg((w("attestation_pod"), l("signer"))) + .arg((w("src_or"), w("src_key"))), + st_tmpl(NP::Equal) + .arg((w("attestation_pod"), l("attestation"))) + .arg((w("dst_or"), w("dst_key"))), + ], + ); + + println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); + let eth_friend = builder.finish(); + // This batch only has 1 predicate, so we pick it already for convenience + let eth_friend = Predicate::Custom(eth_friend, 0); + + let mut builder = CustomPredicateBatchBuilder::new(); + let eth_dos_distance_base = builder.predicate_and( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[], + &[ + st_tmpl(NP::Equal) + .arg((w("src_or"), l("src_key"))) + .arg((w("dst_or"), w("dst_key"))), + st_tmpl(NP::ValueOf) + .arg((w("distance_or"), w("distance_key"))) + .arg(0), + ], + ); + + println!( + "b.0. eth_dos_distance_base = {}", + builder.predicates.last().unwrap() + ); + + let eth_dos_distance = Predicate::BatchSelf(3); + + let eth_dos_distance_ind = builder.predicate_and( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[ + "one_or", + "one_key", + "shorter_distance_or", + "shorter_distance_key", + "intermed_or", + "intermed_key", + ], + &[ + st_tmpl(eth_dos_distance) + .arg((w("src_or"), w("src_key"))) + .arg((w("intermed_or"), w("intermed_key"))) + .arg((w("shorter_distance_or"), w("shorter_distance_key"))), + // distance == shorter_distance + 1 + st_tmpl(NP::ValueOf).arg((w("one_or"), w("one_key"))).arg(1), + st_tmpl(NP::SumOf) + .arg((w("distance_or"), w("distance_key"))) + .arg((w("shorter_distance_or"), w("shorter_distance_key"))) + .arg((w("one_or"), w("one_key"))), + // intermed is a friend of dst + st_tmpl(eth_friend) + .arg((w("intermed_or"), w("intermed_key"))) + .arg((w("dst_or"), w("dst_key"))), + ], + ); + + println!( + "b.1. eth_dos_distance_ind = {}", + builder.predicates.last().unwrap() + ); + + let _eth_dos_distance = builder.predicate_or( + &[ + "src_or", + "src_key", + "dst_or", + "dst_key", + "distance_or", + "distance_key", + ], + &[], + &[ + st_tmpl(eth_dos_distance_base) + .arg((w("src_or"), w("src_key"))) + .arg((w("dst_or"), w("dst_key"))) + .arg((w("distance_or"), w("distance_key"))), + st_tmpl(eth_dos_distance_ind) + .arg((w("src_or"), w("src_key"))) + .arg((w("dst_or"), w("dst_key"))) + .arg((w("distance_or"), w("distance_key"))), + ], + ); + + println!( + "b.2. eth_dos_distance = {}", + builder.predicates.last().unwrap() + ); + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 546b83de..14cd9f29 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,18 +1,20 @@ //! The middleware includes the type definitions and the traits used to connect the frontend and //! the backend. +mod custom; mod operation; mod statement; +pub use custom::*; +pub use operation::*; +pub use statement::*; use anyhow::{anyhow, Error, Result}; use dyn_clone::DynClone; use hex::{FromHex, FromHexError}; -pub use operation::*; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::{Hasher, PoseidonGoldilocksConfig}; -pub use statement::*; use std::any::Any; use std::cmp::{Ord, Ordering}; use std::collections::HashMap; diff --git a/src/middleware/operation.rs b/src/middleware/operation.rs index 5f3a5c36..f8934deb 100644 --- a/src/middleware/operation.rs +++ b/src/middleware/operation.rs @@ -1,7 +1,7 @@ -use crate::middleware::{AnchoredKey, SELF}; use anyhow::{anyhow, Result}; use super::Statement; +use crate::middleware::{AnchoredKey, SELF}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NativeOperation { diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index eb437d3f..0d35805b 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -3,7 +3,7 @@ use plonky2::field::types::Field; use std::fmt; use strum_macros::FromRepr; -use super::{hash_str, AnchoredKey, Hash, ToFields, Value, F}; +use super::{AnchoredKey, ToFields, Value, F}; pub const KEY_SIGNER: &str = "_signer"; pub const KEY_TYPE: &str = "_type"; @@ -181,398 +181,3 @@ impl ToFields for StatementArg { (f, STATEMENT_ARG_F_LEN) } } - -use std::sync::Arc; - -// BEGIN Custom 1b - -#[derive(Debug)] -pub enum HashOrWildcard { - Hash(Hash), - Wildcard(usize), -} - -impl fmt::Display for HashOrWildcard { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Hash(h) => write!(f, "{}", h), - Self::Wildcard(n) => write!(f, "*{}", n), - } - } -} - -#[derive(Debug)] -pub enum StatementTmplArg { - None, - Literal(Value), - Key(HashOrWildcard, HashOrWildcard), -} - -impl fmt::Display for StatementTmplArg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::None => write!(f, "none"), - Self::Literal(v) => write!(f, "{}", v), - Self::Key(pod_id, key) => write!(f, "({}, {})", pod_id, key), - } - } -} - -// END - -// BEGIN Custom 2 - -// pub enum StatementTmplArg { -// None, -// Literal(Value), -// Wildcard(usize), -// } - -// END - -/// Statement Template for a Custom Predicate -#[derive(Debug)] -pub struct StatementTmpl(Predicate, Vec); - -#[derive(Debug)] -pub struct CustomPredicate { - /// true for "and", false for "or" - pub conjunction: bool, - pub statements: Vec, - pub args_len: usize, - // TODO: Add private args length? - // TODO: Add args type information? -} - -impl fmt::Display for CustomPredicate { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, "{}<", if self.conjunction { "and" } else { "or" })?; - for st in &self.statements { - write!(f, " {}", st.0)?; - for (i, arg) in st.1.iter().enumerate() { - if i != 0 { - write!(f, ", ")?; - } - write!(f, "{}", arg)?; - } - writeln!(f, "),")?; - } - write!(f, ">(")?; - for i in 0..self.args_len { - if i != 0 { - write!(f, ", ")?; - } - write!(f, "*{}", i)?; - } - writeln!(f, ")")?; - Ok(()) - } -} - -#[derive(Debug)] -pub struct CustomPredicateBatch { - predicates: Vec, -} - -impl CustomPredicateBatch { - pub fn hash(&self) -> Hash { - // TODO - hash_str(&format!("{:?}", self)) - } -} - -#[derive(Clone, Debug)] -pub enum Predicate { - Native(NativePredicate), - BatchSelf(usize), - Custom(Arc, usize), -} - -impl From for Predicate { - fn from(v: NativePredicate) -> Self { - Self::Native(v) - } -} - -impl ToFields for Predicate { - fn to_fields(self) -> (Vec, usize) { - todo!() - } -} - -impl fmt::Display for Predicate { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Native(p) => write!(f, "{:?}", p), - Self::BatchSelf(i) => write!(f, "self.{}", i), - Self::Custom(pb, i) => write!(f, "{}.{}", pb.hash(), i), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::middleware::PodType; - - enum HashOrWildcardStr { - Hash(Hash), - Wildcard(String), - } - - fn l(s: &str) -> HashOrWildcardStr { - HashOrWildcardStr::Hash(hash_str(s)) - } - - fn w(s: &str) -> HashOrWildcardStr { - HashOrWildcardStr::Wildcard(s.to_string()) - } - - enum BuilderArg { - Literal(Value), - Key(HashOrWildcardStr, HashOrWildcardStr), - } - - impl From<(HashOrWildcardStr, HashOrWildcardStr)> for BuilderArg { - fn from((pod_id, key): (HashOrWildcardStr, HashOrWildcardStr)) -> Self { - Self::Key(pod_id, key) - } - } - - impl From for BuilderArg - where - V: Into, - { - fn from(v: V) -> Self { - Self::Literal(v.into()) - } - } - - struct StatementTmplBuilder { - predicate: Predicate, - args: Vec, - } - - fn st_tmpl(p: impl Into) -> StatementTmplBuilder { - StatementTmplBuilder { - predicate: p.into(), - args: Vec::new(), - } - } - - impl StatementTmplBuilder { - fn arg(mut self, a: impl Into) -> Self { - self.args.push(a.into()); - self - } - } - - struct CustomPredicateBatchBuilder { - predicates: Vec, - } - - impl CustomPredicateBatchBuilder { - fn new() -> Self { - Self { - predicates: Vec::new(), - } - } - - fn predicate_and( - &mut self, - args: &[&str], - priv_args: &[&str], - sts: &[StatementTmplBuilder], - ) -> Predicate { - self.predicate(true, args, priv_args, sts) - } - - fn predicate_or( - &mut self, - args: &[&str], - priv_args: &[&str], - sts: &[StatementTmplBuilder], - ) -> Predicate { - self.predicate(false, args, priv_args, sts) - } - - fn predicate( - &mut self, - conjunction: bool, - args: &[&str], - priv_args: &[&str], - sts: &[StatementTmplBuilder], - ) -> Predicate { - use BuilderArg as BA; - let statements = sts - .iter() - .map(|sb| { - let args = sb - .args - .iter() - .map(|a| match a { - BA::Literal(v) => StatementTmplArg::Literal(*v), - BA::Key(pod_id, key) => StatementTmplArg::Key( - resolve_wildcard(args, priv_args, pod_id), - resolve_wildcard(args, priv_args, key), - ), - }) - .collect(); - StatementTmpl(sb.predicate.clone(), args) - }) - .collect(); - let custom_predicate = CustomPredicate { - conjunction, - statements, - args_len: args.len(), - }; - self.predicates.push(custom_predicate); - Predicate::BatchSelf(self.predicates.len() - 1) - } - - fn finish(self) -> Arc { - Arc::new(CustomPredicateBatch { - predicates: self.predicates, - }) - } - } - - fn resolve_wildcard( - args: &[&str], - priv_args: &[&str], - v: &HashOrWildcardStr, - ) -> HashOrWildcard { - match v { - HashOrWildcardStr::Hash(h) => HashOrWildcard::Hash(*h), - HashOrWildcardStr::Wildcard(s) => HashOrWildcard::Wildcard( - args.iter() - .chain(priv_args.iter()) - .enumerate() - .find_map(|(i, name)| (&s == name).then_some(i)) - .unwrap(), - ), - } - } - - #[test] - fn test_custom_pred() { - use NativePredicate as NP; - - let mut builder = CustomPredicateBatchBuilder::new(); - let eth_friend = builder.predicate_and( - &["src_or", "src_key", "dst_or", "dst_key"], - &["attestation_pod"], - &[ - st_tmpl(NP::ValueOf) - .arg((w("attestation_pod"), l("type"))) - .arg(PodType::Signed), - st_tmpl(NP::Equal) - .arg((w("attestation_pod"), l("signer"))) - .arg((w("src_or"), w("src_key"))), - st_tmpl(NP::Equal) - .arg((w("attestation_pod"), l("attestation"))) - .arg((w("dst_or"), w("dst_key"))), - ], - ); - - println!("a.0. eth_friend = {}", builder.predicates.last().unwrap()); - let eth_friend = builder.finish(); - // This batch only has 1 predicate, so we pick it already for convenience - let eth_friend = Predicate::Custom(eth_friend, 0); - - let mut builder = CustomPredicateBatchBuilder::new(); - let eth_dos_distance_base = builder.predicate_and( - &[ - "src_or", - "src_key", - "dst_or", - "dst_key", - "distance_or", - "distance_key", - ], - &[], - &[ - st_tmpl(NP::Equal) - .arg((w("src_or"), l("src_key"))) - .arg((w("dst_or"), w("dst_key"))), - st_tmpl(NP::ValueOf) - .arg((w("distance_or"), w("distance_key"))) - .arg(0), - ], - ); - - println!( - "b.0. eth_dos_distance_base = {}", - builder.predicates.last().unwrap() - ); - - let eth_dos_distance = Predicate::BatchSelf(3); - - let eth_dos_distance_ind = builder.predicate_and( - &[ - "src_or", - "src_key", - "dst_or", - "dst_key", - "distance_or", - "distance_key", - ], - &[ - "one_or", - "one_key", - "shorter_distance_or", - "shorter_distance_key", - "intermed_or", - "intermed_key", - ], - &[ - st_tmpl(eth_dos_distance) - .arg((w("src_or"), w("src_key"))) - .arg((w("intermed_or"), w("intermed_key"))) - .arg((w("shorter_distance_or"), w("shorter_distance_key"))), - // distance == shorter_distance + 1 - st_tmpl(NP::ValueOf).arg((w("one_or"), w("one_key"))).arg(1), - st_tmpl(NP::SumOf) - .arg((w("distance_or"), w("distance_key"))) - .arg((w("shorter_distance_or"), w("shorter_distance_key"))) - .arg((w("one_or"), w("one_key"))), - // intermed is a friend of dst - st_tmpl(eth_friend) - .arg((w("intermed_or"), w("intermed_key"))) - .arg((w("dst_or"), w("dst_key"))), - ], - ); - - println!( - "b.1. eth_dos_distance_ind = {}", - builder.predicates.last().unwrap() - ); - - let eth_dos_distance = builder.predicate_or( - &[ - "src_or", - "src_key", - "dst_or", - "dst_key", - "distance_or", - "distance_key", - ], - &[], - &[ - st_tmpl(eth_dos_distance_base) - .arg((w("src_or"), w("src_key"))) - .arg((w("dst_or"), w("dst_key"))) - .arg((w("distance_or"), w("distance_key"))), - st_tmpl(eth_dos_distance_ind) - .arg((w("src_or"), w("src_key"))) - .arg((w("dst_or"), w("dst_key"))) - .arg((w("distance_or"), w("distance_key"))), - ], - ); - - println!( - "b.2. eth_dos_distance = {}", - builder.predicates.last().unwrap() - ); - } -} From cad724f486aa1edc45a710158cca7e5a773f514d Mon Sep 17 00:00:00 2001 From: arnaucube Date: Thu, 13 Feb 2025 19:12:40 +0100 Subject: [PATCH 05/12] start sparsemerkletree. impl add_leaf method, initial Leaf & Intermediate types with methods --- src/middleware/mod.rs | 9 ++ src/primitives/merkletree_new.rs | 236 +++++++++++++++++++++++++++++++ src/primitives/mod.rs | 1 + 3 files changed, 246 insertions(+) create mode 100644 src/primitives/merkletree_new.rs diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 14cd9f29..34cbbe46 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -48,6 +48,15 @@ pub type Entry = (String, Value); #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)] pub struct Value(pub [F; 4]); +impl Value { + pub fn to_bytes(self) -> Vec { + self.0 + .iter() + .flat_map(|e| e.to_canonical_u64().to_le_bytes()) + .collect() + } +} + impl Ord for Value { fn cmp(&self, other: &Self) -> Ordering { for (lhs, rhs) in self.0.iter().zip(other.0.iter()).rev() { diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs new file mode 100644 index 00000000..33f87c95 --- /dev/null +++ b/src/primitives/merkletree_new.rs @@ -0,0 +1,236 @@ +#![allow(unused)] +#![allow(dead_code, unused_variables)] +// NOTE: starting in this file (merkletree_new.rs), once we have the implementation ready we just +// place it in the merkletree.rs file. +use anyhow::{anyhow, Result}; +use itertools::Itertools; +use plonky2::field::types::Field; +use plonky2::hash::{hash_types::HashOut, poseidon::PoseidonHash}; +use plonky2::plonk::config::GenericConfig; +use plonky2::plonk::config::Hasher; +use std::collections::HashMap; +use std::iter::IntoIterator; + +use crate::middleware::{Hash, Value, C, D, F, NULL}; + +pub struct MerkleTree { + max_depth: usize, + root: Intermediate, +} + +#[derive(Clone, Debug)] +enum Node { + None, + Leaf(Leaf), + Intermediate(Intermediate), +} +impl Node { + fn is_empty(self) -> bool { + match self { + Self::None => true, + Self::Leaf(l) => false, + Self::Intermediate(n) => false, + } + } + fn hash(self) -> Hash { + match self { + Self::None => NULL, + Self::Leaf(l) => l.hash(), + Self::Intermediate(n) => n.hash(), + } + } + fn add_leaf(&mut self, lvl: usize, leaf: Leaf) -> Result<()> { + // TODO check that lvl<=maxlevels + + match self { + Self::Intermediate(n) => { + if leaf.path[lvl] { + if (*n.right).clone().is_empty() { + // empty sub-node, add the leaf here + n.right = Box::new(Node::Leaf(leaf)); + return Ok(()); + } + n.right.add_leaf(lvl + 1, leaf)?; + } else { + if (*n.left).clone().is_empty() { + // empty sub-node, add the leaf here + n.left = Box::new(Node::Leaf(leaf)); + return Ok(()); + } + n.left.add_leaf(lvl + 1, leaf)?; + } + } + Self::Leaf(l) => { + // in this case, it means that we found a leaf in the new-leaf path, thus we need + // to push both leaves (old-leaf and new-leaf) down the path till their paths + // diverge. + + // first check that keys of both leafs are different + // (l: old-leaf, leaf: new-leaf) + if l.key == leaf.key { + // TODO decide if we want to return an error when trying to add a leaf that + // allready exists, or if we just ignore it + return Err(anyhow!("key already exists")); + } + let old_leaf = l.clone(); + // set self as an intermediate node + *self = Node::Intermediate(Intermediate::empty()); + return self.down_till_divergence(lvl, old_leaf, leaf); + } + Self::None => { + return Err(anyhow!("reached empty node, should not have entered")); + } + } + Ok(()) + } + + fn down_till_divergence(&mut self, lvl: usize, old_leaf: Leaf, new_leaf: Leaf) -> Result<()> { + // TODO check that lvl<=maxlevels + + if let Node::Intermediate(ref mut n) = self { + // let current_node: Intermediate = *self; + if old_leaf.path[lvl] != new_leaf.path[lvl] { + // reached divergence in next level, set the leafs as childs at the current node + if new_leaf.path[lvl] { + n.left = Box::new(Node::Leaf(old_leaf)); + n.right = Box::new(Node::Leaf(new_leaf)); + } else { + n.left = Box::new(Node::Leaf(new_leaf)); + n.right = Box::new(Node::Leaf(old_leaf)); + } + return Ok(()); + } + + // no divergence yet, continue going down + if new_leaf.path[lvl] { + n.right = Box::new(Node::Intermediate(Intermediate::empty())); + return n.right.down_till_divergence(lvl + 1, old_leaf, new_leaf); + } else { + n.left = Box::new(Node::Intermediate(Intermediate::empty())); + return n.left.down_till_divergence(lvl + 1, old_leaf, new_leaf); + } + } + Ok(()) + } +} + +#[derive(Clone, Debug)] +struct Intermediate { + left: Box, + right: Box, +} +impl Intermediate { + fn empty() -> Self { + Self { + left: Box::new(Node::None), + right: Box::new(Node::None), + } + } + + // TODO move to a Node/Hashable trait? + fn hash(self) -> Hash { + let l_hash = self.left.hash(); + let r_hash = self.right.hash(); + let input: Vec = [l_hash.0, r_hash.0].concat(); + Hash(PoseidonHash::hash_no_pad(&input).elements) + } +} + +#[derive(Clone, Debug)] +struct Leaf { + path: Vec, + key: Value, + value: Value, +} +impl Leaf { + fn new(key: Value, value: Value) -> Self { + Self { + path: keypath(key), + key, + value, + } + } +} +impl Leaf { + // TODO move to a Node/Hashable trait? + fn hash(self) -> Hash { + let input: Vec = [self.key.0, self.value.0].concat(); + Hash(PoseidonHash::hash_no_pad(&input).elements) + } +} + +// TODO 1: think if maybe the length of the returned vector can be <256 (8*bytes.len()), so that +// we can do fewer iterations. For example, if the tree.max_depth is set to 20, we just need 20 +// iterations of the loop, not 256. +// TODO 2: which approach do we take with keys that are longer than the max-depth? ie, what +// happens when two keys share the same path for more bits than the max_depth? +fn keypath(k: Value) -> Vec { + let bytes = k.to_bytes(); + (0..8 * bytes.len()) + .map(|n| bytes[n / 8] & (1 << (n % 8)) != 0) + .collect() +} + +pub struct MerkleProof { + existence: bool, +} + +impl MerkleTree { + /// returns the root of the tree + fn root(&self) -> Hash { + todo!(); + } + + /// returns the value at the given key + pub fn get(&self, key: &Value) -> Result { + todo!(); + } + + /// returns a boolean indicating whether the key exists in the tree + pub fn contains(&self, key: &Value) -> bool { + todo!(); + } + + /// returns a proof of existence, which proves that the given key exists in + /// the tree. It returns the `value` of the leaf at the given `key`, and + /// the `MerkleProof`. + fn prove(&self, key: &Value) -> Result { + todo!(); + } + + /// returns a proof of non-existence, which proves that the given `key` + /// does not exist in the tree + fn prove_nonexistence(&self, key: &Value) -> Result { + todo!(); + } + + /// verifies an inclusion proof for the given `key` and `value` + fn verify(root: Hash, proof: &MerkleProof, key: &Value, value: &Value) -> Result<()> { + todo!(); + } + + /// verifies a non-inclusion proof for the given `key`, that is, the given + /// `key` does not exist in the tree + fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Value) -> Result<()> { + todo!(); + } + + /// returns an iterator over the leaves of the tree + fn iter(&self) -> std::collections::hash_map::Iter { + todo!(); + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::middleware::hash_str; + + #[test] + fn test_keypath() -> Result<()> { + let key = Value(hash_str("key".into()).0); + // dbg!(keypath(key)); + + Ok(()) + } +} diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index ad3ec6a6..8faa69ee 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -1 +1,2 @@ pub mod merkletree; +pub mod merkletree_new; From c4b1707540711abc34bb963f38399e88a574bd3b Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 17 Feb 2025 21:56:42 +0100 Subject: [PATCH 06/12] mt: add hash computation of all the nodes in the tree, add method to print the tree to visualize it as a graphviz --- src/primitives/merkletree_new.rs | 265 ++++++++++++++++++++++--------- 1 file changed, 189 insertions(+), 76 deletions(-) diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs index 33f87c95..be6fb5d9 100644 --- a/src/primitives/merkletree_new.rs +++ b/src/primitives/merkletree_new.rs @@ -9,13 +9,92 @@ use plonky2::hash::{hash_types::HashOut, poseidon::PoseidonHash}; use plonky2::plonk::config::GenericConfig; use plonky2::plonk::config::Hasher; use std::collections::HashMap; +use std::fmt; use std::iter::IntoIterator; use crate::middleware::{Hash, Value, C, D, F, NULL}; pub struct MerkleTree { max_depth: usize, - root: Intermediate, + root: Node, +} + +impl MerkleTree { + /// builds a new `MerkleTree` where the leaves contain the given key-values + pub fn new(max_depth: usize, kvs: &HashMap) -> Self { + let mut root = Node::Intermediate(Intermediate::empty()); + + for (k, v) in kvs.iter() { + let leaf = Leaf::new(*k, *v); + root.add_leaf(0, max_depth, leaf).unwrap(); // TODO unwrap + } + + let _ = root.compute_hash(); + Self { max_depth, root } + } +} + +impl fmt::Display for MerkleTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "\nPaste in GraphViz (https://dreampuf.github.io/GraphvizOnline/):\n-----\n" + ); + write!(f, "digraph hierarchy {{\n"); + write!(f, "node [fontname=Monospace,fontsize=10,shape=box]\n"); + write!(f, "{}", self.root); + write!(f, "\n}}\n-----\n") + } +} + +pub struct MerkleProof { + existence: bool, +} + +impl MerkleTree { + /// returns the root of the tree + fn root(&self) -> Hash { + self.root.hash() + } + + /// returns the value at the given key + pub fn get(&self, key: &Value) -> Result { + todo!(); + } + + /// returns a boolean indicating whether the key exists in the tree + pub fn contains(&self, key: &Value) -> bool { + todo!(); + } + + /// returns a proof of existence, which proves that the given key exists in + /// the tree. It returns the `value` of the leaf at the given `key`, and + /// the `MerkleProof`. + fn prove(&self, key: &Value) -> Result { + todo!(); + } + + /// returns a proof of non-existence, which proves that the given `key` + /// does not exist in the tree + fn prove_nonexistence(&self, key: &Value) -> Result { + todo!(); + } + + /// verifies an inclusion proof for the given `key` and `value` + fn verify(root: Hash, proof: &MerkleProof, key: &Value, value: &Value) -> Result<()> { + todo!(); + } + + /// verifies a non-inclusion proof for the given `key`, that is, the given + /// `key` does not exist in the tree + fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Value) -> Result<()> { + todo!(); + } + + /// returns an iterator over the leaves of the tree + fn iter(&self) -> std::collections::hash_map::Iter { + todo!(); + } } #[derive(Clone, Debug)] @@ -24,40 +103,82 @@ enum Node { Leaf(Leaf), Intermediate(Intermediate), } + +impl fmt::Display for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Intermediate(n) => { + write!( + f, + "\"{}\" -> {{ \"{}\" \"{}\" }}\n", + n.hash(), + n.left.hash(), + n.right.hash() + ); + write!(f, "{}", n.left); + write!(f, "{}", n.right) + } + Self::Leaf(l) => { + write!(f, "\"{}\" [style=filled]\n", l.hash()); + write!(f, "\"k:{}\\nv:{}\" [style=dashed]\n", l.key, l.value); + write!( + f, + "\"{}\" -> {{ \"k:{}\\nv:{}\" }}\n", + l.hash(), + l.key, + l.value, + ) + } + Self::None => Ok(()), + } + } +} + impl Node { - fn is_empty(self) -> bool { + fn is_empty(&self) -> bool { match self { Self::None => true, Self::Leaf(l) => false, Self::Intermediate(n) => false, } } - fn hash(self) -> Hash { + fn compute_hash(&mut self) -> Hash { + match self { + Self::None => NULL, + Self::Leaf(l) => l.compute_hash(), + Self::Intermediate(n) => n.compute_hash(), + } + } + fn hash(&self) -> Hash { match self { Self::None => NULL, Self::Leaf(l) => l.hash(), Self::Intermediate(n) => n.hash(), } } - fn add_leaf(&mut self, lvl: usize, leaf: Leaf) -> Result<()> { - // TODO check that lvl<=maxlevels + + // adds the leaf at the tree from the current node (self), without computing any hash + fn add_leaf(&mut self, lvl: usize, max_depth: usize, leaf: Leaf) -> Result<()> { + if lvl >= max_depth { + return Err(anyhow!("max depth reached")); + } match self { Self::Intermediate(n) => { if leaf.path[lvl] { - if (*n.right).clone().is_empty() { + if n.right.is_empty() { // empty sub-node, add the leaf here n.right = Box::new(Node::Leaf(leaf)); return Ok(()); } - n.right.add_leaf(lvl + 1, leaf)?; + n.right.add_leaf(lvl + 1, max_depth, leaf)?; } else { - if (*n.left).clone().is_empty() { + if n.left.is_empty() { // empty sub-node, add the leaf here n.left = Box::new(Node::Leaf(leaf)); return Ok(()); } - n.left.add_leaf(lvl + 1, leaf)?; + n.left.add_leaf(lvl + 1, max_depth, leaf)?; } } Self::Leaf(l) => { @@ -66,7 +187,7 @@ impl Node { // diverge. // first check that keys of both leafs are different - // (l: old-leaf, leaf: new-leaf) + // (l=old-leaf, leaf=new-leaf) if l.key == leaf.key { // TODO decide if we want to return an error when trying to add a leaf that // allready exists, or if we just ignore it @@ -75,7 +196,7 @@ impl Node { let old_leaf = l.clone(); // set self as an intermediate node *self = Node::Intermediate(Intermediate::empty()); - return self.down_till_divergence(lvl, old_leaf, leaf); + return self.down_till_divergence(lvl, max_depth, old_leaf, leaf); } Self::None => { return Err(anyhow!("reached empty node, should not have entered")); @@ -84,8 +205,16 @@ impl Node { Ok(()) } - fn down_till_divergence(&mut self, lvl: usize, old_leaf: Leaf, new_leaf: Leaf) -> Result<()> { - // TODO check that lvl<=maxlevels + fn down_till_divergence( + &mut self, + lvl: usize, + max_depth: usize, + old_leaf: Leaf, + new_leaf: Leaf, + ) -> Result<()> { + if lvl >= max_depth { + return Err(anyhow!("max depth reached")); + } if let Node::Intermediate(ref mut n) = self { // let current_node: Intermediate = *self; @@ -104,10 +233,14 @@ impl Node { // no divergence yet, continue going down if new_leaf.path[lvl] { n.right = Box::new(Node::Intermediate(Intermediate::empty())); - return n.right.down_till_divergence(lvl + 1, old_leaf, new_leaf); + return n + .right + .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); } else { n.left = Box::new(Node::Intermediate(Intermediate::empty())); - return n.left.down_till_divergence(lvl + 1, old_leaf, new_leaf); + return n + .left + .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); } } Ok(()) @@ -116,28 +249,39 @@ impl Node { #[derive(Clone, Debug)] struct Intermediate { + hash: Option, left: Box, right: Box, } impl Intermediate { fn empty() -> Self { Self { + hash: None, left: Box::new(Node::None), right: Box::new(Node::None), } } - // TODO move to a Node/Hashable trait? - fn hash(self) -> Hash { - let l_hash = self.left.hash(); - let r_hash = self.right.hash(); + fn compute_hash(&mut self) -> Hash { + if self.left.clone().is_empty() && self.right.clone().is_empty() { + self.hash = Some(NULL); + return NULL; + } + let l_hash = self.left.compute_hash(); + let r_hash = self.right.compute_hash(); let input: Vec = [l_hash.0, r_hash.0].concat(); - Hash(PoseidonHash::hash_no_pad(&input).elements) + let h = Hash(PoseidonHash::hash_no_pad(&input).elements); + self.hash = Some(h); + h + } + fn hash(&self) -> Hash { + self.hash.unwrap() } } #[derive(Clone, Debug)] struct Leaf { + hash: Option, path: Vec, key: Value, value: Value, @@ -145,6 +289,7 @@ struct Leaf { impl Leaf { fn new(key: Value, value: Value) -> Self { Self { + hash: None, path: keypath(key), key, value, @@ -152,10 +297,14 @@ impl Leaf { } } impl Leaf { - // TODO move to a Node/Hashable trait? - fn hash(self) -> Hash { + fn compute_hash(&mut self) -> Hash { let input: Vec = [self.key.0, self.value.0].concat(); - Hash(PoseidonHash::hash_no_pad(&input).elements) + let h = Hash(PoseidonHash::hash_no_pad(&input).elements); + self.hash = Some(h); + h + } + fn hash(&self) -> Hash { + self.hash.unwrap() } } @@ -171,65 +320,29 @@ fn keypath(k: Value) -> Vec { .collect() } -pub struct MerkleProof { - existence: bool, -} - -impl MerkleTree { - /// returns the root of the tree - fn root(&self) -> Hash { - todo!(); - } - - /// returns the value at the given key - pub fn get(&self, key: &Value) -> Result { - todo!(); - } - - /// returns a boolean indicating whether the key exists in the tree - pub fn contains(&self, key: &Value) -> bool { - todo!(); - } - - /// returns a proof of existence, which proves that the given key exists in - /// the tree. It returns the `value` of the leaf at the given `key`, and - /// the `MerkleProof`. - fn prove(&self, key: &Value) -> Result { - todo!(); - } - - /// returns a proof of non-existence, which proves that the given `key` - /// does not exist in the tree - fn prove_nonexistence(&self, key: &Value) -> Result { - todo!(); - } - - /// verifies an inclusion proof for the given `key` and `value` - fn verify(root: Hash, proof: &MerkleProof, key: &Value, value: &Value) -> Result<()> { - todo!(); - } - - /// verifies a non-inclusion proof for the given `key`, that is, the given - /// `key` does not exist in the tree - fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Value) -> Result<()> { - todo!(); - } - - /// returns an iterator over the leaves of the tree - fn iter(&self) -> std::collections::hash_map::Iter { - todo!(); - } -} - #[cfg(test)] pub mod tests { use super::*; use crate::middleware::hash_str; #[test] - fn test_keypath() -> Result<()> { - let key = Value(hash_str("key".into()).0); - // dbg!(keypath(key)); + fn test_merkletree() -> Result<()> { + // let v = Value(hash_str("value_0".into()).0); + let v = crate::middleware::EMPTY; + + let mut kvs = HashMap::new(); + for i in 0..8 { + if i == 1 { + continue; + } + kvs.insert(Value::from(i), v); + } + kvs.insert(Value::from(13), v); + + let tree = MerkleTree::new(32, &kvs); + // it should print the same tree as in + // https://0xparc.github.io/pod2/merkletree.html#example-2 + println!("{}", tree); Ok(()) } From 62df6ea859797454b37d0043642c9e76e0f17a4a Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 17 Feb 2025 22:34:10 +0100 Subject: [PATCH 07/12] mt: add (till the leaf) method which is used by get,contains,prove methods --- src/primitives/merkletree_new.rs | 88 +++++++++++++++++++++++++++++--- 1 file changed, 80 insertions(+), 8 deletions(-) diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs index be6fb5d9..01e81800 100644 --- a/src/primitives/merkletree_new.rs +++ b/src/primitives/merkletree_new.rs @@ -47,8 +47,22 @@ impl fmt::Display for MerkleTree { } } +#[derive(Clone, Debug)] pub struct MerkleProof { existence: bool, + siblings: Vec, +} + +impl fmt::Display for MerkleProof { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (i, s) in self.siblings.iter().enumerate() { + if i > 0 { + write!(f, ", "); + } + write!(f, "{}", s); + } + Ok(()) + } } impl MerkleTree { @@ -59,19 +73,37 @@ impl MerkleTree { /// returns the value at the given key pub fn get(&self, key: &Value) -> Result { - todo!(); + let path = keypath(*key); + let (v, _) = self.root.down(0, self.max_depth, path, None)?; + Ok(v) } /// returns a boolean indicating whether the key exists in the tree pub fn contains(&self, key: &Value) -> bool { + let path = keypath(*key); + // TODO once thiserror is added to pod2 + // match self.root.down(0, self.max_depth, path, None) { + // Ok((_, _)) => true, + // Err("leaf not found")) => false, + // Err(_) => false, + // } todo!(); } /// returns a proof of existence, which proves that the given key exists in /// the tree. It returns the `value` of the leaf at the given `key`, and /// the `MerkleProof`. - fn prove(&self, key: &Value) -> Result { - todo!(); + fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { + let path = keypath(*key); + let (v, siblings) = self.root.down(0, self.max_depth, path, Some(Vec::new()))?; + Ok(( + v, + MerkleProof { + existence: true, + // `unwrap` is safe since we've called `down` passing a vector + siblings: siblings.unwrap(), + }, + )) } /// returns a proof of non-existence, which proves that the given `key` @@ -157,6 +189,45 @@ impl Node { } } + /// the `siblings` parameter is used to store the siblings while going down to the leaf, if the + /// given parameter is set to `None`, then no siblings are stored. In this way, the same method + /// `down` can be used by MerkleTree methods `get`, `contains`, `prove` and + /// `prove_nonexistence`. + fn down( + &self, + lvl: usize, + max_depth: usize, + path: Vec, + mut siblings: Option>, + ) -> Result<(Value, Option>)> { + if lvl >= max_depth { + return Err(anyhow!("max depth reached")); + } + + match self { + Self::Intermediate(n) => { + if path[lvl] { + if let Some(ref mut s) = siblings { + s.push(n.left.hash()); + } + return n.right.down(lvl + 1, max_depth, path, siblings); + } else { + if let Some(ref mut s) = siblings { + s.push(n.right.hash()); + } + return n.left.down(lvl + 1, max_depth, path, siblings); + } + } + Self::Leaf(l) => { + return Ok((l.value, siblings)); + } + Self::None => { + return Err(anyhow!("leaf not found")); + } + } + Err(anyhow!("leaf not found")) + } + // adds the leaf at the tree from the current node (self), without computing any hash fn add_leaf(&mut self, lvl: usize, max_depth: usize, leaf: Leaf) -> Result<()> { if lvl >= max_depth { @@ -327,23 +398,24 @@ pub mod tests { #[test] fn test_merkletree() -> Result<()> { - // let v = Value(hash_str("value_0".into()).0); - let v = crate::middleware::EMPTY; - let mut kvs = HashMap::new(); for i in 0..8 { if i == 1 { continue; } - kvs.insert(Value::from(i), v); + kvs.insert(Value::from(i), Value::from(1000 + i)); } - kvs.insert(Value::from(13), v); + kvs.insert(Value::from(13), Value::from(1013)); let tree = MerkleTree::new(32, &kvs); // it should print the same tree as in // https://0xparc.github.io/pod2/merkletree.html#example-2 println!("{}", tree); + let (v, proof) = tree.prove(&Value::from(13))?; + assert_eq!(v, Value::from(1013)); + println!("{}", proof); + Ok(()) } } From dd8b5a544f0eca2e58faae36fe8c5dcb8d29b1af Mon Sep 17 00:00:00 2001 From: arnaucube Date: Tue, 18 Feb 2025 07:45:34 +0100 Subject: [PATCH 08/12] mt: add verify (of inclusion) method --- src/primitives/merkletree_new.rs | 128 +++++++++++++++++++++---------- 1 file changed, 87 insertions(+), 41 deletions(-) diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs index 01e81800..2bb1135f 100644 --- a/src/primitives/merkletree_new.rs +++ b/src/primitives/merkletree_new.rs @@ -21,29 +21,29 @@ pub struct MerkleTree { impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values - pub fn new(max_depth: usize, kvs: &HashMap) -> Self { + pub fn new(max_depth: usize, kvs: &HashMap) -> Result { let mut root = Node::Intermediate(Intermediate::empty()); for (k, v) in kvs.iter() { - let leaf = Leaf::new(*k, *v); - root.add_leaf(0, max_depth, leaf).unwrap(); // TODO unwrap + let leaf = Leaf::new(max_depth, *k, *v)?; + root.add_leaf(0, max_depth, leaf)?; } let _ = root.compute_hash(); - Self { max_depth, root } + Ok(Self { max_depth, root }) } } impl fmt::Display for MerkleTree { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( + writeln!( f, - "\nPaste in GraphViz (https://dreampuf.github.io/GraphvizOnline/):\n-----\n" + "\nPaste in GraphViz (https://dreampuf.github.io/GraphvizOnline/):\n-----" ); - write!(f, "digraph hierarchy {{\n"); - write!(f, "node [fontname=Monospace,fontsize=10,shape=box]\n"); + writeln!(f, "digraph hierarchy {{"); + writeln!(f, "node [fontname=Monospace,fontsize=10,shape=box]"); write!(f, "{}", self.root); - write!(f, "\n}}\n-----\n") + writeln!(f, "\n}}\n-----") } } @@ -73,28 +73,29 @@ impl MerkleTree { /// returns the value at the given key pub fn get(&self, key: &Value) -> Result { - let path = keypath(*key); + let path = keypath(self.max_depth, *key)?; let (v, _) = self.root.down(0, self.max_depth, path, None)?; Ok(v) } /// returns a boolean indicating whether the key exists in the tree pub fn contains(&self, key: &Value) -> bool { - let path = keypath(*key); - // TODO once thiserror is added to pod2 + // WIP once thiserror is added to pod2, this method is just like `.get` but returning + // true/false if the error matches the key-non-existing error returned by `down` + // let path = keypath(self.max_depth, *key)?; // match self.root.down(0, self.max_depth, path, None) { // Ok((_, _)) => true, // Err("leaf not found")) => false, // Err(_) => false, // } - todo!(); + unimplemented!(); } /// returns a proof of existence, which proves that the given key exists in /// the tree. It returns the `value` of the leaf at the given `key`, and /// the `MerkleProof`. fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { - let path = keypath(*key); + let path = keypath(self.max_depth, *key)?; let (v, siblings) = self.root.down(0, self.max_depth, path, Some(Vec::new()))?; Ok(( v, @@ -109,12 +110,41 @@ impl MerkleTree { /// returns a proof of non-existence, which proves that the given `key` /// does not exist in the tree fn prove_nonexistence(&self, key: &Value) -> Result { + // note: non-existence of a key can be in 2 cases: + // - the expected leaf does not exist + // - the expected leaf does exist in the tree, but it has a different `key` + // both cases prove that the given key don't exist in the tree. todo!(); } /// verifies an inclusion proof for the given `key` and `value` - fn verify(root: Hash, proof: &MerkleProof, key: &Value, value: &Value) -> Result<()> { - todo!(); + fn verify( + max_depth: usize, + root: Hash, + proof: &MerkleProof, + key: &Value, + value: &Value, + ) -> Result<()> { + if proof.siblings.len() >= max_depth { + return Err(anyhow!("max depth reached")); + } + + let path = keypath(max_depth, *key)?; + let input: Vec = [key.0, value.0].concat(); + let mut h = Hash(PoseidonHash::hash_no_pad(&input).elements); + for (i, sibling) in proof.siblings.iter().enumerate().rev() { + let input: Vec = if path[i] { + [sibling.0, h.0].concat() + } else { + [h.0, sibling.0].concat() + }; + h = Hash(PoseidonHash::hash_no_pad(&input).elements); + } + + if h != root { + return Err(anyhow!("proof of inclusion does not verify")); + } + Ok(()) } /// verifies a non-inclusion proof for the given `key`, that is, the given @@ -140,9 +170,9 @@ impl fmt::Display for Node { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Intermediate(n) => { - write!( + writeln!( f, - "\"{}\" -> {{ \"{}\" \"{}\" }}\n", + "\"{}\" -> {{ \"{}\" \"{}\" }}", n.hash(), n.left.hash(), n.right.hash() @@ -151,11 +181,11 @@ impl fmt::Display for Node { write!(f, "{}", n.right) } Self::Leaf(l) => { - write!(f, "\"{}\" [style=filled]\n", l.hash()); - write!(f, "\"k:{}\\nv:{}\" [style=dashed]\n", l.key, l.value); - write!( + writeln!(f, "\"{}\" [style=filled]", l.hash()); + writeln!(f, "\"k:{}\\nv:{}\" [style=dashed]", l.key, l.value); + writeln!( f, - "\"{}\" -> {{ \"k:{}\\nv:{}\" }}\n", + "\"{}\" -> {{ \"k:{}\\nv:{}\" }}", l.hash(), l.key, l.value, @@ -189,7 +219,8 @@ impl Node { } } - /// the `siblings` parameter is used to store the siblings while going down to the leaf, if the + /// goes down from the current node till finding a leaf or reaching the max_depth. The + /// `siblings` parameter is used to store the siblings while going down to the leaf, if the /// given parameter is set to `None`, then no siblings are stored. In this way, the same method /// `down` can be used by MerkleTree methods `get`, `contains`, `prove` and /// `prove_nonexistence`. @@ -257,11 +288,12 @@ impl Node { // to push both leaves (old-leaf and new-leaf) down the path till their paths // diverge. - // first check that keys of both leafs are different + // first check that keys of both leaves are different // (l=old-leaf, leaf=new-leaf) if l.key == leaf.key { // TODO decide if we want to return an error when trying to add a leaf that - // allready exists, or if we just ignore it + // already exists, or if we just ignore it. For the moment we return the error + // if the key already exists in the leaf. return Err(anyhow!("key already exists")); } let old_leaf = l.clone(); @@ -276,6 +308,9 @@ impl Node { Ok(()) } + /// goes down through a 'virtual' path till finding a divergence. This method is used for when + /// adding a new leaf another already existing leaf is found, so that both leaves (new and old) + /// are pushed down the path till their keys diverge. fn down_till_divergence( &mut self, lvl: usize, @@ -290,7 +325,7 @@ impl Node { if let Node::Intermediate(ref mut n) = self { // let current_node: Intermediate = *self; if old_leaf.path[lvl] != new_leaf.path[lvl] { - // reached divergence in next level, set the leafs as childs at the current node + // reached divergence in next level, set the leaves as children at the current node if new_leaf.path[lvl] { n.left = Box::new(Node::Leaf(old_leaf)); n.right = Box::new(Node::Leaf(new_leaf)); @@ -332,7 +367,6 @@ impl Intermediate { right: Box::new(Node::None), } } - fn compute_hash(&mut self) -> Hash { if self.left.clone().is_empty() && self.right.clone().is_empty() { self.hash = Some(NULL); @@ -358,16 +392,14 @@ struct Leaf { value: Value, } impl Leaf { - fn new(key: Value, value: Value) -> Self { - Self { + fn new(max_depth: usize, key: Value, value: Value) -> Result { + Ok(Self { hash: None, - path: keypath(key), + path: keypath(max_depth, key)?, key, value, - } + }) } -} -impl Leaf { fn compute_hash(&mut self) -> Hash { let input: Vec = [self.key.0, self.value.0].concat(); let h = Hash(PoseidonHash::hash_no_pad(&input).elements); @@ -379,16 +411,26 @@ impl Leaf { } } -// TODO 1: think if maybe the length of the returned vector can be <256 (8*bytes.len()), so that +// NOTE 1: think if maybe the length of the returned vector can be <256 (8*bytes.len()), so that // we can do fewer iterations. For example, if the tree.max_depth is set to 20, we just need 20 // iterations of the loop, not 256. -// TODO 2: which approach do we take with keys that are longer than the max-depth? ie, what +// NOTE 2: which approach do we take with keys that are longer than the max-depth? ie, what // happens when two keys share the same path for more bits than the max_depth? -fn keypath(k: Value) -> Vec { +/// returns the path of the given key +fn keypath(max_depth: usize, k: Value) -> Result> { let bytes = k.to_bytes(); - (0..8 * bytes.len()) + if max_depth > 8 * bytes.len() { + // note that our current keys are of Value type, which are 4 Goldilocks field elements, ie + // ~256 bits, therefore the max_depth can not be bigger than 256. + return Err(anyhow!( + "key to short (key length: {}) for the max_depth: {}", + 8 * bytes.len(), + max_depth + )); + } + Ok((0..max_depth) .map(|n| bytes[n / 8] & (1 << (n % 8)) != 0) - .collect() + .collect()) } #[cfg(test)] @@ -405,10 +447,12 @@ pub mod tests { } kvs.insert(Value::from(i), Value::from(1000 + i)); } - kvs.insert(Value::from(13), Value::from(1013)); + let key = Value::from(13); + let value = Value::from(1013); + kvs.insert(key, value); - let tree = MerkleTree::new(32, &kvs); - // it should print the same tree as in + let tree = MerkleTree::new(32, &kvs)?; + // when printing the tree, it should print the same tree as in // https://0xparc.github.io/pod2/merkletree.html#example-2 println!("{}", tree); @@ -416,6 +460,8 @@ pub mod tests { assert_eq!(v, Value::from(1013)); println!("{}", proof); + MerkleTree::verify(32, tree.root(), &proof, &key, &value)?; + Ok(()) } } From 14965ba9490072856ac8e63cd5f4775a5ecbbfdb Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 19 Feb 2025 16:11:51 +0100 Subject: [PATCH 09/12] mt: update 'down' method to reuse siblings, update get,contains,prove methods (the three use 'down' under the hood) --- src/primitives/merkletree_new.rs | 53 ++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs index 2bb1135f..b04fc82a 100644 --- a/src/primitives/merkletree_new.rs +++ b/src/primitives/merkletree_new.rs @@ -74,21 +74,26 @@ impl MerkleTree { /// returns the value at the given key pub fn get(&self, key: &Value) -> Result { let path = keypath(self.max_depth, *key)?; - let (v, _) = self.root.down(0, self.max_depth, path, None)?; + let (k, v) = self.root.down(0, self.max_depth, path, None)?; + if &k != key { + return Err(anyhow!("key not found")); + } Ok(v) } /// returns a boolean indicating whether the key exists in the tree - pub fn contains(&self, key: &Value) -> bool { - // WIP once thiserror is added to pod2, this method is just like `.get` but returning - // true/false if the error matches the key-non-existing error returned by `down` - // let path = keypath(self.max_depth, *key)?; - // match self.root.down(0, self.max_depth, path, None) { - // Ok((_, _)) => true, - // Err("leaf not found")) => false, - // Err(_) => false, - // } - unimplemented!(); + pub fn contains(&self, key: &Value) -> Result { + let path = keypath(self.max_depth, *key)?; + match self.root.down(0, self.max_depth, path, None) { + Ok((k, _)) => { + if &k == key { + Ok(true) + } else { + Ok(false) + } + } + Err(_) => Ok(false), + } } /// returns a proof of existence, which proves that the given key exists in @@ -96,13 +101,21 @@ impl MerkleTree { /// the `MerkleProof`. fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { let path = keypath(self.max_depth, *key)?; - let (v, siblings) = self.root.down(0, self.max_depth, path, Some(Vec::new()))?; + + let mut siblings: Vec = Vec::new(); + let (k, v) = self + .root + .down(0, self.max_depth, path, Some(&mut siblings))?; + + if &k != key { + return Err(anyhow!("key not found")); + } + Ok(( v, MerkleProof { existence: true, - // `unwrap` is safe since we've called `down` passing a vector - siblings: siblings.unwrap(), + siblings, }, )) } @@ -224,13 +237,15 @@ impl Node { /// given parameter is set to `None`, then no siblings are stored. In this way, the same method /// `down` can be used by MerkleTree methods `get`, `contains`, `prove` and /// `prove_nonexistence`. + /// Be aware that this method will return the found leaf at the given path, which may contain a + /// different key and value than the expected one. fn down( &self, lvl: usize, max_depth: usize, path: Vec, - mut siblings: Option>, - ) -> Result<(Value, Option>)> { + mut siblings: Option<&mut Vec>, + ) -> Result<(Value, Value)> { if lvl >= max_depth { return Err(anyhow!("max depth reached")); } @@ -238,19 +253,19 @@ impl Node { match self { Self::Intermediate(n) => { if path[lvl] { - if let Some(ref mut s) = siblings { + if let Some(s) = siblings.as_mut() { s.push(n.left.hash()); } return n.right.down(lvl + 1, max_depth, path, siblings); } else { - if let Some(ref mut s) = siblings { + if let Some(s) = siblings.as_mut() { s.push(n.right.hash()); } return n.left.down(lvl + 1, max_depth, path, siblings); } } Self::Leaf(l) => { - return Ok((l.value, siblings)); + return Ok((l.key, l.value)); } Self::None => { return Err(anyhow!("leaf not found")); From e39bba4607893623048ac0eea125b3f717e861d6 Mon Sep 17 00:00:00 2001 From: Ahmad Date: Thu, 20 Feb 2025 16:33:20 +1000 Subject: [PATCH 10/12] Add nonexistence proofs and iterator --- src/primitives/merkletree_new.rs | 229 ++++++++++++++++++++++--------- 1 file changed, 166 insertions(+), 63 deletions(-) diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs index b04fc82a..a294d8dc 100644 --- a/src/primitives/merkletree_new.rs +++ b/src/primitives/merkletree_new.rs @@ -4,6 +4,7 @@ // place it in the merkletree.rs file. use anyhow::{anyhow, Result}; use itertools::Itertools; +use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::Field; use plonky2::hash::{hash_types::HashOut, poseidon::PoseidonHash}; use plonky2::plonk::config::GenericConfig; @@ -32,6 +33,11 @@ impl MerkleTree { let _ = root.compute_hash(); Ok(Self { max_depth, root }) } + + /// returns an iterator over the leaves of the tree + pub fn iter(&self) -> Iter { + Iter { state: vec![&self.root] } + } } impl fmt::Display for MerkleTree { @@ -53,6 +59,35 @@ pub struct MerkleProof { siblings: Vec, } +impl MerkleProof { + /// Computes the root of the Merkle tree suggested by a Merkle proof. + /// If a value is not provided, the terminal node is assumed to be empty. + pub fn compute_root( + &self, + max_depth: usize, + key: &Value, + value: Option<&Value>, + ) -> Result { + if self.siblings.len() >= max_depth { + return Err(anyhow!("max depth reached")); + } + + let path = keypath(max_depth, *key)?; + let mut h = value + .map(|v| Hash(PoseidonHash::hash_no_pad(&[key.0, v.0].concat()).elements)) + .unwrap_or(Hash([GoldilocksField(0); 4])); + for (i, sibling) in self.siblings.iter().enumerate().rev() { + let input: Vec = if path[i] { + [sibling.0, h.0].concat() + } else { + [h.0, sibling.0].concat() + }; + h = Hash(PoseidonHash::hash_no_pad(&input).elements); + } + Ok(h) + } +} + impl fmt::Display for MerkleProof { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for (i, s) in self.siblings.iter().enumerate() { @@ -74,25 +109,25 @@ impl MerkleTree { /// returns the value at the given key pub fn get(&self, key: &Value) -> Result { let path = keypath(self.max_depth, *key)?; - let (k, v) = self.root.down(0, self.max_depth, path, None)?; - if &k != key { - return Err(anyhow!("key not found")); + let key_resolution = self.root.down(0, self.max_depth, path, None)?; + match key_resolution { + Some((k, v)) if &k == key => Ok(v), + _ => Err(anyhow!("key not found")), } - Ok(v) } /// returns a boolean indicating whether the key exists in the tree pub fn contains(&self, key: &Value) -> Result { let path = keypath(self.max_depth, *key)?; match self.root.down(0, self.max_depth, path, None) { - Ok((k, _)) => { + Ok(Some((k, _))) => { if &k == key { Ok(true) } else { Ok(false) } } - Err(_) => Ok(false), + _ => Ok(false), } } @@ -103,31 +138,55 @@ impl MerkleTree { let path = keypath(self.max_depth, *key)?; let mut siblings: Vec = Vec::new(); - let (k, v) = self - .root - .down(0, self.max_depth, path, Some(&mut siblings))?; - if &k != key { - return Err(anyhow!("key not found")); + match self + .root + .down(0, self.max_depth, path, Some(&mut siblings))? + { + Some((k, v)) if &k == key => Ok(( + v, + MerkleProof { + existence: true, + siblings, + }, + )), + _ => Err(anyhow!("key not found")), } - - Ok(( - v, - MerkleProof { - existence: true, - siblings, - }, - )) } - /// returns a proof of non-existence, which proves that the given `key` - /// does not exist in the tree - fn prove_nonexistence(&self, key: &Value) -> Result { + /// returns a proof of non-existence, which proves that the given + /// `key` does not exist in the tree. The return value specifies + /// the key-value pair in the leaf reached as a result of + /// resolving `key` as well as a `MerkleProof`. + fn prove_nonexistence(&self, key: &Value) -> Result<(Option<(Value, Value)>, MerkleProof)> { + let path = keypath(self.max_depth, *key)?; + + let mut siblings: Vec = Vec::new(); + // note: non-existence of a key can be in 2 cases: - // - the expected leaf does not exist - // - the expected leaf does exist in the tree, but it has a different `key` - // both cases prove that the given key don't exist in the tree. - todo!(); + match self + .root + .down(0, self.max_depth, path, Some(&mut siblings))? + { + // - the expected leaf does not exist + None => Ok(( + None, + MerkleProof { + existence: false, + siblings, + }, + )), + // - the expected leaf does exist in the tree, but it has a different `key` + Some((k, v)) if &k != key => Ok(( + Some((k, v)), + MerkleProof { + existence: false, + siblings, + }, + )), + _ => Err(anyhow!("key found")), + } + // both cases prove that the given key don't exist in the tree. ∎ } /// verifies an inclusion proof for the given `key` and `value` @@ -138,38 +197,40 @@ impl MerkleTree { key: &Value, value: &Value, ) -> Result<()> { - if proof.siblings.len() >= max_depth { - return Err(anyhow!("max depth reached")); - } - - let path = keypath(max_depth, *key)?; - let input: Vec = [key.0, value.0].concat(); - let mut h = Hash(PoseidonHash::hash_no_pad(&input).elements); - for (i, sibling) in proof.siblings.iter().enumerate().rev() { - let input: Vec = if path[i] { - [sibling.0, h.0].concat() - } else { - [h.0, sibling.0].concat() - }; - h = Hash(PoseidonHash::hash_no_pad(&input).elements); - } + let h = proof.compute_root(max_depth, key, Some(value))?; if h != root { return Err(anyhow!("proof of inclusion does not verify")); + } else { + Ok(()) } - Ok(()) } /// verifies a non-inclusion proof for the given `key`, that is, the given /// `key` does not exist in the tree - fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Value) -> Result<()> { - todo!(); + fn verify_nonexistence( + max_depth: usize, + root: Hash, + proof: &MerkleProof, + key: &Value, + other_leaf: Option<&(Value, Value)>, + ) -> Result<()> { + match other_leaf { + Some((k, v)) if k == key => Err(anyhow!("Invalid non-existence proof.")), + _ => { + let k = other_leaf.map(|(k, _)| k).unwrap_or(key); + let v = other_leaf.map(|(_, v)| v); + let h = proof.compute_root(max_depth, k, v)?; + + if h != root { + return Err(anyhow!("proof of exclusion does not verify")); + } else { + Ok(()) + } + } + } } - /// returns an iterator over the leaves of the tree - fn iter(&self) -> std::collections::hash_map::Iter { - todo!(); - } } #[derive(Clone, Debug)] @@ -232,20 +293,24 @@ impl Node { } } - /// goes down from the current node till finding a leaf or reaching the max_depth. The - /// `siblings` parameter is used to store the siblings while going down to the leaf, if the - /// given parameter is set to `None`, then no siblings are stored. In this way, the same method - /// `down` can be used by MerkleTree methods `get`, `contains`, `prove` and - /// `prove_nonexistence`. - /// Be aware that this method will return the found leaf at the given path, which may contain a - /// different key and value than the expected one. + /// Goes down from the current node until it encounters a terminal + /// node, viz. a leaf or empty node, or until it reaches the + /// maximum depth. The `siblings` parameter is used to store the + /// siblings while going down to the leaf, if the given parameter + /// is set to `None`, then no siblings are stored. In this way, + /// the same method `down` can be used by MerkleTree methods + /// `get`, `contains`, `prove` and `prove_nonexistence`. + /// + /// Be aware that this method will return the found leaf at the + /// given path, which may contain a different key and value than + /// the expected one. fn down( &self, lvl: usize, max_depth: usize, path: Vec, mut siblings: Option<&mut Vec>, - ) -> Result<(Value, Value)> { + ) -> Result> { if lvl >= max_depth { return Err(anyhow!("max depth reached")); } @@ -264,14 +329,14 @@ impl Node { return n.left.down(lvl + 1, max_depth, path, siblings); } } - Self::Leaf(l) => { - return Ok((l.key, l.value)); - } - Self::None => { - return Err(anyhow!("leaf not found")); - } + Self::Leaf(Leaf { + key, + value, + path: _p, + hash: _h, + }) => Ok(Some((key.clone(), value.clone()))), + _ => Ok(None), } - Err(anyhow!("leaf not found")) } // adds the leaf at the tree from the current node (self), without computing any hash @@ -448,6 +513,28 @@ fn keypath(max_depth: usize, k: Value) -> Result> { .collect()) } +pub struct Iter<'a> { + state: Vec<&'a Node> +} + +impl<'a> Iterator for Iter<'a> { + type Item = (&'a Value, &'a Value); + + fn next(&mut self) -> Option { + let node = self.state.pop(); + match node { + Some(Node::None) => self.next(), + Some(Node::Leaf(Leaf { hash: _, path: _, key, value })) => Some((key,value)), + Some(Node::Intermediate(Intermediate { hash: _, left, right })) => { + self.state.push(&right); + self.state.push(&left); + self.next() + }, + _ => None + } + } +} + #[cfg(test)] pub mod tests { use super::*; @@ -471,12 +558,28 @@ pub mod tests { // https://0xparc.github.io/pod2/merkletree.html#example-2 println!("{}", tree); + // Inclusion checks let (v, proof) = tree.prove(&Value::from(13))?; assert_eq!(v, Value::from(1013)); println!("{}", proof); MerkleTree::verify(32, tree.root(), &proof, &key, &value)?; + // Exclusion checks + let key = Value::from(12); + let (other_leaf, proof) = tree.prove_nonexistence(&key)?; + assert_eq!(other_leaf, Some((Value::from(4), Value::from(1004)))); + println!("{}", proof); + + MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key, other_leaf.as_ref())?; + + let key = Value::from(1); + let (other_leaf, proof) = tree.prove_nonexistence(&Value::from(1))?; + assert_eq!(other_leaf, None); + println!("{}", proof); + + MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key, other_leaf.as_ref())?; + Ok(()) } } From d3221861425c2138d06857cc8b860c199377d673 Mon Sep 17 00:00:00 2001 From: Ahmad Date: Thu, 20 Feb 2025 20:37:31 +1000 Subject: [PATCH 11/12] Add iterator test --- src/primitives/merkletree_new.rs | 66 ++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs index a294d8dc..24b2c905 100644 --- a/src/primitives/merkletree_new.rs +++ b/src/primitives/merkletree_new.rs @@ -36,7 +36,9 @@ impl MerkleTree { /// returns an iterator over the leaves of the tree pub fn iter(&self) -> Iter { - Iter { state: vec![&self.root] } + Iter { + state: vec![&self.root], + } } } @@ -53,6 +55,15 @@ impl fmt::Display for MerkleTree { } } +impl<'a> IntoIterator for &'a MerkleTree { + type Item = (&'a Value, &'a Value); + type IntoIter = Iter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + #[derive(Clone, Debug)] pub struct MerkleProof { existence: bool, @@ -230,7 +241,6 @@ impl MerkleTree { } } } - } #[derive(Clone, Debug)] @@ -514,7 +524,7 @@ fn keypath(max_depth: usize, k: Value) -> Result> { } pub struct Iter<'a> { - state: Vec<&'a Node> + state: Vec<&'a Node>, } impl<'a> Iterator for Iter<'a> { @@ -524,19 +534,30 @@ impl<'a> Iterator for Iter<'a> { let node = self.state.pop(); match node { Some(Node::None) => self.next(), - Some(Node::Leaf(Leaf { hash: _, path: _, key, value })) => Some((key,value)), - Some(Node::Intermediate(Intermediate { hash: _, left, right })) => { + Some(Node::Leaf(Leaf { + hash: _, + path: _, + key, + value, + })) => Some((key, value)), + Some(Node::Intermediate(Intermediate { + hash: _, + left, + right, + })) => { self.state.push(&right); self.state.push(&left); self.next() - }, - _ => None + } + _ => None, } } } #[cfg(test)] pub mod tests { + use std::cmp::Ordering; + use super::*; use crate::middleware::hash_str; @@ -580,6 +601,37 @@ pub mod tests { MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key, other_leaf.as_ref())?; + // Check iterator + let collected_kvs: Vec<_> = tree.into_iter().collect::>(); + + // Expected key ordering + let cmp = |max_depth: usize| { + move |k1, k2| { + let path1 = keypath(max_depth, k1).unwrap(); + let path2 = keypath(max_depth, k2).unwrap(); + + let first_unequal_bits = std::iter::zip(path1, path2).find(|(b1, b2)| b1 != b2); + + match first_unequal_bits { + Some((b1, b2)) => { + if b1 < b2 { + Ordering::Less + } else { + Ordering::Greater + } + } + _ => Ordering::Equal, + } + } + }; + + let sorted_kvs = kvs + .iter() + .sorted_by(|(k1, _), (k2, _)| cmp(32)(**k1, **k2)) + .collect::>(); + + assert_eq!(collected_kvs, sorted_kvs); + Ok(()) } } From 7a22f521a22e968c84b8558233402850c9fac257 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Thu, 20 Feb 2025 17:02:39 +0100 Subject: [PATCH 12/12] migrate usage of old merkletree to the new merkletree impl in POD2 code --- src/backends/mock_main/mod.rs | 6 +- src/backends/mock_signed.rs | 25 +- src/constants.rs | 1 + src/examples.rs | 23 +- src/frontend/mod.rs | 6 +- src/lib.rs | 1 + src/middleware/containers.rs | 52 +-- src/primitives/merkletree.rs | 724 ++++++++++++++++++++++++------- src/primitives/merkletree_new.rs | 637 --------------------------- src/primitives/mod.rs | 1 - 10 files changed, 640 insertions(+), 836 deletions(-) create mode 100644 src/constants.rs delete mode 100644 src/primitives/merkletree_new.rs diff --git a/src/backends/mock_main/mod.rs b/src/backends/mock_main/mod.rs index cbb0c7dc..4db4359b 100644 --- a/src/backends/mock_main/mod.rs +++ b/src/backends/mock_main/mod.rs @@ -485,7 +485,7 @@ pub mod tests { pk: "ZooDeel".into(), }; let pay_stub_pod = pay_stub_builder.sign(&mut signer)?; - let kyc_builder = zu_kyc_pod_builder(¶ms, &gov_id_pod, &pay_stub_pod); + let kyc_builder = zu_kyc_pod_builder(¶ms, &gov_id_pod, &pay_stub_pod)?; let mut prover = MockProver {}; let kyc_pod = kyc_builder.prove(&mut prover)?; @@ -501,7 +501,7 @@ pub mod tests { #[test] fn test_mock_main_great_boy() -> Result<()> { - let great_boy_builder = great_boy_pod_full_flow(); + let great_boy_builder = great_boy_pod_full_flow()?; let mut prover = MockProver {}; let great_boy_pod = great_boy_builder.prove(&mut prover)?; @@ -520,7 +520,7 @@ pub mod tests { #[test] fn test_mock_main_tickets() -> Result<()> { - let tickets_builder = tickets_pod_full_flow(); + let tickets_builder = tickets_pod_full_flow()?; let mut prover = MockProver {}; let proof_pod = tickets_builder.prove(&mut prover)?; let pod = proof_pod.pod.into_any().downcast::().unwrap(); diff --git a/src/backends/mock_signed.rs b/src/backends/mock_signed.rs index afbaf760..51ce54e6 100644 --- a/src/backends/mock_signed.rs +++ b/src/backends/mock_signed.rs @@ -2,6 +2,7 @@ use anyhow::Result; use std::any::Any; use std::collections::HashMap; +use crate::constants::MAX_DEPTH; use crate::middleware::{ containers::Dictionary, hash_str, AnchoredKey, Hash, Params, Pod, PodId, PodSigner, PodType, Statement, Value, KEY_SIGNER, KEY_TYPE, @@ -19,7 +20,7 @@ impl PodSigner for MockSigner { kvs.insert(hash_str(&KEY_SIGNER), Value(pk_hash.0)); kvs.insert(hash_str(&KEY_TYPE), Value::from(PodType::MockSigned)); - let dict = Dictionary::new(&kvs); + let dict = Dictionary::new(&kvs)?; let id = PodId(dict.commitment()); let signature = format!("{}_signed_by_{}", id, pk_hash); Ok(Box::new(MockSignedPod { @@ -49,13 +50,17 @@ impl Pod for MockSignedPod { } // Verify id - let mt = MerkleTree::new( + let mt = match MerkleTree::new( + MAX_DEPTH, &self .dict .iter() .map(|(&k, &v)| (k, v)) .collect::>(), - ); + ) { + Ok(mt) => mt, + Err(_) => return false, + }; let id = PodId(mt.root()); if id != self.id { return false; @@ -93,14 +98,16 @@ impl Pod for MockSignedPod { #[cfg(test)] pub mod tests { + use plonky2::field::types::Field; + use std::iter; + use super::*; + use crate::constants::MAX_DEPTH; use crate::frontend; use crate::middleware::{self, F, NULL}; - use plonky2::field::types::Field; - use std::iter; #[test] - fn test_mock_signed_0() { + fn test_mock_signed_0() -> Result<()> { let params = middleware::Params::default(); let mut pod = frontend::SignedPodBuilder::new(¶ms); pod.insert("idNumber", "4242424242"); @@ -131,7 +138,7 @@ pub mod tests { .map(|(AnchoredKey(_, k), v)| (Value(k.0), v)) .chain(iter::once(bad_kv)) .collect::>(); - let bad_mt = MerkleTree::new(&bad_kvs_mt); + let bad_mt = MerkleTree::new(MAX_DEPTH, &bad_kvs_mt)?; bad_pod.dict.mt = bad_mt; assert_eq!(bad_pod.verify(), false); @@ -143,8 +150,10 @@ pub mod tests { .map(|(AnchoredKey(_, k), v)| (Value(k.0), v)) .chain(iter::once(bad_kv)) .collect::>(); - let bad_mt = MerkleTree::new(&bad_kvs_mt); + let bad_mt = MerkleTree::new(MAX_DEPTH, &bad_kvs_mt)?; bad_pod.dict.mt = bad_mt; assert_eq!(bad_pod.verify(), false); + + Ok(()) } } diff --git a/src/constants.rs b/src/constants.rs new file mode 100644 index 00000000..1b5be372 --- /dev/null +++ b/src/constants.rs @@ -0,0 +1 @@ +pub const MAX_DEPTH: usize = 32; diff --git a/src/examples.rs b/src/examples.rs index b617aaea..818807a0 100644 --- a/src/examples.rs +++ b/src/examples.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use std::collections::HashMap; use crate::backends::mock_signed::MockSigner; @@ -24,8 +25,8 @@ pub fn zu_kyc_pod_builder( params: &Params, gov_id: &SignedPod, pay_stub: &SignedPod, -) -> MainPodBuilder { - let sanction_list = Value::Dictionary(Dictionary::new(&HashMap::new())); // empty dictionary +) -> Result { + let sanction_list = Value::Dictionary(Dictionary::new(&HashMap::new())?); // empty dictionary let now_minus_18y: i64 = 1169909388; let now_minus_1y: i64 = 1706367566; @@ -41,7 +42,7 @@ pub fn zu_kyc_pod_builder( )); kyc.pub_op(op!(eq, (pay_stub, "startDate"), now_minus_1y)); - kyc + Ok(kyc) } // GreatBoy @@ -130,7 +131,7 @@ pub fn great_boy_pod_builder( great_boy } -pub fn great_boy_pod_full_flow() -> MainPodBuilder { +pub fn great_boy_pod_full_flow() -> Result { let params = Params { max_input_signed_pods: 6, max_statements: 100, @@ -179,8 +180,8 @@ pub fn great_boy_pod_full_flow() -> MainPodBuilder { alice_friend_pods.push(friend.sign(&mut bob_signer).unwrap()); alice_friend_pods.push(friend.sign(&mut charlie_signer).unwrap()); - let good_boy_issuers_dict = Value::Dictionary(Dictionary::new(&HashMap::new())); // empty - great_boy_pod_builder( + let good_boy_issuers_dict = Value::Dictionary(Dictionary::new(&HashMap::new())?); // empty + Ok(great_boy_pod_builder( ¶ms, [ &bob_good_boys[0], @@ -191,7 +192,7 @@ pub fn great_boy_pod_full_flow() -> MainPodBuilder { [&alice_friend_pods[0], &alice_friend_pods[1]], &good_boy_issuers_dict, alice, - ) + )) } // Tickets @@ -229,15 +230,15 @@ pub fn tickets_pod_builder( builder } -pub fn tickets_pod_full_flow() -> MainPodBuilder { +pub fn tickets_pod_full_flow() -> Result { let params = Params::default(); let builder = tickets_sign_pod_builder(¶ms); let signed_pod = builder.sign(&mut MockSigner { pk: "test".into() }).unwrap(); - tickets_pod_builder( + Ok(tickets_pod_builder( ¶ms, &signed_pod, 123, true, - &Value::Dictionary(Dictionary::new(&HashMap::new())), - ) + &Value::Dictionary(Dictionary::new(&HashMap::new())?), + )) } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 085b093c..f1335cbe 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -510,7 +510,7 @@ pub mod tests { let pay_stub = pay_stub.sign(&mut signer).unwrap(); println!("{}", pay_stub); - let kyc = zu_kyc_pod_builder(¶ms, &gov_id, &pay_stub); + let kyc = zu_kyc_pod_builder(¶ms, &gov_id, &pay_stub)?; println!("{}", kyc); // TODO: prove kyc with MockProver and print it @@ -520,7 +520,7 @@ pub mod tests { #[test] fn test_front_great_boy() -> Result<()> { - let great_boy = great_boy_pod_full_flow(); + let great_boy = great_boy_pod_full_flow()?; println!("{}", great_boy); // TODO: prove kyc with MockProver and print it @@ -530,7 +530,7 @@ pub mod tests { #[test] fn test_front_tickets() -> Result<()> { - let builder = tickets_pod_full_flow(); + let builder = tickets_pod_full_flow()?; println!("{}", builder); Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 95245d9f..963de0b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod backends; +pub mod constants; pub mod frontend; pub mod middleware; pub mod primitives; diff --git a/src/middleware/containers.rs b/src/middleware/containers.rs index 43204b43..743402db 100644 --- a/src/middleware/containers.rs +++ b/src/middleware/containers.rs @@ -6,6 +6,7 @@ use plonky2::plonk::config::Hasher; use std::collections::HashMap; use super::{Hash, Value, EMPTY}; +use crate::constants::MAX_DEPTH; use crate::primitives::merkletree::{MerkleProof, MerkleTree}; /// Dictionary: the user original keys and values are hashed to be used in the leaf. @@ -18,11 +19,11 @@ pub struct Dictionary { } impl Dictionary { - pub fn new(kvs: &HashMap) -> Self { + pub fn new(kvs: &HashMap) -> Result { let kvs: HashMap = kvs.into_iter().map(|(&k, &v)| (Value(k.0), v)).collect(); - Self { - mt: MerkleTree::new(&kvs), - } + Ok(Self { + mt: MerkleTree::new(MAX_DEPTH, &kvs)?, + }) } pub fn commitment(&self) -> Hash { self.mt.root() @@ -30,25 +31,25 @@ impl Dictionary { pub fn get(&self, key: &Value) -> Result { self.mt.get(key) } - pub fn prove(&self, key: &Value) -> Result { + pub fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { self.mt.prove(key) } pub fn prove_nonexistence(&self, key: &Value) -> Result { self.mt.prove_nonexistence(key) } pub fn verify(root: Hash, proof: &MerkleProof, key: &Value, value: &Value) -> Result<()> { - MerkleTree::verify(root, proof, key, value) + MerkleTree::verify(MAX_DEPTH, root, proof, key, value) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, key: &Value) -> Result<()> { - MerkleTree::verify_nonexistence(root, proof, key) + MerkleTree::verify_nonexistence(MAX_DEPTH, root, proof, key) } - pub fn iter(&self) -> std::collections::hash_map::Iter { + pub fn iter(&self) -> crate::primitives::merkletree::Iter { self.mt.iter() } } impl<'a> IntoIterator for &'a Dictionary { type Item = (&'a Value, &'a Value); - type IntoIter = std::collections::hash_map::Iter<'a, Value, Value>; + type IntoIter = crate::primitives::merkletree::Iter<'a>; fn into_iter(self) -> Self::IntoIter { self.mt.iter() @@ -71,7 +72,7 @@ pub struct Set { } impl Set { - pub fn new(set: &Vec) -> Self { + pub fn new(set: &Vec) -> Result { let kvs: HashMap = set .into_iter() .map(|e| { @@ -79,29 +80,30 @@ impl Set { (Value(h), EMPTY) }) .collect(); - Self { - mt: MerkleTree::new(&kvs), - } + Ok(Self { + mt: MerkleTree::new(MAX_DEPTH, &kvs)?, + }) } pub fn commitment(&self) -> Hash { self.mt.root() } - pub fn contains(&self, value: &Value) -> bool { + pub fn contains(&self, value: &Value) -> Result { self.mt.contains(value) } pub fn prove(&self, value: &Value) -> Result { - self.mt.prove(value) + let (_, proof) = self.mt.prove(value)?; + Ok(proof) } pub fn prove_nonexistence(&self, value: &Value) -> Result { self.mt.prove_nonexistence(value) } pub fn verify(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - MerkleTree::verify(root, proof, value, &EMPTY) + MerkleTree::verify(MAX_DEPTH, root, proof, value, &EMPTY) } pub fn verify_nonexistence(root: Hash, proof: &MerkleProof, value: &Value) -> Result<()> { - MerkleTree::verify_nonexistence(root, proof, value) + MerkleTree::verify_nonexistence(MAX_DEPTH, root, proof, value) } - pub fn iter(&self) -> std::collections::hash_map::Iter { + pub fn iter(&self) -> crate::primitives::merkletree::Iter { self.mt.iter() } } @@ -123,16 +125,16 @@ pub struct Array { } impl Array { - pub fn new(array: &Vec) -> Self { + pub fn new(array: &Vec) -> Result { let kvs: HashMap = array .into_iter() .enumerate() .map(|(i, &e)| (Value::from(i as i64), e)) .collect(); - Self { - mt: MerkleTree::new(&kvs), - } + Ok(Self { + mt: MerkleTree::new(MAX_DEPTH, &kvs)?, + }) } pub fn commitment(&self) -> Hash { self.mt.root() @@ -140,13 +142,13 @@ impl Array { pub fn get(&self, i: usize) -> Result { self.mt.get(&Value::from(i as i64)) } - pub fn prove(&self, i: usize) -> Result { + pub fn prove(&self, i: usize) -> Result<(Value, MerkleProof)> { self.mt.prove(&Value::from(i as i64)) } pub fn verify(root: Hash, proof: &MerkleProof, i: usize, value: &Value) -> Result<()> { - MerkleTree::verify(root, proof, &Value::from(i as i64), value) + MerkleTree::verify(MAX_DEPTH, root, proof, &Value::from(i as i64), value) } - pub fn iter(&self) -> std::collections::hash_map::Iter { + pub fn iter(&self) -> crate::primitives::merkletree::Iter { self.mt.iter() } } diff --git a/src/primitives/merkletree.rs b/src/primitives/merkletree.rs index 8a393af6..9a214478 100644 --- a/src/primitives/merkletree.rs +++ b/src/primitives/merkletree.rs @@ -1,214 +1,642 @@ -/// MerkleTree implementation for POD2. -/// -/// Current implementation is a wrapper on top of Plonky2's MerkleTree, but the future iteration -/// will replace it by the MerkleTree specified at https://0xparc.github.io/pod2/merkletree.html . +//! Module that implements the MerkleTree specified at +//! https://0xparc.github.io/pod2/merkletree.html . use anyhow::{anyhow, Result}; -use itertools::Itertools; -use plonky2::field::types::Field; -use plonky2::hash::{ - hash_types::HashOut, - merkle_proofs::{verify_merkle_proof, MerkleProof as PlonkyMerkleProof}, - merkle_tree::MerkleTree as PlonkyMerkleTree, - poseidon::PoseidonHash, -}; -use plonky2::plonk::config::GenericConfig; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::Hasher; use std::collections::HashMap; +use std::fmt; use std::iter::IntoIterator; -use crate::middleware::{Hash, Value, C, D, F}; +use crate::middleware::{Hash, Value, F, NULL}; -const CAP_HEIGHT: usize = 0; - -/// MerkleTree currently is a wrapper on top of Plonky2's MerkleTree. A future iteration will -/// replace it by the MerkleTree specified at https://0xparc.github.io/pod2/merkletree.html . +/// Implements the MerkleTree specified at +/// https://0xparc.github.io/pod2/merkletree.html #[derive(Clone, Debug)] pub struct MerkleTree { - tree: PlonkyMerkleTree>::Hasher>, - // keyindex: key -> index mapping. This is just for the current plonky-tree wrapper - keyindex: HashMap, - // kvs are a field in the MerkleTree in order to be able to iterate over the keyvalues. This is - // specific of the current implementation (Plonky2's tree wrapper), in the next iteration this - // will not be needed since the tree implementation itself will offer the hashmap - // functionality. - pub kvs: HashMap, - // leaves_map is a map between the leaf (leaf=Hash(key,value)) and the actual (key, value). It - // is used to get the actual value from a leaf for a given key (through the method - // `MerkleTree.get`. - leaves_map: HashMap, -} - -pub struct MerkleProof { - existence: bool, - index: usize, - proof: PlonkyMerkleProof>::Hasher>, + max_depth: usize, + root: Node, } impl MerkleTree { /// builds a new `MerkleTree` where the leaves contain the given key-values - pub fn new(kvs: &HashMap) -> Self { - let mut keyindex: HashMap = HashMap::new(); - let mut leaves: Vec> = Vec::new(); - let mut leaves_map: HashMap = HashMap::new(); - // Note: current version iterates sorting by keys of the kvs, but the merkletree defined at - // https://0xparc.github.io/pod2/merkletree.html will not need it since it will be - // deterministic based on the keys values not on the order of the keys when added into the - // tree. - for (i, (k, v)) in kvs.iter().sorted_by_key(|kv| kv.0).enumerate() { - let input: Vec = [k.0, v.0].concat(); - let leaf = PoseidonHash::hash_no_pad(&input).elements; - leaves.push(leaf.into()); - keyindex.insert(*k, i); - leaves_map.insert(Hash(leaf), (*k, *v)); - } - - // pad to a power of two if needed - let leaf_empty: Vec = vec![F::ZERO, F::ZERO, F::ZERO, F::ZERO]; - for _ in leaves.len()..leaves.len().next_power_of_two() { - leaves.push(leaf_empty.clone()); - } - - let tree = PlonkyMerkleTree::>::Hasher>::new(leaves, CAP_HEIGHT); - Self { - tree, - keyindex, - kvs: kvs.clone(), - leaves_map, + pub fn new(max_depth: usize, kvs: &HashMap) -> Result { + let mut root = Node::Intermediate(Intermediate::empty()); + + for (k, v) in kvs.iter() { + let leaf = Leaf::new(max_depth, *k, *v)?; + root.add_leaf(0, max_depth, leaf)?; } + + let _ = root.compute_hash(); + Ok(Self { max_depth, root }) } -} -impl MerkleTree { /// returns the root of the tree pub fn root(&self) -> Hash { - if self.tree.cap.is_empty() { - return crate::middleware::NULL; - } - Hash(self.tree.cap.0[0].elements) + self.root.hash() + } + + pub fn max_depth(&self) -> usize { + self.max_depth } /// returns the value at the given key pub fn get(&self, key: &Value) -> Result { - let i = self.keyindex.get(&key).ok_or(anyhow!("key not in tree"))?; - let leaf_hash_raw = self.tree.get(*i); - let leaf_hash_f: [F; 4] = leaf_hash_raw - .try_into() - .map_err(|_| anyhow!("unexpected length (len!=4)"))?; - let leaf_hash: Hash = Hash(leaf_hash_f); - let (_, value) = self.leaves_map.get(&leaf_hash).unwrap(); - Ok(*value) + let path = keypath(self.max_depth, *key)?; + let key_resolution = self.root.down(0, self.max_depth, path, None)?; + match key_resolution { + Some((k, v)) if &k == key => Ok(v), + _ => Err(anyhow!("key not found")), + } } /// returns a boolean indicating whether the key exists in the tree - pub fn contains(&self, key: &Value) -> bool { - self.keyindex.get(&key).is_some() + pub fn contains(&self, key: &Value) -> Result { + let path = keypath(self.max_depth, *key)?; + match self.root.down(0, self.max_depth, path, None) { + Ok(Some((k, _))) => { + if &k == key { + Ok(true) + } else { + Ok(false) + } + } + _ => Ok(false), + } } /// returns a proof of existence, which proves that the given key exists in - /// the tree. It returns the `MerkleProof`. - pub fn prove(&self, key: &Value) -> Result { - let i = self.keyindex.get(&key).ok_or(anyhow!("key not in tree"))?; - let proof = self.tree.prove(*i); - Ok(MerkleProof { - existence: true, - index: *i, - proof, - }) + /// the tree. It returns the `value` of the leaf at the given `key`, and the + /// `MerkleProof`. + pub fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { + let path = keypath(self.max_depth, *key)?; + + let mut siblings: Vec = Vec::new(); + + match self + .root + .down(0, self.max_depth, path, Some(&mut siblings))? + { + Some((k, v)) if &k == key => Ok(( + v, + MerkleProof { + existence: true, + siblings, + other_leaf: None, + }, + )), + _ => Err(anyhow!("key not found")), + } } - /// returns a proof of non-existence, which proves that the given `key` - /// does not exist in the tree - pub fn prove_nonexistence(&self, _key: &Value) -> Result { - // mock method - println!("WARNING: MerkleTree::verify_nonexistence is currently a mock"); - Ok(MerkleProof { - existence: false, - index: 0, - proof: PlonkyMerkleProof { siblings: vec![] }, - }) + /// returns a proof of non-existence, which proves that the given + /// `key` does not exist in the tree. The return value specifies + /// the key-value pair in the leaf reached as a result of + /// resolving `key` as well as a `MerkleProof`. + pub fn prove_nonexistence(&self, key: &Value) -> Result { + let path = keypath(self.max_depth, *key)?; + + let mut siblings: Vec = Vec::new(); + + // note: non-existence of a key can be in 2 cases: + match self + .root + .down(0, self.max_depth, path, Some(&mut siblings))? + { + // case i) the expected leaf does not exist + None => Ok(MerkleProof { + existence: false, + siblings, + other_leaf: None, + }), + // case ii) the expected leaf does exist in the tree, but it has a different `key` + Some((k, v)) if &k != key => Ok(MerkleProof { + existence: false, + siblings, + other_leaf: Some((k, v)), + }), + _ => Err(anyhow!("key found")), + } + // both cases prove that the given key don't exist in the tree. ∎ } /// verifies an inclusion proof for the given `key` and `value` - pub fn verify(root: Hash, proof: &MerkleProof, key: &Value, value: &Value) -> Result<()> { - if !proof.existence { - return Err(anyhow!( - "expected proof of existence, found proof of non-existence" - )); + pub fn verify( + max_depth: usize, + root: Hash, + proof: &MerkleProof, + key: &Value, + value: &Value, + ) -> Result<()> { + let h = proof.compute_root_from_leaf(max_depth, key, Some(*value))?; + + if h != root { + return Err(anyhow!("proof of inclusion does not verify")); + } else { + Ok(()) } - let leaf = PoseidonHash::hash_no_pad(&[key.0, value.0].concat()).elements; - let root = HashOut::from_vec(root.0.to_vec()); - verify_merkle_proof(leaf.into(), proof.index, root, &proof.proof) } /// verifies a non-inclusion proof for the given `key`, that is, the given /// `key` does not exist in the tree - pub fn verify_nonexistence(_root: Hash, proof: &MerkleProof, _key: &Value) -> Result<()> { - // mock method - if proof.existence { - return Err(anyhow!( - "expected proof of non-existence, found proof of existence" - )); - } - println!("WARNING: MerkleTree::verify_nonexistence is currently a mock"); - Ok(()) + pub fn verify_nonexistence( + max_depth: usize, + root: Hash, + proof: &MerkleProof, + key: &Value, + ) -> Result<()> { + match proof.other_leaf { + Some((k, _v)) if &k == key => Err(anyhow!("Invalid non-existence proof.")), + _ => { + let k = proof.other_leaf.map(|(k, _)| k).unwrap_or(*key); + let v: Option = proof.other_leaf.map(|(_, v)| v); + let h = proof.compute_root_from_leaf(max_depth, &k, v)?; + + if h != root { + return Err(anyhow!("proof of exclusion does not verify")); + } else { + Ok(()) + } + } + } } /// returns an iterator over the leaves of the tree - pub fn iter(&self) -> std::collections::hash_map::Iter { - self.kvs.iter() + pub fn iter(&self) -> Iter { + Iter { + state: vec![&self.root], + } } } impl<'a> IntoIterator for &'a MerkleTree { type Item = (&'a Value, &'a Value); - type IntoIter = std::collections::hash_map::Iter<'a, Value, Value>; + type IntoIter = Iter<'a>; fn into_iter(self) -> Self::IntoIter { - self.kvs.iter() + self.iter() + } +} + +impl fmt::Display for MerkleTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "\nPaste in GraphViz (https://dreampuf.github.io/GraphvizOnline/):\n-----" + )?; + writeln!(f, "digraph hierarchy {{")?; + writeln!(f, "node [fontname=Monospace,fontsize=10,shape=box]")?; + write!(f, "{}", self.root)?; + writeln!(f, "\n}}\n-----") + } +} + +#[derive(Clone, Debug)] +pub struct MerkleProof { + // note: currently we don't use the `_existence` field, we would use if we merge the methods + // `verify` and `verify_nonexistence` into a single one + #[allow(unused)] + existence: bool, + siblings: Vec, + // other_leaf is used for non-existence proofs + other_leaf: Option<(Value, Value)>, +} + +impl fmt::Display for MerkleProof { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (i, s) in self.siblings.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", s)?; + } + Ok(()) + } +} + +impl MerkleProof { + /// Computes the root of the Merkle tree suggested by a Merkle proof given a + /// key & value. If a value is not provided, the terminal node is assumed to + /// be empty. + fn compute_root_from_leaf( + &self, + max_depth: usize, + key: &Value, + value: Option, + ) -> Result { + if self.siblings.len() >= max_depth { + return Err(anyhow!("max depth reached")); + } + + let path = keypath(max_depth, *key)?; + let mut h = value + .map(|v| Hash(PoseidonHash::hash_no_pad(&[key.0, v.0].concat()).elements)) + .unwrap_or(Hash([GoldilocksField(0); 4])); + for (i, sibling) in self.siblings.iter().enumerate().rev() { + let input: Vec = if path[i] { + [sibling.0, h.0].concat() + } else { + [h.0, sibling.0].concat() + }; + h = Hash(PoseidonHash::hash_no_pad(&input).elements); + } + Ok(h) + } +} + +#[derive(Clone, Debug)] +enum Node { + None, + Leaf(Leaf), + Intermediate(Intermediate), +} + +impl fmt::Display for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Intermediate(n) => { + writeln!( + f, + "\"{}\" -> {{ \"{}\" \"{}\" }}", + n.hash(), + n.left.hash(), + n.right.hash() + )?; + write!(f, "{}", n.left)?; + write!(f, "{}", n.right) + } + Self::Leaf(l) => { + writeln!(f, "\"{}\" [style=filled]", l.hash())?; + writeln!(f, "\"k:{}\\nv:{}\" [style=dashed]", l.key, l.value)?; + writeln!( + f, + "\"{}\" -> {{ \"k:{}\\nv:{}\" }}", + l.hash(), + l.key, + l.value, + ) + } + Self::None => Ok(()), + } + } +} + +impl Node { + fn is_empty(&self) -> bool { + match self { + Self::None => true, + Self::Leaf(_l) => false, + Self::Intermediate(_n) => false, + } + } + fn compute_hash(&mut self) -> Hash { + match self { + Self::None => NULL, + Self::Leaf(l) => l.compute_hash(), + Self::Intermediate(n) => n.compute_hash(), + } + } + fn hash(&self) -> Hash { + match self { + Self::None => NULL, + Self::Leaf(l) => l.hash(), + Self::Intermediate(n) => n.hash(), + } + } + + /// Goes down from the current node until it encounters a terminal node, + /// viz. a leaf or empty node, or until it reaches the maximum depth. The + /// `siblings` parameter is used to store the siblings while going down to + /// the leaf, if the given parameter is set to `None`, then no siblings are + /// stored. In this way, the same method `down` can be used by MerkleTree + /// methods `get`, `contains`, `prove` and `prove_nonexistence`. + /// + /// Be aware that this method will return the found leaf at the given path, + /// which may contain a different key and value than the expected one. + fn down( + &self, + lvl: usize, + max_depth: usize, + path: Vec, + mut siblings: Option<&mut Vec>, + ) -> Result> { + if lvl >= max_depth { + return Err(anyhow!("max depth reached")); + } + + match self { + Self::Intermediate(n) => { + if path[lvl] { + if let Some(s) = siblings.as_mut() { + s.push(n.left.hash()); + } + return n.right.down(lvl + 1, max_depth, path, siblings); + } else { + if let Some(s) = siblings.as_mut() { + s.push(n.right.hash()); + } + return n.left.down(lvl + 1, max_depth, path, siblings); + } + } + Self::Leaf(Leaf { + key, + value, + path: _p, + hash: _h, + }) => Ok(Some((key.clone(), value.clone()))), + _ => Ok(None), + } + } + + // adds the leaf at the tree from the current node (self), without computing any hash + fn add_leaf(&mut self, lvl: usize, max_depth: usize, leaf: Leaf) -> Result<()> { + if lvl >= max_depth { + return Err(anyhow!("max depth reached")); + } + + match self { + Self::Intermediate(n) => { + if leaf.path[lvl] { + if n.right.is_empty() { + // empty sub-node, add the leaf here + n.right = Box::new(Node::Leaf(leaf)); + return Ok(()); + } + n.right.add_leaf(lvl + 1, max_depth, leaf)?; + } else { + if n.left.is_empty() { + // empty sub-node, add the leaf here + n.left = Box::new(Node::Leaf(leaf)); + return Ok(()); + } + n.left.add_leaf(lvl + 1, max_depth, leaf)?; + } + } + Self::Leaf(l) => { + // in this case, it means that we found a leaf in the new-leaf + // path, thus we need to push both leaves (old-leaf and + // new-leaf) down the path till their paths diverge. + + // first check that keys of both leaves are different + // (l=old-leaf, leaf=new-leaf) + if l.key == leaf.key { + // Note: current approach returns an error when trying to + // add to a leaf where the key already exists. We could also + // ignore it if needed. + return Err(anyhow!("key already exists")); + } + let old_leaf = l.clone(); + // set self as an intermediate node + *self = Node::Intermediate(Intermediate::empty()); + return self.down_till_divergence(lvl, max_depth, old_leaf, leaf); + } + Self::None => { + return Err(anyhow!("reached empty node, should not have entered")); + } + } + Ok(()) + } + + /// goes down through a 'virtual' path till finding a divergence. This + /// method is used for when adding a new leaf another already existing leaf + /// is found, so that both leaves (new and old) are pushed down the path + /// till their keys diverge. + fn down_till_divergence( + &mut self, + lvl: usize, + max_depth: usize, + old_leaf: Leaf, + new_leaf: Leaf, + ) -> Result<()> { + if lvl >= max_depth { + return Err(anyhow!("max depth reached")); + } + + if let Node::Intermediate(ref mut n) = self { + if old_leaf.path[lvl] != new_leaf.path[lvl] { + // reached divergence in next level, set the leaves as children + // at the current node + if new_leaf.path[lvl] { + n.left = Box::new(Node::Leaf(old_leaf)); + n.right = Box::new(Node::Leaf(new_leaf)); + } else { + n.left = Box::new(Node::Leaf(new_leaf)); + n.right = Box::new(Node::Leaf(old_leaf)); + } + return Ok(()); + } + + // no divergence yet, continue going down + if new_leaf.path[lvl] { + n.right = Box::new(Node::Intermediate(Intermediate::empty())); + return n + .right + .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); + } else { + n.left = Box::new(Node::Intermediate(Intermediate::empty())); + return n + .left + .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); + } + } + Ok(()) + } +} + +#[derive(Clone, Debug)] +struct Intermediate { + hash: Option, + left: Box, + right: Box, +} +impl Intermediate { + fn empty() -> Self { + Self { + hash: None, + left: Box::new(Node::None), + right: Box::new(Node::None), + } + } + fn compute_hash(&mut self) -> Hash { + if self.left.clone().is_empty() && self.right.clone().is_empty() { + self.hash = Some(NULL); + return NULL; + } + let l_hash = self.left.compute_hash(); + let r_hash = self.right.compute_hash(); + let input: Vec = [l_hash.0, r_hash.0].concat(); + let h = Hash(PoseidonHash::hash_no_pad(&input).elements); + self.hash = Some(h); + h + } + fn hash(&self) -> Hash { + self.hash.unwrap() + } +} + +#[derive(Clone, Debug)] +struct Leaf { + hash: Option, + path: Vec, + key: Value, + value: Value, +} +impl Leaf { + fn new(max_depth: usize, key: Value, value: Value) -> Result { + Ok(Self { + hash: None, + path: keypath(max_depth, key)?, + key, + value, + }) + } + fn compute_hash(&mut self) -> Hash { + let input: Vec = [self.key.0, self.value.0].concat(); + let h = Hash(PoseidonHash::hash_no_pad(&input).elements); + self.hash = Some(h); + h + } + fn hash(&self) -> Hash { + self.hash.unwrap() + } +} + +// NOTE 1: think if maybe the length of the returned vector can be <256 +// (8*bytes.len()), so that we can do fewer iterations. For example, if the +// tree.max_depth is set to 20, we just need 20 iterations of the loop, not 256. +// NOTE 2: which approach do we take with keys that are longer than the +// max-depth? ie, what happens when two keys share the same path for more bits +// than the max_depth? +/// returns the path of the given key +fn keypath(max_depth: usize, k: Value) -> Result> { + let bytes = k.to_bytes(); + if max_depth > 8 * bytes.len() { + // note that our current keys are of Value type, which are 4 Goldilocks + // field elements, ie ~256 bits, therefore the max_depth can not be + // bigger than 256. + return Err(anyhow!( + "key to short (key length: {}) for the max_depth: {}", + 8 * bytes.len(), + max_depth + )); + } + Ok((0..max_depth) + .map(|n| bytes[n / 8] & (1 << (n % 8)) != 0) + .collect()) +} + +pub struct Iter<'a> { + state: Vec<&'a Node>, +} + +impl<'a> Iterator for Iter<'a> { + type Item = (&'a Value, &'a Value); + + fn next(&mut self) -> Option { + let node = self.state.pop(); + match node { + Some(Node::None) => self.next(), + Some(Node::Leaf(Leaf { + hash: _, + path: _, + key, + value, + })) => Some((key, value)), + Some(Node::Intermediate(Intermediate { + hash: _, + left, + right, + })) => { + self.state.push(&right); + self.state.push(&left); + self.next() + } + _ => None, + } } } #[cfg(test)] pub mod tests { - use super::*; + use itertools::Itertools; + use std::cmp::Ordering; - use crate::middleware::hash_str; + use super::*; #[test] fn test_merkletree() -> Result<()> { - let (k0, v0) = ( - Value(hash_str("key_0".into()).0), - Value(hash_str("value_0".into()).0), - ); - let (k1, v1) = ( - Value(hash_str("key_1".into()).0), - Value(hash_str("value_1".into()).0), - ); - let (k2, v2) = ( - Value(hash_str("key_2".into()).0), - Value(hash_str("value_2".into()).0), + let mut kvs = HashMap::new(); + for i in 0..8 { + if i == 1 { + continue; + } + kvs.insert(Value::from(i), Value::from(1000 + i)); + } + let key = Value::from(13); + let value = Value::from(1013); + kvs.insert(key, value); + + let tree = MerkleTree::new(32, &kvs)?; + // when printing the tree, it should print the same tree as in + // https://0xparc.github.io/pod2/merkletree.html#example-2 + println!("{}", tree); + + // Inclusion checks + let (v, proof) = tree.prove(&Value::from(13))?; + assert_eq!(v, Value::from(1013)); + println!("{}", proof); + + MerkleTree::verify(32, tree.root(), &proof, &key, &value)?; + + // Exclusion checks + let key = Value::from(12); + let proof = tree.prove_nonexistence(&key)?; + assert_eq!( + proof.other_leaf.unwrap(), + (Value::from(4), Value::from(1004)) ); + println!("{}", proof); - let mut kvs = HashMap::new(); - kvs.insert(k0, v0); - kvs.insert(k1, v1); - kvs.insert(k2, v2); + MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key)?; + + let key = Value::from(1); + let proof = tree.prove_nonexistence(&Value::from(1))?; + assert_eq!(proof.other_leaf, None); + println!("{}", proof); + + MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key)?; + + // Check iterator + let collected_kvs: Vec<_> = tree.into_iter().collect::>(); - let tree = MerkleTree::new(&kvs); + // Expected key ordering + let cmp = |max_depth: usize| { + move |k1, k2| { + let path1 = keypath(max_depth, k1).unwrap(); + let path2 = keypath(max_depth, k2).unwrap(); - let proof = tree.prove(&k2)?; - MerkleTree::verify(tree.root(), &proof, &k2, &v2)?; + let first_unequal_bits = std::iter::zip(path1, path2).find(|(b1, b2)| b1 != b2); - // expect verification to fail with different key / value - assert!(MerkleTree::verify(tree.root(), &proof, &k2, &v0).is_err()); - assert!(MerkleTree::verify(tree.root(), &proof, &k0, &v2).is_err()); + match first_unequal_bits { + Some((b1, b2)) => { + if b1 < b2 { + Ordering::Less + } else { + Ordering::Greater + } + } + _ => Ordering::Equal, + } + } + }; - // non-existence proofs - let proof_ne = tree.prove_nonexistence(&k2)?; - let _ = MerkleTree::verify_nonexistence(tree.root(), &proof_ne, &k2)?; + let sorted_kvs = kvs + .iter() + .sorted_by(|(k1, _), (k2, _)| cmp(32)(**k1, **k2)) + .collect::>(); - // expect verification of existence fail for nonexistence proof - let _ = MerkleTree::verify(tree.root(), &proof_ne, &k2, &v2).is_err(); + assert_eq!(collected_kvs, sorted_kvs); Ok(()) } diff --git a/src/primitives/merkletree_new.rs b/src/primitives/merkletree_new.rs deleted file mode 100644 index 24b2c905..00000000 --- a/src/primitives/merkletree_new.rs +++ /dev/null @@ -1,637 +0,0 @@ -#![allow(unused)] -#![allow(dead_code, unused_variables)] -// NOTE: starting in this file (merkletree_new.rs), once we have the implementation ready we just -// place it in the merkletree.rs file. -use anyhow::{anyhow, Result}; -use itertools::Itertools; -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::field::types::Field; -use plonky2::hash::{hash_types::HashOut, poseidon::PoseidonHash}; -use plonky2::plonk::config::GenericConfig; -use plonky2::plonk::config::Hasher; -use std::collections::HashMap; -use std::fmt; -use std::iter::IntoIterator; - -use crate::middleware::{Hash, Value, C, D, F, NULL}; - -pub struct MerkleTree { - max_depth: usize, - root: Node, -} - -impl MerkleTree { - /// builds a new `MerkleTree` where the leaves contain the given key-values - pub fn new(max_depth: usize, kvs: &HashMap) -> Result { - let mut root = Node::Intermediate(Intermediate::empty()); - - for (k, v) in kvs.iter() { - let leaf = Leaf::new(max_depth, *k, *v)?; - root.add_leaf(0, max_depth, leaf)?; - } - - let _ = root.compute_hash(); - Ok(Self { max_depth, root }) - } - - /// returns an iterator over the leaves of the tree - pub fn iter(&self) -> Iter { - Iter { - state: vec![&self.root], - } - } -} - -impl fmt::Display for MerkleTree { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "\nPaste in GraphViz (https://dreampuf.github.io/GraphvizOnline/):\n-----" - ); - writeln!(f, "digraph hierarchy {{"); - writeln!(f, "node [fontname=Monospace,fontsize=10,shape=box]"); - write!(f, "{}", self.root); - writeln!(f, "\n}}\n-----") - } -} - -impl<'a> IntoIterator for &'a MerkleTree { - type Item = (&'a Value, &'a Value); - type IntoIter = Iter<'a>; - - fn into_iter(self) -> Self::IntoIter { - self.iter() - } -} - -#[derive(Clone, Debug)] -pub struct MerkleProof { - existence: bool, - siblings: Vec, -} - -impl MerkleProof { - /// Computes the root of the Merkle tree suggested by a Merkle proof. - /// If a value is not provided, the terminal node is assumed to be empty. - pub fn compute_root( - &self, - max_depth: usize, - key: &Value, - value: Option<&Value>, - ) -> Result { - if self.siblings.len() >= max_depth { - return Err(anyhow!("max depth reached")); - } - - let path = keypath(max_depth, *key)?; - let mut h = value - .map(|v| Hash(PoseidonHash::hash_no_pad(&[key.0, v.0].concat()).elements)) - .unwrap_or(Hash([GoldilocksField(0); 4])); - for (i, sibling) in self.siblings.iter().enumerate().rev() { - let input: Vec = if path[i] { - [sibling.0, h.0].concat() - } else { - [h.0, sibling.0].concat() - }; - h = Hash(PoseidonHash::hash_no_pad(&input).elements); - } - Ok(h) - } -} - -impl fmt::Display for MerkleProof { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for (i, s) in self.siblings.iter().enumerate() { - if i > 0 { - write!(f, ", "); - } - write!(f, "{}", s); - } - Ok(()) - } -} - -impl MerkleTree { - /// returns the root of the tree - fn root(&self) -> Hash { - self.root.hash() - } - - /// returns the value at the given key - pub fn get(&self, key: &Value) -> Result { - let path = keypath(self.max_depth, *key)?; - let key_resolution = self.root.down(0, self.max_depth, path, None)?; - match key_resolution { - Some((k, v)) if &k == key => Ok(v), - _ => Err(anyhow!("key not found")), - } - } - - /// returns a boolean indicating whether the key exists in the tree - pub fn contains(&self, key: &Value) -> Result { - let path = keypath(self.max_depth, *key)?; - match self.root.down(0, self.max_depth, path, None) { - Ok(Some((k, _))) => { - if &k == key { - Ok(true) - } else { - Ok(false) - } - } - _ => Ok(false), - } - } - - /// returns a proof of existence, which proves that the given key exists in - /// the tree. It returns the `value` of the leaf at the given `key`, and - /// the `MerkleProof`. - fn prove(&self, key: &Value) -> Result<(Value, MerkleProof)> { - let path = keypath(self.max_depth, *key)?; - - let mut siblings: Vec = Vec::new(); - - match self - .root - .down(0, self.max_depth, path, Some(&mut siblings))? - { - Some((k, v)) if &k == key => Ok(( - v, - MerkleProof { - existence: true, - siblings, - }, - )), - _ => Err(anyhow!("key not found")), - } - } - - /// returns a proof of non-existence, which proves that the given - /// `key` does not exist in the tree. The return value specifies - /// the key-value pair in the leaf reached as a result of - /// resolving `key` as well as a `MerkleProof`. - fn prove_nonexistence(&self, key: &Value) -> Result<(Option<(Value, Value)>, MerkleProof)> { - let path = keypath(self.max_depth, *key)?; - - let mut siblings: Vec = Vec::new(); - - // note: non-existence of a key can be in 2 cases: - match self - .root - .down(0, self.max_depth, path, Some(&mut siblings))? - { - // - the expected leaf does not exist - None => Ok(( - None, - MerkleProof { - existence: false, - siblings, - }, - )), - // - the expected leaf does exist in the tree, but it has a different `key` - Some((k, v)) if &k != key => Ok(( - Some((k, v)), - MerkleProof { - existence: false, - siblings, - }, - )), - _ => Err(anyhow!("key found")), - } - // both cases prove that the given key don't exist in the tree. ∎ - } - - /// verifies an inclusion proof for the given `key` and `value` - fn verify( - max_depth: usize, - root: Hash, - proof: &MerkleProof, - key: &Value, - value: &Value, - ) -> Result<()> { - let h = proof.compute_root(max_depth, key, Some(value))?; - - if h != root { - return Err(anyhow!("proof of inclusion does not verify")); - } else { - Ok(()) - } - } - - /// verifies a non-inclusion proof for the given `key`, that is, the given - /// `key` does not exist in the tree - fn verify_nonexistence( - max_depth: usize, - root: Hash, - proof: &MerkleProof, - key: &Value, - other_leaf: Option<&(Value, Value)>, - ) -> Result<()> { - match other_leaf { - Some((k, v)) if k == key => Err(anyhow!("Invalid non-existence proof.")), - _ => { - let k = other_leaf.map(|(k, _)| k).unwrap_or(key); - let v = other_leaf.map(|(_, v)| v); - let h = proof.compute_root(max_depth, k, v)?; - - if h != root { - return Err(anyhow!("proof of exclusion does not verify")); - } else { - Ok(()) - } - } - } - } -} - -#[derive(Clone, Debug)] -enum Node { - None, - Leaf(Leaf), - Intermediate(Intermediate), -} - -impl fmt::Display for Node { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Intermediate(n) => { - writeln!( - f, - "\"{}\" -> {{ \"{}\" \"{}\" }}", - n.hash(), - n.left.hash(), - n.right.hash() - ); - write!(f, "{}", n.left); - write!(f, "{}", n.right) - } - Self::Leaf(l) => { - writeln!(f, "\"{}\" [style=filled]", l.hash()); - writeln!(f, "\"k:{}\\nv:{}\" [style=dashed]", l.key, l.value); - writeln!( - f, - "\"{}\" -> {{ \"k:{}\\nv:{}\" }}", - l.hash(), - l.key, - l.value, - ) - } - Self::None => Ok(()), - } - } -} - -impl Node { - fn is_empty(&self) -> bool { - match self { - Self::None => true, - Self::Leaf(l) => false, - Self::Intermediate(n) => false, - } - } - fn compute_hash(&mut self) -> Hash { - match self { - Self::None => NULL, - Self::Leaf(l) => l.compute_hash(), - Self::Intermediate(n) => n.compute_hash(), - } - } - fn hash(&self) -> Hash { - match self { - Self::None => NULL, - Self::Leaf(l) => l.hash(), - Self::Intermediate(n) => n.hash(), - } - } - - /// Goes down from the current node until it encounters a terminal - /// node, viz. a leaf or empty node, or until it reaches the - /// maximum depth. The `siblings` parameter is used to store the - /// siblings while going down to the leaf, if the given parameter - /// is set to `None`, then no siblings are stored. In this way, - /// the same method `down` can be used by MerkleTree methods - /// `get`, `contains`, `prove` and `prove_nonexistence`. - /// - /// Be aware that this method will return the found leaf at the - /// given path, which may contain a different key and value than - /// the expected one. - fn down( - &self, - lvl: usize, - max_depth: usize, - path: Vec, - mut siblings: Option<&mut Vec>, - ) -> Result> { - if lvl >= max_depth { - return Err(anyhow!("max depth reached")); - } - - match self { - Self::Intermediate(n) => { - if path[lvl] { - if let Some(s) = siblings.as_mut() { - s.push(n.left.hash()); - } - return n.right.down(lvl + 1, max_depth, path, siblings); - } else { - if let Some(s) = siblings.as_mut() { - s.push(n.right.hash()); - } - return n.left.down(lvl + 1, max_depth, path, siblings); - } - } - Self::Leaf(Leaf { - key, - value, - path: _p, - hash: _h, - }) => Ok(Some((key.clone(), value.clone()))), - _ => Ok(None), - } - } - - // adds the leaf at the tree from the current node (self), without computing any hash - fn add_leaf(&mut self, lvl: usize, max_depth: usize, leaf: Leaf) -> Result<()> { - if lvl >= max_depth { - return Err(anyhow!("max depth reached")); - } - - match self { - Self::Intermediate(n) => { - if leaf.path[lvl] { - if n.right.is_empty() { - // empty sub-node, add the leaf here - n.right = Box::new(Node::Leaf(leaf)); - return Ok(()); - } - n.right.add_leaf(lvl + 1, max_depth, leaf)?; - } else { - if n.left.is_empty() { - // empty sub-node, add the leaf here - n.left = Box::new(Node::Leaf(leaf)); - return Ok(()); - } - n.left.add_leaf(lvl + 1, max_depth, leaf)?; - } - } - Self::Leaf(l) => { - // in this case, it means that we found a leaf in the new-leaf path, thus we need - // to push both leaves (old-leaf and new-leaf) down the path till their paths - // diverge. - - // first check that keys of both leaves are different - // (l=old-leaf, leaf=new-leaf) - if l.key == leaf.key { - // TODO decide if we want to return an error when trying to add a leaf that - // already exists, or if we just ignore it. For the moment we return the error - // if the key already exists in the leaf. - return Err(anyhow!("key already exists")); - } - let old_leaf = l.clone(); - // set self as an intermediate node - *self = Node::Intermediate(Intermediate::empty()); - return self.down_till_divergence(lvl, max_depth, old_leaf, leaf); - } - Self::None => { - return Err(anyhow!("reached empty node, should not have entered")); - } - } - Ok(()) - } - - /// goes down through a 'virtual' path till finding a divergence. This method is used for when - /// adding a new leaf another already existing leaf is found, so that both leaves (new and old) - /// are pushed down the path till their keys diverge. - fn down_till_divergence( - &mut self, - lvl: usize, - max_depth: usize, - old_leaf: Leaf, - new_leaf: Leaf, - ) -> Result<()> { - if lvl >= max_depth { - return Err(anyhow!("max depth reached")); - } - - if let Node::Intermediate(ref mut n) = self { - // let current_node: Intermediate = *self; - if old_leaf.path[lvl] != new_leaf.path[lvl] { - // reached divergence in next level, set the leaves as children at the current node - if new_leaf.path[lvl] { - n.left = Box::new(Node::Leaf(old_leaf)); - n.right = Box::new(Node::Leaf(new_leaf)); - } else { - n.left = Box::new(Node::Leaf(new_leaf)); - n.right = Box::new(Node::Leaf(old_leaf)); - } - return Ok(()); - } - - // no divergence yet, continue going down - if new_leaf.path[lvl] { - n.right = Box::new(Node::Intermediate(Intermediate::empty())); - return n - .right - .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); - } else { - n.left = Box::new(Node::Intermediate(Intermediate::empty())); - return n - .left - .down_till_divergence(lvl + 1, max_depth, old_leaf, new_leaf); - } - } - Ok(()) - } -} - -#[derive(Clone, Debug)] -struct Intermediate { - hash: Option, - left: Box, - right: Box, -} -impl Intermediate { - fn empty() -> Self { - Self { - hash: None, - left: Box::new(Node::None), - right: Box::new(Node::None), - } - } - fn compute_hash(&mut self) -> Hash { - if self.left.clone().is_empty() && self.right.clone().is_empty() { - self.hash = Some(NULL); - return NULL; - } - let l_hash = self.left.compute_hash(); - let r_hash = self.right.compute_hash(); - let input: Vec = [l_hash.0, r_hash.0].concat(); - let h = Hash(PoseidonHash::hash_no_pad(&input).elements); - self.hash = Some(h); - h - } - fn hash(&self) -> Hash { - self.hash.unwrap() - } -} - -#[derive(Clone, Debug)] -struct Leaf { - hash: Option, - path: Vec, - key: Value, - value: Value, -} -impl Leaf { - fn new(max_depth: usize, key: Value, value: Value) -> Result { - Ok(Self { - hash: None, - path: keypath(max_depth, key)?, - key, - value, - }) - } - fn compute_hash(&mut self) -> Hash { - let input: Vec = [self.key.0, self.value.0].concat(); - let h = Hash(PoseidonHash::hash_no_pad(&input).elements); - self.hash = Some(h); - h - } - fn hash(&self) -> Hash { - self.hash.unwrap() - } -} - -// NOTE 1: think if maybe the length of the returned vector can be <256 (8*bytes.len()), so that -// we can do fewer iterations. For example, if the tree.max_depth is set to 20, we just need 20 -// iterations of the loop, not 256. -// NOTE 2: which approach do we take with keys that are longer than the max-depth? ie, what -// happens when two keys share the same path for more bits than the max_depth? -/// returns the path of the given key -fn keypath(max_depth: usize, k: Value) -> Result> { - let bytes = k.to_bytes(); - if max_depth > 8 * bytes.len() { - // note that our current keys are of Value type, which are 4 Goldilocks field elements, ie - // ~256 bits, therefore the max_depth can not be bigger than 256. - return Err(anyhow!( - "key to short (key length: {}) for the max_depth: {}", - 8 * bytes.len(), - max_depth - )); - } - Ok((0..max_depth) - .map(|n| bytes[n / 8] & (1 << (n % 8)) != 0) - .collect()) -} - -pub struct Iter<'a> { - state: Vec<&'a Node>, -} - -impl<'a> Iterator for Iter<'a> { - type Item = (&'a Value, &'a Value); - - fn next(&mut self) -> Option { - let node = self.state.pop(); - match node { - Some(Node::None) => self.next(), - Some(Node::Leaf(Leaf { - hash: _, - path: _, - key, - value, - })) => Some((key, value)), - Some(Node::Intermediate(Intermediate { - hash: _, - left, - right, - })) => { - self.state.push(&right); - self.state.push(&left); - self.next() - } - _ => None, - } - } -} - -#[cfg(test)] -pub mod tests { - use std::cmp::Ordering; - - use super::*; - use crate::middleware::hash_str; - - #[test] - fn test_merkletree() -> Result<()> { - let mut kvs = HashMap::new(); - for i in 0..8 { - if i == 1 { - continue; - } - kvs.insert(Value::from(i), Value::from(1000 + i)); - } - let key = Value::from(13); - let value = Value::from(1013); - kvs.insert(key, value); - - let tree = MerkleTree::new(32, &kvs)?; - // when printing the tree, it should print the same tree as in - // https://0xparc.github.io/pod2/merkletree.html#example-2 - println!("{}", tree); - - // Inclusion checks - let (v, proof) = tree.prove(&Value::from(13))?; - assert_eq!(v, Value::from(1013)); - println!("{}", proof); - - MerkleTree::verify(32, tree.root(), &proof, &key, &value)?; - - // Exclusion checks - let key = Value::from(12); - let (other_leaf, proof) = tree.prove_nonexistence(&key)?; - assert_eq!(other_leaf, Some((Value::from(4), Value::from(1004)))); - println!("{}", proof); - - MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key, other_leaf.as_ref())?; - - let key = Value::from(1); - let (other_leaf, proof) = tree.prove_nonexistence(&Value::from(1))?; - assert_eq!(other_leaf, None); - println!("{}", proof); - - MerkleTree::verify_nonexistence(32, tree.root(), &proof, &key, other_leaf.as_ref())?; - - // Check iterator - let collected_kvs: Vec<_> = tree.into_iter().collect::>(); - - // Expected key ordering - let cmp = |max_depth: usize| { - move |k1, k2| { - let path1 = keypath(max_depth, k1).unwrap(); - let path2 = keypath(max_depth, k2).unwrap(); - - let first_unequal_bits = std::iter::zip(path1, path2).find(|(b1, b2)| b1 != b2); - - match first_unequal_bits { - Some((b1, b2)) => { - if b1 < b2 { - Ordering::Less - } else { - Ordering::Greater - } - } - _ => Ordering::Equal, - } - } - }; - - let sorted_kvs = kvs - .iter() - .sorted_by(|(k1, _), (k2, _)| cmp(32)(**k1, **k2)) - .collect::>(); - - assert_eq!(collected_kvs, sorted_kvs); - - Ok(()) - } -} diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index 8faa69ee..ad3ec6a6 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -1,2 +1 @@ pub mod merkletree; -pub mod merkletree_new;