diff --git a/Cargo.toml b/Cargo.toml index 8b311f6..1f9ca9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ itertools = "0.10.0" [dev-dependencies] serde_json = "1.0.0" tree_hash_derive = "0.5.0" +ethereum_ssz_derive = "0.5.0" diff --git a/src/lib.rs b/src/lib.rs index 3e181da..36c72d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,12 +40,14 @@ #[macro_use] mod bitfield; mod fixed_vector; +mod optional; pub mod serde_utils; mod tree_hash; mod variable_list; pub use bitfield::{BitList, BitVector, Bitfield}; pub use fixed_vector::FixedVector; +pub use optional::Optional; pub use typenum; pub use variable_list::VariableList; diff --git a/src/optional.rs b/src/optional.rs new file mode 100644 index 0000000..ae3b3ee --- /dev/null +++ b/src/optional.rs @@ -0,0 +1,251 @@ +use crate::tree_hash::optional_tree_hash_root; +use derivative::Derivative; +use serde_derive::{Deserialize, Serialize}; +use tree_hash::Hash256; + +pub use typenum; + +/// Emulates a SSZ `Optional` (distinct from a Rust `Option`). +/// +/// This SSZ type is defined in EIP-6475. +/// +/// This struct is backed by a Rust `Option` and its behaviour is defined by the variant. +/// +/// If `Some`, it will serialize with a 1-byte identifying prefix with a value of 1 followed by the +/// serialized internal type. +/// If `None`, it will serialize as `null`. +/// +/// `Optional` will Merklize in the following ways: +/// `if None`: Merklize as an empty `VariableList` +/// `if Some(T)`: Merklize as a `VariableList` of length 1 whose single value is `T`. +/// +/// ## Example +/// +/// ``` +/// use ssz_types::{Optional, typenum::*, VariableList}; +/// use tree_hash::TreeHash; +/// use ssz::Encode; +/// +/// // Create an `Optional` from an `Option` that is `Some`. +/// let some: Option = Some(9); +/// let ssz: Optional = Optional::from(some); +/// let serialized: &[u8] = &ssz.as_ssz_bytes(); +/// assert_eq!(serialized, &[1, 9]); +/// +/// let root = ssz.tree_hash_root(); +/// let equivalent_list: VariableList = VariableList::from(vec![9; 1]); +/// assert_eq!(root, equivalent_list.tree_hash_root()); +/// +/// // Create an `Optional` from an `Option` that is `None`. +/// let none: Option = None; +/// let ssz: Optional = Optional::from(none); +/// let serialized: &[u8] = &ssz.as_ssz_bytes(); +/// let null: &[u8] = &[]; +/// assert_eq!(serialized, null); +/// +/// let root = ssz.tree_hash_root(); +/// let equivalent_list: VariableList = VariableList::from(vec![]); +/// assert_eq!(root, equivalent_list.tree_hash_root()); +/// +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, Derivative)] +#[derivative(PartialEq, Hash(bound = "T: std::hash::Hash"))] +#[serde(transparent)] +pub struct Optional { + optional: Option, +} + +impl From> for Optional { + fn from(optional: Option) -> Self { + Self { optional } + } +} + +impl From> for Option { + fn from(val: Optional) -> Option { + val.optional + } +} + +impl Default for Optional { + fn default() -> Self { + Self { optional: None } + } +} + +impl tree_hash::TreeHash for Optional +where + T: tree_hash::TreeHash, +{ + fn tree_hash_type() -> tree_hash::TreeHashType { + tree_hash::TreeHashType::List + } + + fn tree_hash_packed_encoding(&self) -> tree_hash::PackedEncoding { + unreachable!("List should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("List should never be packed.") + } + + fn tree_hash_root(&self) -> Hash256 { + let root = optional_tree_hash_root::(&self.optional); + + let length = match &self.optional { + None => 0, + Some(_) => 1, + }; + + tree_hash::mix_in_length(&root, length) + } +} + +impl ssz::Encode for Optional +where + T: ssz::Encode, +{ + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_bytes_len(&self) -> usize { + match &self.optional { + None => 0, + Some(val) => val.ssz_bytes_len() + 1, + } + } + + fn ssz_append(&self, buf: &mut Vec) { + match &self.optional { + None => (), + Some(val) => { + let mut optional_identifier = vec![1]; + buf.append(&mut optional_identifier); + val.ssz_append(buf) + } + } + } +} + +impl ssz::Decode for Optional +where + T: ssz::Decode, +{ + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + if let Some((first, rest)) = bytes.split_first() { + if first == &0x01 { + return Ok(Optional { + optional: Some(T::from_ssz_bytes(&rest)?), + }); + } else { + // An `Optional` must always contains `0x01` as the first byte. + // Might be worth having an explicit error variant in ssz::DecodeError. + return Err(ssz::DecodeError::BytesInvalid( + "Missing Optional identifier byte".to_string(), + )); + } + } else { + Ok(Optional { optional: None }) + } + } +} + +/// TODO Use a more robust `Arbitrary` impl. +#[cfg(feature = "arbitrary")] +impl<'a, T: arbitrary::Arbitrary<'a>> arbitrary::Arbitrary<'a> for Optional { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let option = Some(::arbitrary(u).unwrap()); + Ok(Self::from(option)) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{FixedVector, VariableList}; + use ssz::*; + use ssz_derive::{Decode, Encode}; + use tree_hash::TreeHash; + use tree_hash_derive::TreeHash; + use typenum::*; + + #[test] + fn encode() { + let some: Optional = Some(42).into(); + let bytes: Vec = vec![1, 42]; + assert_eq!(some.as_ssz_bytes(), bytes); + + let none: Optional = None.into(); + let empty: Vec = vec![]; + assert_eq!(none.as_ssz_bytes(), empty); + } + + #[test] + fn decode() { + let bytes = &[1, 42, 0, 0, 0, 0, 0, 0, 0]; + let some: Optional = Optional::from_ssz_bytes(bytes).unwrap(); + assert_eq!(Some(42), some.optional); + + let empty = &[]; + let none: Optional = Optional::from_ssz_bytes(empty).unwrap(); + assert_eq!(None, none.optional); + } + + #[test] + fn tree_hash_none() { + // None should merklize the same as an empty VariableList. + let none: Optional = Optional::from(None); + let empty_list: VariableList = VariableList::from(vec![]); + assert_eq!(none.tree_hash_root(), empty_list.tree_hash_root()); + } + + #[test] + fn tree_hash_some_int() { + // Optional should merklize the same as a length 1 VariableList. + let some_int: Optional = Optional::from(Some(9)); + let list_int: VariableList = VariableList::from(vec![9; 1]); + assert_eq!(some_int.tree_hash_root(), list_int.tree_hash_root()); + } + + #[test] + fn tree_hash_some_list() { + // Optional should merklize the same as a length 1 VariableList. + let list: VariableList = VariableList::from(vec![9; 16]); + let some_list: Optional> = Optional::from(Some(list.clone())); + let list_list: VariableList, U1> = VariableList::from(vec![list; 1]); + assert_eq!(some_list.tree_hash_root(), list_list.tree_hash_root()); + } + + #[test] + fn tree_hash_some_vec() { + // Optional should merklize the same as a length 1 VariableList. + let vec: FixedVector = FixedVector::from(vec![9; 16]); + let some_vec: Optional> = Optional::from(Some(vec.clone())); + let list_vec: VariableList, U1> = VariableList::from(vec![vec; 1]); + assert_eq!(some_vec.tree_hash_root(), list_vec.tree_hash_root()); + } + + #[test] + fn tree_hash_some_object() { + #[derive(TreeHash, Decode, Encode)] + struct Object { + a: u8, + b: u8, + } + + // Optional should merklize the same as a length 1 VariableList. Note the 1-byte identifier + // during deserialization. + let optional_object: Optional = Optional::from_ssz_bytes(&[1, 11, 9]).unwrap(); + let list_object: VariableList = VariableList::from_ssz_bytes(&[11, 9]).unwrap(); + + assert_eq!( + optional_object.tree_hash_root(), + list_object.tree_hash_root() + ); + } +} diff --git a/src/tree_hash.rs b/src/tree_hash.rs index e08c1d6..9b97ed4 100644 --- a/src/tree_hash.rs +++ b/src/tree_hash.rs @@ -56,3 +56,22 @@ pub fn bitfield_bytes_tree_hash_root(bytes: &[u8]) -> Hash256 { .finish() .expect("bitfield tree hash buffer should not exceed leaf limit") } + +/// A helper function providing common functionality for finding the Merkle root of some bytes that +/// represent an optional value. +pub fn optional_tree_hash_root(option: &Option) -> Hash256 { + let mut hasher = MerkleHasher::with_leaves(1); + + match option { + None => (), + Some(val) => { + hasher + .write(val.tree_hash_root().as_bytes()) + .expect("ssz_types optional should only contain 1 element"); + } + } + + hasher + .finish() + .expect("ssz_types optional should not have a remaining buffer") +}