Skip to content

Commit e47eda0

Browse files
authored
fix account zeroed discriminator detection (#4645)
1 parent 0bc86d6 commit e47eda0

1 file changed

Lines changed: 64 additions & 24 deletions

File tree

  • lang/attribute/account/src

lang/attribute/account/src/lib.rs

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,30 @@ mod id;
1818
#[cfg(feature = "lazy-account")]
1919
mod lazy;
2020

21+
fn is_zero_lit(lit: &syn::Lit) -> bool {
22+
match lit {
23+
syn::Lit::Int(val) => val.base10_parse::<u128>().is_ok_and(|v| v == 0),
24+
syn::Lit::Byte(val) => val.value() == 0,
25+
syn::Lit::ByteStr(val) => val.value().iter().all(|byte| *byte == 0),
26+
_ => false,
27+
}
28+
}
29+
30+
fn is_zeroed_discriminator(discr: &Expr) -> bool {
31+
match discr {
32+
Expr::Reference(syn::ExprReference { expr, .. })
33+
| Expr::Paren(syn::ExprParen { expr, .. })
34+
| Expr::Group(syn::ExprGroup { expr, .. }) => is_zeroed_discriminator(expr),
35+
Expr::Lit(syn::ExprLit { lit, .. }) => is_zero_lit(lit),
36+
Expr::Array(arr) => arr.elems.iter().all(is_zeroed_discriminator),
37+
// [0; N] is all zeroed for any N, and [X; 0] is empty.
38+
Expr::Repeat(rep) => {
39+
is_zeroed_discriminator(&rep.expr) || is_zeroed_discriminator(&rep.len)
40+
}
41+
_ => false,
42+
}
43+
}
44+
2145
/// An attribute for a data structure representing a Solana account.
2246
///
2347
/// `#[account]` generates trait implementations for the following traits:
@@ -56,7 +80,7 @@ mod lazy;
5680
/// - `discriminator = MY_DISC`
5781
/// - `discriminator = get_disc(...)`
5882
///
59-
/// All-zeroed discriminators are not supported.
83+
/// All-zero or empty discriminators are not supported.
6084
///
6185
/// # Zero Copy Deserialization
6286
///
@@ -113,32 +137,10 @@ pub fn account(
113137
let account_name_str = account_name.to_string();
114138
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
115139

116-
fn is_zero_lit(expr: &Expr) -> bool {
117-
matches!(
118-
expr,
119-
Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(val), .. })
120-
if val.base10_parse::<u128>().is_ok_and(|v| v == 0)
121-
)
122-
}
123-
124-
fn is_zeroed_discriminator(mut discr: &Expr) -> bool {
125-
// Peel references
126-
while let Expr::Reference(syn::ExprReference { expr, .. }) = discr {
127-
discr = expr;
128-
}
129-
match discr {
130-
Expr::Lit(_) => is_zero_lit(discr),
131-
Expr::Array(arr) => arr.elems.iter().all(is_zero_lit),
132-
// [0; N] — repeat expression
133-
Expr::Repeat(rep) => is_zero_lit(&rep.expr),
134-
_ => false,
135-
}
136-
}
137-
138140
let discriminator = match args.overrides.and_then(|ov| ov.discriminator) {
139141
Some(discrim) => {
140142
let zero_err = is_zeroed_discriminator(&discrim).then(||
141-
quote_spanned! {discrim.span() => compile_error!("all-zero discriminators are not supported");}
143+
quote_spanned! {discrim.span() => compile_error!("all-zero or empty discriminators are not supported");}
142144
);
143145
quote! {
144146
{
@@ -643,3 +645,41 @@ pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
643645
#[allow(unreachable_code)]
644646
proc_macro::TokenStream::from(ret)
645647
}
648+
649+
#[cfg(test)]
650+
mod tests {
651+
use super::*;
652+
653+
#[allow(clippy::expect_used)]
654+
fn zeroed(source: &str) -> bool {
655+
let expr = syn::parse_str(source).expect("test expression should parse");
656+
is_zeroed_discriminator(&expr)
657+
}
658+
659+
#[test]
660+
fn detects_zeroed_discriminator_literals() {
661+
assert!(zeroed("0"));
662+
assert!(zeroed("b'\\x00'"));
663+
assert!(zeroed("b\"\""));
664+
assert!(zeroed("b\"\\x00\\x00\""));
665+
666+
assert!(!zeroed("1"));
667+
assert!(!zeroed("b'a'"));
668+
assert!(!zeroed("b\"\\x00\\x01\""));
669+
assert!(!zeroed("\"\""));
670+
}
671+
672+
#[test]
673+
fn detects_zeroed_discriminator_collections() {
674+
assert!(zeroed("&[0, (0)]"));
675+
assert!(zeroed("[]"));
676+
assert!(zeroed("[0; N]"));
677+
assert!(zeroed("[1; 0]"));
678+
assert!(zeroed("(&b\"\\x00\" )"));
679+
680+
assert!(!zeroed("&[0, 1]"));
681+
assert!(!zeroed("[1; N]"));
682+
assert!(!zeroed("MY_DISC"));
683+
assert!(!zeroed("get_disc()"));
684+
}
685+
}

0 commit comments

Comments
 (0)