Skip to content

Commit a798efb

Browse files
glihmpiotmag769
andauthored
get_class_hash cheatcode (#441)
<!-- Reference any GitHub issues resolved by this PR --> Closes #335 ## Introduced changes Add a cheatcode to get the class hash associated to a given address: ```rust let class_hash = declare(...); // Prepare let contract_address = deploy(...); assert(get_class_hash(contract_address) == class_hash, 'Wrong class hash'); ``` ## Breaking changes <!-- List of all breaking changes, if applicable --> ## Checklist <!-- Make sure all of these are complete --> - [X] Linked relevant issue - [X] Updated relevant documentation - [X] Added relevant tests - [X] Performed self-review of the code - [X] Added changes to `CHANGELOG.md` --------- Co-authored-by: Piotr Magiera <[email protected]>
1 parent e7e44ad commit a798efb

File tree

12 files changed

+252
-36
lines changed

12 files changed

+252
-36
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- Documentation to public methods
1616
- Information sections to documentation about importing `snforge_std`
1717
- Added print support for basic numeric data types
18+
- `get_class_hash` cheatcode
1819

1920
#### Changed
2021

crates/cheatnet/src/cheatcodes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use thiserror::Error;
1010

1111
pub mod declare;
1212
pub mod deploy;
13+
pub mod get_class_hash;
1314
pub mod prank;
1415
pub mod roll;
1516
pub mod warp;
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
use crate::{
2+
cheatcodes::{CheatcodeError, EnhancedHintError},
3+
CheatnetState,
4+
};
5+
use blockifier::state::state_api::StateReader;
6+
use starknet_api::core::{ClassHash, ContractAddress};
7+
8+
impl CheatnetState {
9+
/// Gets the class hash at the given address.
10+
pub fn get_class_hash(
11+
&mut self,
12+
contract_address: ContractAddress,
13+
) -> Result<ClassHash, CheatcodeError> {
14+
match self.blockifier_state.get_class_hash_at(contract_address) {
15+
Ok(class_hash) => Ok(class_hash),
16+
Err(e) => Err(CheatcodeError::Unrecoverable(EnhancedHintError::State(e))),
17+
}
18+
}
19+
}

crates/forge/src/cheatcodes_hint_processor.rs

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use std::path::PathBuf;
44

55
use crate::scarb::StarknetContractArtifacts;
66
use anyhow::{anyhow, Result};
7-
use blockifier::execution::execution_utils::stark_felt_to_felt;
7+
use blockifier::execution::execution_utils::{felt_to_stark_felt, stark_felt_to_felt};
8+
89
use cairo_felt::Felt252;
910
use cairo_vm::hint_processor::hint_processor_definition::HintProcessorLogic;
1011
use cairo_vm::hint_processor::hint_processor_definition::HintReference;
@@ -18,7 +19,6 @@ use cheatnet::{
1819
cheatcodes::{CheatcodeError, ContractArtifacts, EnhancedHintError},
1920
CheatnetState,
2021
};
21-
use num_traits::Num;
2222
use num_traits::ToPrimitive;
2323
use serde::Deserialize;
2424
use starknet_api::core::{ContractAddress, PatriciaKey};
@@ -237,8 +237,7 @@ impl CairoHintProcessor<'_> {
237237
.collect(),
238238
) {
239239
Ok(class_hash) => {
240-
let felt_class_hash =
241-
felt252_from_hex_string(&class_hash.to_string()).unwrap();
240+
let felt_class_hash = stark_felt_to_felt(class_hash.0);
242241

243242
buffer
244243
.write(Felt252::from(0))
@@ -287,6 +286,22 @@ impl CairoHintProcessor<'_> {
287286
print(inputs);
288287
Ok(())
289288
}
289+
"get_class_hash" => {
290+
let contract_address = contract_address_from_felt252(&inputs[0])?;
291+
292+
match self.cheatnet_state.get_class_hash(contract_address) {
293+
Ok(class_hash) => {
294+
let felt_class_hash = stark_felt_to_felt(class_hash.0);
295+
296+
buffer
297+
.write(felt_class_hash)
298+
.expect("Failed to insert contract class hash");
299+
Ok(())
300+
}
301+
Err(CheatcodeError::Recoverable(_)) => unreachable!(),
302+
Err(CheatcodeError::Unrecoverable(err)) => Err(err),
303+
}
304+
}
290305
_ => Err(anyhow!("Unknown cheatcode selector: {selector}")).map_err(Into::into),
291306
}?;
292307

@@ -376,12 +391,6 @@ fn print(inputs: Vec<Felt252>) {
376391
}
377392
}
378393

