Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions masp_primitives/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::collections::BTreeMap;
use std::{
io::{self, Write},
iter::Sum,
ops::{Add, AddAssign, Sub, SubAssign},
ops::{Add, AddAssign, Neg, Sub, SubAssign},
};

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -143,20 +143,18 @@ impl BorshSerialize for AllowedConversion {
}

impl BorshDeserialize for AllowedConversion {
/// This deserialization is unsafe because it does not do the expensive
/// computation of checking whether the asset generator corresponds to the
/// deserialized amount.
fn deserialize_reader<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let assets = I128Sum::read(reader)?;
let gen_bytes =
<<jubjub::ExtendedPoint as GroupEncoding>::Repr as BorshDeserialize>::deserialize_reader(reader)?;
let generator = Option::from(jubjub::ExtendedPoint::from_bytes(&gen_bytes))
.ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?;
let allowed_conversion: AllowedConversion = assets.clone().into();
if allowed_conversion.generator != generator {
return Err(io::Error::from(io::ErrorKind::InvalidData));
// Use the unchecked reader to ensure that same format is supported
let unchecked_conv = UncheckedAllowedConversion::deserialize_reader(reader)?.0;
// Recompute the generator using only the value sum
let safe_conv: AllowedConversion = unchecked_conv.assets.clone().into();
// Check that the computed generator is identical to what was read
if safe_conv.generator == unchecked_conv.generator {
Ok(safe_conv)
} else {
// The generators do not match, so the bytes cannot be from Self::serialize
Err(io::Error::from(io::ErrorKind::InvalidData))
}
Ok(AllowedConversion { assets, generator })
}
}

Expand Down Expand Up @@ -196,12 +194,42 @@ impl SubAssign for AllowedConversion {
}
}

impl Neg for AllowedConversion {
type Output = Self;

fn neg(self) -> Self {
Self {
assets: -self.assets,
generator: -self.generator,
}
}
}

impl Sum for AllowedConversion {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(AllowedConversion::from(ValueSum::zero()), Add::add)
}
}

/// A seprate type to allow unchecked deserializations of AllowedConversions
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct UncheckedAllowedConversion(pub AllowedConversion);

impl BorshDeserialize for UncheckedAllowedConversion {
/// This deserialization is unchecked because it does not do the expensive
/// computation of checking whether the asset generator corresponds to the
/// deserialized amount.
fn deserialize_reader<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let assets = I128Sum::read(reader)?;
let gen_bytes =
<<jubjub::ExtendedPoint as GroupEncoding>::Repr as BorshDeserialize>::deserialize_reader(reader)?;
let generator = Option::from(jubjub::ExtendedPoint::from_bytes(&gen_bytes))
.ok_or_else(|| io::Error::from(io::ErrorKind::InvalidData))?;
// Assume that the generator just read corresponds to the value sum
Ok(Self(AllowedConversion { assets, generator }))
}
}

#[cfg(test)]
mod tests {
use crate::asset_type::AssetType;
Expand Down
21 changes: 16 additions & 5 deletions masp_primitives/src/transaction/components/amount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,15 +440,26 @@ where
type Output = ValueSum<Unit, <Lhs as CheckedMul<Rhs>>::Output>;

fn mul(self, rhs: Rhs) -> Self::Output {
self.checked_mul(rhs).expect("overflow detected")
}
}

impl<Unit, Lhs, Rhs> CheckedMul<Rhs> for ValueSum<Unit, Lhs>
where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Lhs: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedMul<Rhs>,
Rhs: Copy,
<Lhs as CheckedMul<Rhs>>::Output: Default + BorshSerialize + BorshDeserialize + Eq,
{
type Output = ValueSum<Unit, <Lhs as CheckedMul<Rhs>>::Output>;

fn checked_mul(self, rhs: Rhs) -> Option<Self::Output> {
let mut comps = BTreeMap::new();
for (atype, amount) in self.0.iter() {
comps.insert(
atype.clone(),
amount.checked_mul(rhs).expect("overflow detected"),
);
comps.insert(atype.clone(), amount.checked_mul(rhs)?);
}
comps.retain(|_, v| *v != <Lhs as CheckedMul<Rhs>>::Output::default());
ValueSum(comps)
Some(ValueSum(comps))
}
}

Expand Down
Loading