Skip to content

Commit abd58a0

Browse files
authored
address!: Add custom PartialEq impl (#318)
* Add custom partial eq impl * Add inline attribute
1 parent ac902c4 commit abd58a0

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

address/src/lib.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use core::{
3434
array,
3535
convert::TryFrom,
3636
hash::{Hash, Hasher},
37+
ptr::read_unaligned,
3738
};
3839
#[cfg(feature = "serde")]
3940
use serde_derive::{Deserialize, Serialize};
@@ -82,7 +83,7 @@ const PDA_MARKER: &[u8; 21] = b"ProgramDerivedAddress";
8283
#[cfg_attr(all(feature = "borsh", feature = "std"), derive(BorshSchema))]
8384
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
8485
#[cfg_attr(feature = "bytemuck", derive(Pod, Zeroable))]
85-
#[derive(Clone, Copy, Default, Eq, Ord, PartialEq, PartialOrd)]
86+
#[derive(Clone, Copy, Default, Eq, Ord, PartialOrd)]
8687
#[cfg_attr(feature = "dev-context-only-utils", derive(Arbitrary))]
8788
pub struct Address(pub(crate) [u8; 32]);
8889

@@ -307,6 +308,25 @@ impl core::fmt::Display for Address {
307308
}
308309
}
309310

311+
/// Custom impl of `PartialEq` for `Address`.
312+
///
313+
/// The implementation compares the address in 4 chunks of 8 bytes (`u64` values),
314+
/// which is currently more efficient (CU-wise) than the default implementation.
315+
impl PartialEq for Address {
316+
#[inline(always)]
317+
fn eq(&self, other: &Self) -> bool {
318+
let p1_ptr = self.0.as_ptr().cast::<u64>();
319+
let p2_ptr = other.0.as_ptr().cast::<u64>();
320+
321+
unsafe {
322+
read_unaligned(p1_ptr) == read_unaligned(p2_ptr)
323+
&& read_unaligned(p1_ptr.add(1)) == read_unaligned(p2_ptr.add(1))
324+
&& read_unaligned(p1_ptr.add(2)) == read_unaligned(p2_ptr.add(2))
325+
&& read_unaligned(p1_ptr.add(3)) == read_unaligned(p2_ptr.add(3))
326+
}
327+
}
328+
}
329+
310330
/// Convenience macro to define a static `Address` value.
311331
///
312332
/// Input: a single literal base58 string representation of an `Address`.
@@ -609,4 +629,24 @@ mod tests {
609629
ADDRESS
610630
);
611631
}
632+
633+
#[test]
634+
fn test_address_eq_matches_default_eq() {
635+
for i in 0..u8::MAX {
636+
let p1 = Address::from([i; ADDRESS_BYTES]);
637+
let p2 = Address::from([i; ADDRESS_BYTES]);
638+
639+
// Identical addresses must be equal.
640+
assert!(p1 == p2);
641+
assert!(p1.eq(&p2));
642+
assert_eq!(p1.eq(&p2), p1.0 == p2.0);
643+
644+
let p3 = Address::from([u8::MAX - i; ADDRESS_BYTES]);
645+
646+
// Different addresses must not be equal.
647+
assert!(p1 != p3);
648+
assert!(!p1.eq(&p3));
649+
assert_eq!(!p1.eq(&p3), p1.0 != p3.0);
650+
}
651+
}
612652
}

0 commit comments

Comments
 (0)