Skip to content

Commit 0b21f27

Browse files
authored
Merge pull request #87 from staratlasmeta/stegaBOB/fix/args
Fix: account set args
2 parents f2236d8 + c922113 commit 0b21f27

File tree

6 files changed

+52
-14
lines changed

6 files changed

+52
-14
lines changed

framework/star_frame/src/account_set/data_account.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ where
124124
T: ProgramAccount + UnsizedType + ?Sized,
125125
{
126126
/// Validates the owner and the discriminant of the account.
127+
#[inline]
127128
pub fn validate(&self) -> Result<()> {
128129
if self.owner() != &T::OwnerProgram::PROGRAM_ID {
129130
bail!(ProgramError::IllegalOwner);
@@ -133,6 +134,7 @@ where
133134
Ok(())
134135
}
135136

137+
#[inline]
136138
fn check_discriminant(bytes: &[u8]) -> Result<()> {
137139
if bytes.len() < size_of::<<T::OwnerProgram as StarFrameProgram>::AccountDiscriminant>()
138140
|| from_bytes::<PackedValue<<T::OwnerProgram as StarFrameProgram>::AccountDiscriminant>>(
@@ -144,6 +146,7 @@ where
144146
Ok(())
145147
}
146148

149+
#[inline]
147150
pub fn data<'a>(&'a self) -> Result<RefWrapper<AccountInfoRef<'a>, T::RefData>> {
148151
let r: Ref<'a, _> = self.info_data_bytes()?;
149152
Self::check_discriminant(&r)?;
@@ -158,6 +161,7 @@ where
158161
T::from_bytes(account_info_ref).map(|ret| ret.ref_wrapper)
159162
}
160163

164+
#[inline]
161165
pub fn data_mut<'a>(
162166
&'a mut self,
163167
) -> Result<RefWrapper<AccountInfoRefMut<'a, 'info, T::OwnerProgram>, T::RefData>> {

framework/star_frame/src/util.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::fmt::Debug;
88
use std::mem::size_of;
99

1010
/// Similar to [`Ref::map`], but the closure can return an error.
11+
#[inline]
1112
pub fn try_map_ref<'a, I: 'a + ?Sized, O: 'a + ?Sized, E>(
1213
r: Ref<'a, I>,
1314
f: impl FnOnce(&I) -> Result<&O, E>,
@@ -20,7 +21,8 @@ pub fn try_map_ref<'a, I: 'a + ?Sized, O: 'a + ?Sized, E>(
2021
}
2122
}
2223

