diff --git a/programs/token-2022/src/state/account_state.rs b/programs/token-2022/src/state/account_state.rs index 5380c96f..6f98b852 100644 --- a/programs/token-2022/src/state/account_state.rs +++ b/programs/token-2022/src/state/account_state.rs @@ -34,3 +34,28 @@ impl From for u8 { } } } + +/// Different kinds of accounts. Note that `Mint`, `TokenAccount`, and `Multisig` +/// types are determined exclusively by the size of the account, and are not +/// included in the account data. `AccountType` is only included if extensions +/// have been initialized. +#[repr(u8)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum AccountType { + /// Marker for 0 data + Uninitialized, + /// Mint account with additional extensions + Mint, + /// Token holding account with additional extensions + TokenAccount, +} + +impl From for u8 { + fn from(value: AccountType) -> Self { + match value { + AccountType::Uninitialized => 0, + AccountType::Mint => 1, + AccountType::TokenAccount => 2, + } + } +} diff --git a/programs/token-2022/src/state/mint.rs b/programs/token-2022/src/state/mint.rs index ba47c796..d0cdc35d 100644 --- a/programs/token-2022/src/state/mint.rs +++ b/programs/token-2022/src/state/mint.rs @@ -1,10 +1,14 @@ +use super::AccountType; use pinocchio::{ account_info::{AccountInfo, Ref}, program_error::ProgramError, pubkey::Pubkey, }; -use crate::ID; +use crate::{ + state::{Multisig, TokenAccount}, + ID, +}; /// Mint data. #[repr(C)] @@ -44,13 +48,26 @@ impl Mint { /// the account data. #[inline] pub fn from_account_info(account_info: &AccountInfo) -> Result, ProgramError> { - if account_info.data_len() < Self::BASE_LEN { + let len = account_info.data_len(); + let data = account_info.try_borrow_data()?; + + if len < Self::BASE_LEN || len == Multisig::LEN { return Err(ProgramError::InvalidAccountData); } + if len > Self::BASE_LEN { + if len > TokenAccount::BASE_LEN { + let byte = data[Self::BASE_LEN]; + if byte != AccountType::Mint.into() && byte != AccountType::Uninitialized.into() { + return Err(ProgramError::InvalidAccountData); + } + } else { + return Err(ProgramError::InvalidAccountData); + } + } if !account_info.is_owned_by(&ID) { return Err(ProgramError::InvalidAccountOwner); } - Ok(Ref::map(account_info.try_borrow_data()?, |data| unsafe { + Ok(Ref::map(data, |data| unsafe { Self::from_bytes_unchecked(data) })) } @@ -68,15 +85,26 @@ impl Mint { pub unsafe fn from_account_info_unchecked( account_info: &AccountInfo, ) -> Result<&Self, ProgramError> { - if account_info.data_len() < Self::BASE_LEN { + let len = account_info.data_len(); + let data = account_info.borrow_data_unchecked(); + + if len < Self::BASE_LEN || len == Multisig::LEN { return Err(ProgramError::InvalidAccountData); } + if len > Self::BASE_LEN { + if len > TokenAccount::BASE_LEN { + let byte = data[Self::BASE_LEN]; + if byte != AccountType::Mint.into() && byte != AccountType::Uninitialized.into() { + return Err(ProgramError::InvalidAccountData); + } + } else { + return Err(ProgramError::InvalidAccountData); + } + } if account_info.owner() != &ID { return Err(ProgramError::InvalidAccountOwner); } - Ok(Self::from_bytes_unchecked( - account_info.borrow_data_unchecked(), - )) + Ok(Self::from_bytes_unchecked(data)) } /// Return a `Mint` from the given bytes. diff --git a/programs/token-2022/src/state/token.rs b/programs/token-2022/src/state/token.rs index 56f3afec..72e184d9 100644 --- a/programs/token-2022/src/state/token.rs +++ b/programs/token-2022/src/state/token.rs @@ -1,11 +1,11 @@ -use super::AccountState; +use super::{AccountState, AccountType}; use pinocchio::{ account_info::{AccountInfo, Ref}, program_error::ProgramError, pubkey::Pubkey, }; -use crate::ID; +use crate::{state::Multisig, ID}; /// Token account data. #[repr(C)] @@ -59,13 +59,23 @@ impl TokenAccount { pub fn from_account_info( account_info: &AccountInfo, ) -> Result, ProgramError> { - if account_info.data_len() < Self::BASE_LEN { + let len = account_info.data_len(); + let data = account_info.try_borrow_data()?; + + if len < Self::BASE_LEN || len == Multisig::LEN { return Err(ProgramError::InvalidAccountData); } + if len > Self::BASE_LEN { + let byte = data[Self::BASE_LEN]; + if byte != AccountType::TokenAccount.into() && byte != AccountType::Uninitialized.into() + { + return Err(ProgramError::InvalidAccountData); + } + } if !account_info.is_owned_by(&ID) { return Err(ProgramError::InvalidAccountData); } - Ok(Ref::map(account_info.try_borrow_data()?, |data| unsafe { + Ok(Ref::map(data, |data| unsafe { Self::from_bytes_unchecked(data) })) } @@ -83,15 +93,23 @@ impl TokenAccount { pub unsafe fn from_account_info_unchecked( account_info: &AccountInfo, ) -> Result<&TokenAccount, ProgramError> { - if account_info.data_len() < Self::BASE_LEN { + let len = account_info.data_len(); + let data = account_info.borrow_data_unchecked(); + + if len < Self::BASE_LEN || len == Multisig::LEN { return Err(ProgramError::InvalidAccountData); } + if len > Self::BASE_LEN { + let byte = data[Self::BASE_LEN]; + if byte != AccountType::TokenAccount.into() && byte != AccountType::Uninitialized.into() + { + return Err(ProgramError::InvalidAccountData); + } + } if account_info.owner() != &ID { return Err(ProgramError::InvalidAccountData); } - Ok(Self::from_bytes_unchecked( - account_info.borrow_data_unchecked(), - )) + Ok(Self::from_bytes_unchecked(data)) } /// Return a `TokenAccount` from the given bytes.