diff --git a/address/src/lib.rs b/address/src/lib.rs index ac9e97d85..8c53928e9 100644 --- a/address/src/lib.rs +++ b/address/src/lib.rs @@ -34,6 +34,7 @@ use core::{ array, convert::TryFrom, hash::{Hash, Hasher}, + ptr::read_unaligned, }; #[cfg(feature = "serde")] use serde_derive::{Deserialize, Serialize}; @@ -82,7 +83,7 @@ const PDA_MARKER: &[u8; 21] = b"ProgramDerivedAddress"; #[cfg_attr(all(feature = "borsh", feature = "std"), derive(BorshSchema))] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))] -#[derive(Clone, Copy, Default, Eq, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Default, Eq, Ord, PartialOrd)] #[cfg_attr(feature = "dev-context-only-utils", derive(Arbitrary))] pub struct Address(pub(crate) [u8; 32]); @@ -307,6 +308,25 @@ impl core::fmt::Display for Address { } } +/// Custom impl of `PartialEq` for `Address`. +/// +/// The implementation compares the address in 4 chunks of 8 bytes (`u64` values), +/// which is currently more efficient (CU-wise) than the default implementation. +impl PartialEq for Address { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + let p1_ptr = self.0.as_ptr().cast::(); + let p2_ptr = other.0.as_ptr().cast::(); + + unsafe { + read_unaligned(p1_ptr) == read_unaligned(p2_ptr) + && read_unaligned(p1_ptr.add(1)) == read_unaligned(p2_ptr.add(1)) + && read_unaligned(p1_ptr.add(2)) == read_unaligned(p2_ptr.add(2)) + && read_unaligned(p1_ptr.add(3)) == read_unaligned(p2_ptr.add(3)) + } + } +} + /// Convenience macro to define a static `Address` value. /// /// Input: a single literal base58 string representation of an `Address`. @@ -609,4 +629,24 @@ mod tests { ADDRESS ); } + + #[test] + fn test_address_eq_matches_default_eq() { + for i in 0..u8::MAX { + let p1 = Address::from([i; ADDRESS_BYTES]); + let p2 = Address::from([i; ADDRESS_BYTES]); + + // Identical addresses must be equal. + assert!(p1 == p2); + assert!(p1.eq(&p2)); + assert_eq!(p1.eq(&p2), p1.0 == p2.0); + + let p3 = Address::from([u8::MAX - i; ADDRESS_BYTES]); + + // Different addresses must not be equal. + assert!(p1 != p3); + assert!(!p1.eq(&p3)); + assert_eq!(!p1.eq(&p3), p1.0 != p3.0); + } + } }