@@ -18,6 +18,30 @@ mod id;
1818#[ cfg( feature = "lazy-account" ) ]
1919mod lazy;
2020
21+ fn is_zero_lit ( lit : & syn:: Lit ) -> bool {
22+ match lit {
23+ syn:: Lit :: Int ( val) => val. base10_parse :: < u128 > ( ) . is_ok_and ( |v| v == 0 ) ,
24+ syn:: Lit :: Byte ( val) => val. value ( ) == 0 ,
25+ syn:: Lit :: ByteStr ( val) => val. value ( ) . iter ( ) . all ( |byte| * byte == 0 ) ,
26+ _ => false ,
27+ }
28+ }
29+
30+ fn is_zeroed_discriminator ( discr : & Expr ) -> bool {
31+ match discr {
32+ Expr :: Reference ( syn:: ExprReference { expr, .. } )
33+ | Expr :: Paren ( syn:: ExprParen { expr, .. } )
34+ | Expr :: Group ( syn:: ExprGroup { expr, .. } ) => is_zeroed_discriminator ( expr) ,
35+ Expr :: Lit ( syn:: ExprLit { lit, .. } ) => is_zero_lit ( lit) ,
36+ Expr :: Array ( arr) => arr. elems . iter ( ) . all ( is_zeroed_discriminator) ,
37+ // [0; N] is all zeroed for any N, and [X; 0] is empty.
38+ Expr :: Repeat ( rep) => {
39+ is_zeroed_discriminator ( & rep. expr ) || is_zeroed_discriminator ( & rep. len )
40+ }
41+ _ => false ,
42+ }
43+ }
44+
2145/// An attribute for a data structure representing a Solana account.
2246///
2347/// `#[account]` generates trait implementations for the following traits:
@@ -56,7 +80,7 @@ mod lazy;
5680/// - `discriminator = MY_DISC`
5781/// - `discriminator = get_disc(...)`
5882///
59- /// All-zeroed discriminators are not supported.
83+ /// All-zero or empty discriminators are not supported.
6084///
6185/// # Zero Copy Deserialization
6286///
@@ -113,32 +137,10 @@ pub fn account(
113137 let account_name_str = account_name. to_string ( ) ;
114138 let ( impl_gen, type_gen, where_clause) = account_strct. generics . split_for_impl ( ) ;
115139
116- fn is_zero_lit ( expr : & Expr ) -> bool {
117- matches ! (
118- expr,
119- Expr :: Lit ( syn:: ExprLit { lit: syn:: Lit :: Int ( val) , .. } )
120- if val. base10_parse:: <u128 >( ) . is_ok_and( |v| v == 0 )
121- )
122- }
123-
124- fn is_zeroed_discriminator ( mut discr : & Expr ) -> bool {
125- // Peel references
126- while let Expr :: Reference ( syn:: ExprReference { expr, .. } ) = discr {
127- discr = expr;
128- }
129- match discr {
130- Expr :: Lit ( _) => is_zero_lit ( discr) ,
131- Expr :: Array ( arr) => arr. elems . iter ( ) . all ( is_zero_lit) ,
132- // [0; N] — repeat expression
133- Expr :: Repeat ( rep) => is_zero_lit ( & rep. expr ) ,
134- _ => false ,
135- }
136- }
137-
138140 let discriminator = match args. overrides . and_then ( |ov| ov. discriminator ) {
139141 Some ( discrim) => {
140142 let zero_err = is_zeroed_discriminator ( & discrim) . then ( ||
141- quote_spanned ! { discrim. span( ) => compile_error!( "all-zero discriminators are not supported" ) ; }
143+ quote_spanned ! { discrim. span( ) => compile_error!( "all-zero or empty discriminators are not supported" ) ; }
142144 ) ;
143145 quote ! {
144146 {
@@ -643,3 +645,41 @@ pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
643645 #[ allow( unreachable_code) ]
644646 proc_macro:: TokenStream :: from ( ret)
645647}
648+
649+ #[ cfg( test) ]
650+ mod tests {
651+ use super :: * ;
652+
653+ #[ allow( clippy:: expect_used) ]
654+ fn zeroed ( source : & str ) -> bool {
655+ let expr = syn:: parse_str ( source) . expect ( "test expression should parse" ) ;
656+ is_zeroed_discriminator ( & expr)
657+ }
658+
659+ #[ test]
660+ fn detects_zeroed_discriminator_literals ( ) {
661+ assert ! ( zeroed( "0" ) ) ;
662+ assert ! ( zeroed( "b'\\ x00'" ) ) ;
663+ assert ! ( zeroed( "b\" \" " ) ) ;
664+ assert ! ( zeroed( "b\" \\ x00\\ x00\" " ) ) ;
665+
666+ assert ! ( !zeroed( "1" ) ) ;
667+ assert ! ( !zeroed( "b'a'" ) ) ;
668+ assert ! ( !zeroed( "b\" \\ x00\\ x01\" " ) ) ;
669+ assert ! ( !zeroed( "\" \" " ) ) ;
670+ }
671+
672+ #[ test]
673+ fn detects_zeroed_discriminator_collections ( ) {
674+ assert ! ( zeroed( "&[0, (0)]" ) ) ;
675+ assert ! ( zeroed( "[]" ) ) ;
676+ assert ! ( zeroed( "[0; N]" ) ) ;
677+ assert ! ( zeroed( "[1; 0]" ) ) ;
678+ assert ! ( zeroed( "(&b\" \\ x00\" )" ) ) ;
679+
680+ assert ! ( !zeroed( "&[0, 1]" ) ) ;
681+ assert ! ( !zeroed( "[1; N]" ) ) ;
682+ assert ! ( !zeroed( "MY_DISC" ) ) ;
683+ assert ! ( !zeroed( "get_disc()" ) ) ;
684+ }
685+ }
0 commit comments