379-
fn felt252_from_hex_string(value: &str) -> Result<Felt252> {
380-
let stripped_value = value.replace("0x", "");
381-
Felt252::from_str_radix(&stripped_value, 16)
382-
.map_err(|_| anyhow!("Failed to convert value = {value} to Felt252"))
383-
}
384-
385394
fn write_cheatcode_panic(buffer: &mut MemBuffer, panic_data: &[Felt252]) {
386395
buffer.write(1).expect("Failed to insert err code");
387396
buffer
@@ -392,30 +401,8 @@ fn write_cheatcode_panic(buffer: &mut MemBuffer, panic_data: &[Felt252]) {
392401
.expect("Failed to insert error in memory");
393402
}
394403

395-
#[cfg(test)]
396-
mod test {
397-
use super::*;
398-
399-
#[test]
400-
fn felt_2525_from_prefixed_hex() {
401-
assert_eq!(
402-
felt252_from_hex_string("0x1234").unwrap(),
403-
Felt252::from(0x1234)
404-
);
405-
}
406-
407-
#[test]
408-
fn felt_2525_from_non_prefixed_hex() {
409-
assert_eq!(
410-
felt252_from_hex_string("1234").unwrap(),
411-
Felt252::from(0x1234)
412-
);
413-
}
414-
415-
#[test]
416-
fn felt_252_err_on_failed_conversion() {
417-
let result = felt252_from_hex_string("yyyy");
418-
let err = result.unwrap_err();
419-
assert_eq!(err.to_string(), "Failed to convert value = yyyy to Felt252");
420-
}
404+
fn contract_address_from_felt252(felt: &Felt252) -> Result<ContractAddress, EnhancedHintError> {
405+
Ok(ContractAddress(PatriciaKey::try_from(felt_to_stark_felt(
406+
felt,
407+
))?))
421408
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use starknet::ClassHash;
2+
3+
#[starknet::interface]
4+
trait IUpgradeable<T> {
5+
fn upgrade(ref self: T, class_hash: ClassHash);
6+
}
7+
8+
#[starknet::contract]
9+
mod GetClassHashCheckerUpg {
10+
11+
use starknet::ClassHash;
12+
use result::ResultTrait;
13+
14+
#[storage]
15+
struct Storage {
16+
inner: felt252,
17+
}
18+
19+
#[external(v0)]
20+
impl IUpgradeableImpl of super::IUpgradeable<ContractState> {
21+
fn upgrade(ref self: ContractState, class_hash: ClassHash) {
22+
_upgrade(class_hash);
23+
}
24+
}
25+
26+
fn _upgrade(class_hash: ClassHash) {
27+
match starknet::replace_class_syscall(class_hash) {
28+
Result::Ok(()) => {},
29+
Result::Err(e) => panic(e),
30+
};
31+
}
32+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
use crate::integration::common::corelib::{corelib_path, predeployed_contracts};
2+
use crate::integration::common::runner::Contract;
3+
use crate::{assert_passed, test_case};
4+
use camino::Utf8PathBuf;
5+
use forge::run;
6+
use indoc::indoc;
7+
use std::path::Path;
8+
9+
#[test]
10+
fn get_class_hash() {
11+
let test = test_case!(
12+
indoc!(
13+
r#"
14+
use array::ArrayTrait;
15+
use result::ResultTrait;
16+
use snforge_std::{ declare, PreparedContract, deploy, start_prank, get_class_hash };
17+
18+
#[test]
19+
fn test_get_class_hash() {
20+
let class_hash = declare('GetClassHashCheckerUpg');
21+
let prepared = PreparedContract { class_hash: class_hash, constructor_calldata: @ArrayTrait::new() };
22+
let contract_address = deploy(prepared).unwrap();
23+
assert(get_class_hash(contract_address) == class_hash, 'Incorrect class hash');
24+
}
25+
"#
26+
),
27+
Contract::from_code_path(
28+
"GetClassHashCheckerUpg".to_string(),
29+
Path::new("tests/data/contracts/get_class_hash_checker.cairo"),
30+
)
31+
.unwrap()
32+
);
33+
34+
let result = run(
35+
&test.path().unwrap(),
36+
&String::from("src"),
37+
&test.path().unwrap().join("src/lib.cairo"),
38+
&Some(test.linked_libraries()),
39+
&Default::default(),
40+
&corelib_path(),
41+
&test.contracts(&corelib_path()).unwrap(),
42+
&Utf8PathBuf::from_path_buf(predeployed_contracts().to_path_buf()).unwrap(),
43+
)
44+
.unwrap();
45+
46+
assert_passed!(result);
47+
}
48+
49+
#[test]
50+
fn get_class_hash_replace_class() {
51+
let test = test_case!(
52+
indoc!(
53+
r#"
54+
use array::{ArrayTrait, SpanTrait};
55+
use core::result::ResultTrait;
56+
use starknet::ClassHash;
57+
use snforge_std::{declare, deploy, get_class_hash, PreparedContract};
58+
59+
#[starknet::interface]
60+
trait IUpgradeable<T> {
61+
fn upgrade(ref self: T, class_hash: ClassHash);
62+
}
63+
64+
#[starknet::interface]
65+
trait IHelloStarknet<TContractState> {
66+
fn increase_balance(ref self: TContractState, amount: felt252);
67+
fn get_balance(self: @TContractState) -> felt252;
68+
fn do_a_panic(self: @TContractState);
69+
fn do_a_panic_with(self: @TContractState, panic_data: Array<felt252>);
70+
}
71+
72+
#[test]
73+
fn test_get_class_hash_replace_class() {
74+
let class_hash = declare('GetClassHashCheckerUpg');
75+
76+
let prepared = PreparedContract {
77+
class_hash: class_hash,
78+
constructor_calldata: @ArrayTrait::new()
79+
};
80+
81+
let contract_address = deploy(prepared).unwrap();
82+
83+
assert(get_class_hash(contract_address) == class_hash, 'Incorrect class hash');
84+
85+
let hsn_class_hash = declare('HelloStarknet');
86+
IUpgradeableDispatcher { contract_address }.upgrade(hsn_class_hash);
87+
assert(get_class_hash(contract_address) == hsn_class_hash, 'Incorrect upgrade class hash');
88+
89+
let hello_dispatcher = IHelloStarknetDispatcher { contract_address };
90+
hello_dispatcher.increase_balance(42);
91+
assert(hello_dispatcher.get_balance() == 42, 'Invalid balance');
92+
}
93+
"#
94+
),
95+
Contract::from_code_path(
96+
"GetClassHashCheckerUpg".to_string(),
97+
Path::new("tests/data/contracts/get_class_hash_checker.cairo"),
98+
)
99+
.unwrap(),
100+
Contract::from_code_path(
101+
"HelloStarknet".to_string(),
102+
Path::new("tests/data/contracts/hello_starknet.cairo"),
103+
)
104+
.unwrap()
105+
);
106+
107+
let result = run(
108+
&test.path().unwrap(),
109+
&String::from("src"),
110+
&test.path().unwrap().join("src/lib.cairo"),
111+
&Some(test.linked_libraries()),
112+
&Default::default(),
113+
&corelib_path(),
114+
&test.contracts(&corelib_path()).unwrap(),
115+
&Utf8PathBuf::from_path_buf(predeployed_contracts().to_path_buf()).unwrap(),
116+
)
117+
.unwrap();
118+
119+
assert_passed!(result);
120+
}

crates/forge/tests/integration/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ pub(crate) mod common;
22
mod declare;
33
mod deploy;
44
mod dispatchers;
5+
mod get_class_hash;
56
mod prank;
67
mod pure_cairo;
78
mod roll;

docs/src/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
* [stop_roll](appendix/cheatcodes/stop_roll.md)
4545
* [start_warp](appendix/cheatcodes/start_warp.md)
4646
* [stop_warp](appendix/cheatcodes/stop_warp.md)
47+
* [get_class_hash](appendix/cheatcodes/get_class_hash.md)
4748
* [Forge Library Functions References](appendix/forge-library.md)
4849
* [declare](appendix/forge-library/declare.md)
4950
* [deploy](appendix/forge-library/deploy.md)

docs/src/appendix/cheatcodes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* [`stop_roll`](cheatcodes/stop_roll.md)
77
* [`start_warp`](cheatcodes/start_warp.md)
88
* [`stop_warp`](cheatcodes/stop_warp.md)
9+
* [`get_class_hash`](cheatcodes/get_class_hash.md)
910

1011
> ℹ️ **Info**
1112
> To use cheatcodes you need to add `snforge_std` package as a dependency in
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# `get_class_hash`
2+
3+
> `fn get_class_hash(contract_address: ContractAddress) -> ClassHash`
4+
5+
Retrieves a class hash of a contract deployed under the given address.
6+
7+
- `contract_address` - target contract address
8+
9+
The main purpose of this cheatcode is to test upgradable contracts. For contract implementation:
10+
11+
```rust
12+
// ...
13+
#[external(v0)]
14+
impl IUpgradeableImpl of super::IUpgradeable<ContractState> {
15+
fn upgrade(ref self: ContractState, class_hash: starknet::ClassHash) {
16+
starknet::replace_class_syscall(class_hash).unwrap_syscall();
17+
}
18+
}
19+
// ...
20+
```
21+
22+
We can use `get_class_hash` to check if it upgraded properly:
23+
24+
```rust
25+
#[test]
26+
fn test_get_class_hash() {
27+
let class_hash = declare('Contract1');
28+
29+
let prepared = PreparedContract {
30+
class_hash: class_hash,
31+
constructor_calldata: @ArrayTrait::new()
32+
};
33+
34+
let contract_address = deploy(prepared).unwrap();
35+
36+
assert(get_class_hash(contract_address) == class_hash, 'Incorrect class hash');
37+
38+
let other_class_hash = declare('OtherContract');
39+
40+
IUpgradeableDispatcher { contract_address }.upgrade(other_class_hash);
41+
42+
assert(get_class_hash(contract_address) == other_class_hash, 'Incorrect class hash upgrade');
43+
}
44+
```

0 commit comments

Comments
 (0)