23-
/// Similar to [`RefMut::map`], but the closure can return an error.
24+
/// Similar to [`RefMut::map`], but the closure can return an error
25+
#[inline]
2426
pub fn try_map_ref_mut<'a, I: 'a + ?Sized, O: 'a + ?Sized, E>(
2527
mut r: RefMut<'a, I>,
2628
f: impl FnOnce(&mut I) -> Result<&mut O, E>,
@@ -60,6 +62,7 @@ unsafe impl<S> RefBytes<S> for OffsetRef
6062
where
6163
S: AsBytes,
6264
{
65+
#[inline]
6366
fn bytes(wrapper: &RefWrapper<S, Self>) -> Result<&[u8]> {
6467
let mut bytes = wrapper.sup().as_bytes()?;
6568
bytes.try_advance(wrapper.r().0)?;
@@ -70,6 +73,7 @@ unsafe impl<S> RefBytesMut<S> for OffsetRef
7073
where
7174
S: AsMutBytes,
7275
{
76+
#[inline]
7377
fn bytes_mut(wrapper: &mut RefWrapper<S, Self>) -> Result<&mut [u8]> {
7478
let (sup, r) = unsafe { wrapper.s_r_mut() };
7579
let mut bytes = sup.as_mut_bytes()?;
@@ -79,6 +83,7 @@ where
7983
}
8084

8185
/// Returns a slice of bytes from an array of [`NoUninit`] types.
86+
#[inline]
8287
pub fn uninit_array_bytes<T: NoUninit, const N: usize>(array: &[T; N]) -> &[u8] {
8388
// Safety: `T` is `NoUninit`, so all underlying reads are valid since there's no padding
8489
// between array elements. The pointer is valid. The entire memory is valid.

framework/star_frame_proc/src/account_set/struct_impl/cleanup.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,12 @@ pub(super) fn cleanups(
143143
arg: #cleanup_type,
144144
syscalls: &mut impl #syscall_invoke<#info_lifetime>,
145145
) -> #result<()> {
146-
#(<#field_type as #account_set_cleanup<#info_lifetime, _>>::cleanup_accounts(&mut self.#field_name, #cleanup_args, syscalls)?;)*
146+
#(
147+
{
148+
let __cleanup_arg = #cleanup_args;
149+
<#field_type as #account_set_cleanup<#info_lifetime, _>>::cleanup_accounts(&mut self.#field_name, __cleanup_arg, syscalls)?;
150+
}
151+
)*
147152
#extra_cleanup
148153
Ok(())
149154
}

framework/star_frame_proc/src/account_set/struct_impl/decode.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@ pub(super) fn decodes(
165165
let decode_inner = init(&mut decode_field_ty.iter().zip_eq(&decode_args).map(|(field_ty, decode_args)| {
166166
match &field_ty {
167167
DecodeFieldTy::Type(field_type) => quote! {
168-
<#field_type as #account_set_decode<#decode_lifetime, #info_lifetime, _>>::decode_accounts(accounts, #decode_args, syscalls)?
168+
{
169+
let __arg = #decode_args;
170+
<#field_type as #account_set_decode<#decode_lifetime, #info_lifetime, _>>::decode_accounts(accounts, __arg, syscalls)?
171+
}
169172
},
170173
DecodeFieldTy::Default(default) => quote!(#default)
171174
}

framework/star_frame_proc/src/account_set/struct_impl/validate.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,15 @@ pub(super) fn validates(
182182
quote! {}
183183
} else {
184184
let address_check = validate_address.as_ref().map(|address| quote! {
185-
<#field_type as #prelude::SingleAccountSet<#info_lifetime>>::check_key(&self.#field_name, #address)?;
185+
let __address = #address;
186+
<#field_type as #prelude::SingleAccountSet<#info_lifetime>>::check_key(&self.#field_name, __address)?;
186187
});
187188
quote! {
188-
#address_check
189-
<#field_type as #account_set_validate<#info_lifetime, #validate_ty>>::validate_accounts(&mut self.#field_name, #validate_arg, syscalls)?;
189+
{
190+
#address_check
191+
let __validate_arg = #validate_arg;
192+
<#field_type as #account_set_validate<#info_lifetime, #validate_ty>>::validate_accounts(&mut self.#field_name, __validate_arg, syscalls)?;
193+
}
190194
}
191195
})
192196
.collect::<Vec<_>>();

framework/star_frame_spl/src/token/state.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use star_frame::account_set::AccountSet;
55
use star_frame::anyhow::{bail, Context};
66
use star_frame::bytemuck;
77
use star_frame::prelude::*;
8+
use star_frame::util::try_map_ref;
89
use std::cell::Ref;
910

1011
/// A wrapper around `AccountInfo` for the [`spl_token::state::Mint`] account.
@@ -51,6 +52,7 @@ impl<'info> MintAccount<'info> {
5152
/// ```
5253
pub const LEN: usize = 82;
5354

55+
#[inline]
5456
pub fn validate(&mut self) -> Result<()> {
5557
if self.validated {
5658
return Ok(());
@@ -62,20 +64,27 @@ impl<'info> MintAccount<'info> {
6264
if self.info_data_bytes()?.len() != Self::LEN {
6365
bail!(ProgramError::InvalidAccountData);
6466
}
67+
// set validate before checking state to allow us to call .data()
68+
self.validated = true;
6569
// check initialized
6670
if !self.data()?.is_initialized {
6771
bail!(ProgramError::UninitializedAccount);
6872
}
69-
self.validated = true;
7073
Ok(())
7174
}
7275

76+
#[inline]
7377
pub fn data(&self) -> Result<Ref<MintData>> {
74-
Ok(Ref::map(self.info_data_bytes()?, |data| {
75-
bytemuck::checked::from_bytes::<MintData>(data)
76-
}))
78+
if !self.validated {
79+
return Err(ProgramError::InvalidAccountData)
80+
.context("Called `.data()` on MintAccount before validation");
81+
}
82+
Ok(try_map_ref(self.info_data_bytes()?, |data| {
83+
bytemuck::checked::try_from_bytes::<MintData>(data)
84+
})?)
7785
}
7886

87+
#[inline]
7988
pub fn validate_mint(&self, validate_mint: ValidateMint) -> Result<()> {
8089
let data = self.data()?;
8190
if let Some(decimals) = validate_mint.decimals {
@@ -266,6 +275,7 @@ impl<'info> TokenAccount<'info> {
266275
/// ```
267276
pub const LEN: usize = 165;
268277

278+
#[inline]
269279
pub fn validate(&mut self) -> Result<()> {
270280
if self.validated {
271281
return Ok(());
@@ -277,19 +287,26 @@ impl<'info> TokenAccount<'info> {
277287
if self.info_data_bytes()?.len() != Self::LEN {
278288
bail!(ProgramError::InvalidAccountData);
279289
}
290+
// set validate before checking state to allow us to call .data()
291+
self.validated = true;
280292
if self.data()?.state == AccountState::Uninitialized {
281293
bail!(ProgramError::UninitializedAccount);
282294
}
283-
self.validated = true;
284295
Ok(())
285296
}
286297

298+
#[inline]
287299
pub fn data(&self) -> Result<Ref<TokenAccountData>> {
288-
Ok(Ref::map(self.info_data_bytes()?, |data| {
289-
bytemuck::checked::from_bytes::<TokenAccountData>(data)
290-
}))
300+
if !self.validated {
301+
return Err(ProgramError::InvalidAccountData)
302+
.context("Called `.data()` on TokenAccount before validation");
303+
}
304+
Ok(try_map_ref(self.info_data_bytes()?, |data| {
305+
bytemuck::checked::try_from_bytes::<TokenAccountData>(data)
306+
})?)
291307
}
292308

309+
#[inline]
293310
pub fn validate_token(&self, validate_token: ValidateToken) -> Result<()> {
294311
let data = self.data()?;
295312
if let Some(mint) = validate_token.mint {

0 commit comments

Comments
 (0)