@@ -34,6 +34,7 @@ use core::{
3434 array,
3535 convert:: TryFrom ,
3636 hash:: { Hash , Hasher } ,
37+ ptr:: read_unaligned,
3738} ;
3839#[ cfg( feature = "serde" ) ]
3940use serde_derive:: { Deserialize , Serialize } ;
@@ -82,7 +83,7 @@ const PDA_MARKER: &[u8; 21] = b"ProgramDerivedAddress";
8283#[ cfg_attr( all( feature = "borsh" , feature = "std" ) , derive( BorshSchema ) ) ]
8384#[ cfg_attr( feature = "serde" , derive( Deserialize , Serialize ) ) ]
8485#[ cfg_attr( feature = "bytemuck" , derive( Pod , Zeroable ) ) ]
85- #[ derive( Clone , Copy , Default , Eq , Ord , PartialEq , PartialOrd ) ]
86+ #[ derive( Clone , Copy , Default , Eq , Ord , PartialOrd ) ]
8687#[ cfg_attr( feature = "dev-context-only-utils" , derive( Arbitrary ) ) ]
8788pub struct Address ( pub ( crate ) [ u8 ; 32 ] ) ;
8889
@@ -307,6 +308,25 @@ impl core::fmt::Display for Address {
307308 }
308309}
309310
311+ /// Custom impl of `PartialEq` for `Address`.
312+ ///
313+ /// The implementation compares the address in 4 chunks of 8 bytes (`u64` values),
314+ /// which is currently more efficient (CU-wise) than the default implementation.
315+ impl PartialEq for Address {
316+ #[ inline( always) ]
317+ fn eq ( & self , other : & Self ) -> bool {
318+ let p1_ptr = self . 0 . as_ptr ( ) . cast :: < u64 > ( ) ;
319+ let p2_ptr = other. 0 . as_ptr ( ) . cast :: < u64 > ( ) ;
320+
321+ unsafe {
322+ read_unaligned ( p1_ptr) == read_unaligned ( p2_ptr)
323+ && read_unaligned ( p1_ptr. add ( 1 ) ) == read_unaligned ( p2_ptr. add ( 1 ) )
324+ && read_unaligned ( p1_ptr. add ( 2 ) ) == read_unaligned ( p2_ptr. add ( 2 ) )
325+ && read_unaligned ( p1_ptr. add ( 3 ) ) == read_unaligned ( p2_ptr. add ( 3 ) )
326+ }
327+ }
328+ }
329+
310330/// Convenience macro to define a static `Address` value.
311331///
312332/// Input: a single literal base58 string representation of an `Address`.
@@ -609,4 +629,24 @@ mod tests {
609629 ADDRESS
610630 ) ;
611631 }
632+
633+ #[ test]
634+ fn test_address_eq_matches_default_eq ( ) {
635+ for i in 0 ..u8:: MAX {
636+ let p1 = Address :: from ( [ i; ADDRESS_BYTES ] ) ;
637+ let p2 = Address :: from ( [ i; ADDRESS_BYTES ] ) ;
638+
639+ // Identical addresses must be equal.
640+ assert ! ( p1 == p2) ;
641+ assert ! ( p1. eq( & p2) ) ;
642+ assert_eq ! ( p1. eq( & p2) , p1. 0 == p2. 0 ) ;
643+
644+ let p3 = Address :: from ( [ u8:: MAX - i; ADDRESS_BYTES ] ) ;
645+
646+ // Different addresses must not be equal.
647+ assert ! ( p1 != p3) ;
648+ assert ! ( !p1. eq( & p3) ) ;
649+ assert_eq ! ( !p1. eq( & p3) , p1. 0 != p3. 0 ) ;
650+ }
651+ }
612652}
0 commit comments