Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 87 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pinocchio::{
ProgramResult, account_info::AccountInfo, entrypoint, program_error::ProgramError,
pubkey::Pubkey, sysvars::Sysvar,
account_info::AccountInfo, entrypoint, program_error::ProgramError, pubkey::Pubkey,
sysvars::Sysvar, ProgramResult,
};
use pinocchio_system::instructions::CreateAccount;

Expand Down Expand Up @@ -34,17 +34,47 @@ pub struct Punchcard<'a> {
pub bits: Bits<'a>,
}

const PUNCHCARD_HEADER_LEN: usize = size_of::<PunchcardHeader>();

fn bitset_len(capacity: u64) -> Option<usize> {
let capacity = usize::try_from(capacity).ok()?;
capacity.checked_add(7).map(|value| value / 8)
}

impl<'a> Punchcard<'a> {
pub fn space(capacity: u64) -> usize {
size_of::<PunchcardHeader>() + ((capacity as usize + 7) / 8)
pub fn space(capacity: u64) -> Option<usize> {
let bits_len = bitset_len(capacity)?;
PUNCHCARD_HEADER_LEN.checked_add(bits_len)
}

pub fn from_bytes(data: &'a mut [u8]) -> Self {
let (header, bits) = data.split_at_mut(size_of::<PunchcardHeader>());
Self {
header: bytemuck::from_bytes_mut(header),
bits: Bits(bits),
fn split(data: &'a mut [u8]) -> Result<(&'a mut PunchcardHeader, &'a mut [u8]), ProgramError> {
if data.len() < PUNCHCARD_HEADER_LEN {
return Err(ProgramError::InvalidAccountData);
}

let (header, bits) = data.split_at_mut(PUNCHCARD_HEADER_LEN);
let header =
bytemuck::try_from_bytes_mut(header).map_err(|_| ProgramError::InvalidAccountData)?;

Ok((header, bits))
}

pub fn from_bytes(data: &'a mut [u8]) -> Result<Self, ProgramError> {
let (header, bits) = Self::split(data)?;
let expected_bits_len =
bitset_len(header.capacity).ok_or(Error::InvalidCapacity.into_program_error())?;

if bits.len() != expected_bits_len {
return Err(ProgramError::InvalidAccountData);
}
if header.claimed > header.capacity {
return Err(ProgramError::InvalidAccountData);
}

Ok(Self {
header,
bits: Bits(bits),
})
}

pub fn claim(&mut self, index: u64) -> ProgramResult {
Expand Down Expand Up @@ -72,6 +102,7 @@ pub enum Error {
InvalidAuthority = 0,
IndexOutOfBounds = 1,
AlreadyClaimed = 2,
InvalidCapacity = 3,
}

impl Error {
Expand All @@ -97,7 +128,7 @@ fn create(program_id: &Pubkey, accounts: &[AccountInfo], capacity: u64) -> Progr
return Err(ProgramError::NotEnoughAccountKeys);
};

let space = Punchcard::space(capacity);
let space = Punchcard::space(capacity).ok_or(Error::InvalidCapacity.into_program_error())?;
let rent = pinocchio::sysvars::rent::Rent::get()?.minimum_balance(space);

CreateAccount {
Expand All @@ -110,10 +141,11 @@ fn create(program_id: &Pubkey, accounts: &[AccountInfo], capacity: u64) -> Progr
.invoke()?;

let mut data = punchcard.try_borrow_mut_data()?;
let card = Punchcard::from_bytes(&mut data);
card.header.authority = *payer.key();
card.header.capacity = capacity;
card.header.claimed = 0;
let (header, bits) = Punchcard::split(&mut data)?;
header.authority = *payer.key();
header.capacity = capacity;
header.claimed = 0;
bits.fill(0);

Ok(())
}
Expand All @@ -132,7 +164,7 @@ fn claim(program_id: &Pubkey, accounts: &[AccountInfo], indices: &[u64]) -> Prog

let (capacity, claimed) = {
let mut data = punchcard.try_borrow_mut_data()?;
let mut card = Punchcard::from_bytes(&mut data);
let mut card = Punchcard::from_bytes(&mut data)?;

if card.header.authority != *authority.key() {
return Err(Error::InvalidAuthority.into_program_error());
Expand All @@ -158,3 +190,43 @@ fn claim(program_id: &Pubkey, accounts: &[AccountInfo], indices: &[u64]) -> Prog

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn space_rejects_overflowing_capacities() {
assert_eq!(Punchcard::space(u64::MAX), None);
assert_eq!(Punchcard::space(u64::MAX - 1), None);
assert_eq!(Punchcard::space(u64::MAX - 6), None);
assert!(Punchcard::space(u64::MAX - 7).is_some());
}

#[test]
fn from_bytes_rejects_mismatched_bitset_length() {
let mut data = vec![0u8; PUNCHCARD_HEADER_LEN + 1];
let (header, _) = data.split_at_mut(PUNCHCARD_HEADER_LEN);
let header = bytemuck::from_bytes_mut::<PunchcardHeader>(header);
header.capacity = 0;
header.claimed = 0;

assert!(matches!(
Punchcard::from_bytes(&mut data),
Err(ProgramError::InvalidAccountData)
));
}

#[test]
fn from_bytes_rejects_claimed_greater_than_capacity() {
let mut data = vec![0u8; PUNCHCARD_HEADER_LEN];
let header = bytemuck::from_bytes_mut::<PunchcardHeader>(&mut data[..PUNCHCARD_HEADER_LEN]);
header.capacity = 0;
header.claimed = 1;

assert!(matches!(
Punchcard::from_bytes(&mut data),
Err(ProgramError::InvalidAccountData)
));
}
}