Skip to content

Commit 9f992fe

Browse files
feboilya-bobyr
andauthored
pubkey: Optional bump parameter (#196)
* Remove separate bump parameter * Add optional bump * Tweak docs * Tweaks * More tweaks * Update sdk/pubkey/src/lib.rs Co-authored-by: Illia Bobyr <[email protected]> --------- Co-authored-by: Illia Bobyr <[email protected]>
1 parent f92c336 commit 9f992fe

File tree

1 file changed

+59
-45
lines changed

1 file changed

+59
-45
lines changed

sdk/pubkey/src/lib.rs

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,68 +14,77 @@ use pinocchio::pubkey::{Pubkey, MAX_SEEDS, PDA_MARKER};
1414
#[cfg(target_os = "solana")]
1515
use pinocchio::syscalls::sol_sha256;
1616

17-
/// Derive a [program address][pda] from the given seeds, bump and program id.
17+
/// Derive a [program address][pda] from the given seeds, optional bump and
18+
/// program id.
1819
///
1920
/// [pda]: https://solana.com/docs/core/pda
2021
///
21-
/// This function avoids the cost of the `create_program_address` syscall,
22-
/// which is `1500` compute units, by directly computing the derived address
23-
/// calculating the hash of the seeds, bump, and program id using the
24-
/// `sol_sha256` syscall.
22+
/// In general, the derivation uses an optional bump (byte) value to ensure a
23+
/// valid PDA (off-curve) is generated. Even when a program stores a bump to
24+
/// derive a program address, it is necessary to use the
25+
/// [`pinocchio::pubkey::create_program_address`] to validate the derivation. In
26+
/// most cases, the program has the correct seeds for the derivation, so it would
27+
/// be sufficient to just perform the derivation and compare it against the
28+
/// expected resulting address.
2529
///
26-
/// Even when a program stores a bump to derive a program address, it is necessary
27-
/// to use the [`pinocchio::pubkey::create_program_address`] to validate the
28-
/// derivation. In most cases, the program has the correct seeds for the derivation,
29-
/// so it would be sufficient to just perform the derivation and compare it against
30-
/// the expected resulting address.
30+
/// This function avoids the cost of the `create_program_address` syscall
31+
/// (`1500` compute units) by directly computing the derived address
32+
/// calculating the hash of the seeds, bump and program id using the
33+
/// `sol_sha256` syscall.
3134
///
3235
/// # Important
3336
///
34-
/// This function differs from [`pinocchio::pubkey::create_program_address`] in that it
35-
/// does not perform a validation to ensure that the derived address is a valid
37+
/// This function differs from [`pinocchio::pubkey::create_program_address`] in that
38+
/// it does not perform a validation to ensure that the derived address is a valid
3639
/// (off-curve) program derived address. It is intended for use in cases where the
3740
/// seeds, bump, and program id are known to be valid, and the caller wants to derive
3841
/// the address without incurring the cost of the `create_program_address` syscall.
39-
pub fn derive_address<const N: usize>(seeds: &[&[u8]; N], bump: u8, program_id: &Pubkey) -> Pubkey {
42+
pub fn derive_address<const N: usize>(
43+
seeds: &[&[u8]; N],
44+
bump: Option<u8>,
45+
program_id: &Pubkey,
46+
) -> Pubkey {
4047
const {
41-
assert!(
42-
N <= MAX_SEEDS,
43-
"number of seeds must be less than MAX_SEEDS"
44-
);
48+
assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
4549
}
4650

4751
const UNINIT: MaybeUninit<&[u8]> = MaybeUninit::<&[u8]>::uninit();
48-
let mut data = [UNINIT; MAX_SEEDS + 3];
52+
let mut data = [UNINIT; MAX_SEEDS + 2];
4953
let mut i = 0;
5054

5155
while i < N {
52-
// SAFETY: `data` is guanranteed to have enough space for `N` seeds,
56+
// SAFETY: `data` is guaranteed to have enough space for `N` seeds,
5357
// so `i` will always be within bounds.
5458
unsafe {
5559
data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
5660
}
5761
i += 1;
5862
}
5963

60-
let bump = [bump];
64+
// TODO: replace this with `as_slice` when the MSRV is upgraded
65+
// to `1.84.0+`.
66+
let bump_seed = [bump.unwrap_or_default()];
6167

62-
// SAFETY: `data` is guaranteed to have enough space for `MAX_SEEDS + 3`
68+
// SAFETY: `data` is guaranteed to have enough space for `MAX_SEEDS + 2`
6369
// elements, and `MAX_SEEDS` is as large as `N`.
6470
unsafe {
65-
data.get_unchecked_mut(i).write(bump.as_ref());
66-
data.get_unchecked_mut(i + 1).write(program_id.as_ref());
67-
data.get_unchecked_mut(i + 2).write(PDA_MARKER.as_ref());
71+
if bump.is_some() {
72+
data.get_unchecked_mut(i).write(&bump_seed);
73+
i += 1;
74+
}
75+
data.get_unchecked_mut(i).write(program_id.as_ref());
76+
data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
6877
}
6978

7079
#[cfg(target_os = "solana")]
7180
{
7281
let mut pda = MaybeUninit::<[u8; 32]>::uninit();
7382

74-
// SAFETY: `data` has `i + 3` elements initialized.
83+
// SAFETY: `data` has `i + 2` elements initialized.
7584
unsafe {
7685
sol_sha256(
7786
data.as_ptr() as *const u8,
78-
(N + 3) as u64,
87+
(i + 2) as u64,
7988
pda.as_mut_ptr() as *mut u8,
8089
);
8190
}
@@ -88,20 +97,24 @@ pub fn derive_address<const N: usize>(seeds: &[&[u8]; N], bump: u8, program_id:
8897
unreachable!("deriving a pda is only available on target `solana`");
8998
}
9099

91-
/// Derive a [program address][pda] from the given seeds, bump and program id.
100+
/// Derive a [program address][pda] from the given seeds, optional bump and
101+
/// program id.
92102
///
93103
/// [pda]: https://solana.com/docs/core/pda
94104
///
95-
/// This function avoids the cost of the `create_program_address` syscall,
96-
/// which is `1500` compute units, by directly computing the derived address
97-
/// using the SHA-256 hash of the seeds, bump, and program id.
105+
/// In general, the derivation uses an optional bump (byte) value to ensure a
106+
/// valid PDA (off-curve) is generated.
98107
///
99-
/// This function is intended for use in `const` contexts.
108+
/// This function is intended for use in `const` contexts - i.e., the seeds and
109+
/// bump are known at compile time and the program id is also a constant. It avoids
110+
/// the cost of the `create_program_address` syscall (`1500` compute units) by
111+
/// directly computing the derived address using the SHA-256 hash of the seeds,
112+
/// bump and program id.
100113
///
101114
/// # Important
102115
///
103-
/// This function differs from [`pinocchio::pubkey::create_program_address`] in that it
104-
/// does not perform a validation to ensure that the derived address is a valid
116+
/// This function differs from [`pinocchio::pubkey::create_program_address`] in that
117+
/// it does not perform a validation to ensure that the derived address is a valid
105118
/// (off-curve) program derived address. It is intended for use in cases where the
106119
/// seeds, bump, and program id are known to be valid, and the caller wants to derive
107120
/// the address without incurring the cost of the `create_program_address` syscall.
@@ -110,14 +123,11 @@ pub fn derive_address<const N: usize>(seeds: &[&[u8]; N], bump: u8, program_id:
110123
#[cfg(feature = "const")]
111124
pub const fn derive_address_const<const N: usize>(
112125
seeds: &[&[u8]; N],
113-
bump: u8,
126+
bump: Option<u8>,
114127
program_id: &Pubkey,
115128
) -> Pubkey {
116129
const {
117-
assert!(
118-
N <= MAX_SEEDS,
119-
"number of seeds must be less than MAX_SEEDS"
120-
);
130+
assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
121131
}
122132

123133
let mut hasher = Sha256::new();
@@ -128,13 +138,17 @@ pub const fn derive_address_const<const N: usize>(
128138
i += 1;
129139
}
130140

131-
let bump = [bump];
132-
133-
hasher
134-
.update(&bump)
135-
.update(program_id)
136-
.update(PDA_MARKER)
137-
.finalize()
141+
// TODO: replace this with `is_some` when the MSRV is upgraded
142+
// to `1.84.0+`.
143+
if let Some(bump) = bump {
144+
hasher
145+
.update(&[bump])
146+
.update(program_id)
147+
.update(PDA_MARKER)
148+
.finalize()
149+
} else {
150+
hasher.update(program_id).update(PDA_MARKER).finalize()
151+
}
138152
}
139153

140154
/// Convenience macro to define a static `Pubkey` value.

0 commit comments

Comments
 (0)