Skip to content

Commit ac908b1

Browse files
committed
Add DiamondCutFacet tests
1 parent 1bc5d81 commit ac908b1

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// SPDX-License-Identifier: AGPL-3.0-only
2+
pragma solidity ^0.8.19;
3+
4+
import "forge-std/Test.sol";
5+
6+
import {DiamondCutFacet} from "../src/diamonds/facets/DiamondCutFacet.sol";
7+
import {IDiamond} from "../src/diamonds/interfaces/IDiamond.sol";
8+
import {IDiamondCut} from "../src/diamonds/interfaces/IDiamondCut.sol";
9+
import {
10+
LibDiamond,
11+
NotContractOwner
12+
} from "../src/diamonds/libraries/LibDiamond.sol";
13+
14+
contract DiamondCutHarness {
15+
error FunctionDoesNotExist(bytes4 selector);
16+
17+
uint256 public initValue;
18+
19+
constructor(address _owner) {
20+
LibDiamond.setContractOwner(_owner);
21+
}
22+
23+
function rawCut(IDiamond.FacetCut[] memory _diamondCut) external {
24+
LibDiamond.enforceIsContractOwner();
25+
LibDiamond.diamondCut(_diamondCut, address(0), "");
26+
}
27+
28+
function owner() external view returns (address) {
29+
return LibDiamond.contractOwner();
30+
}
31+
32+
fallback() external payable {
33+
LibDiamond.DiamondStorage storage ds = LibDiamond.diamondStorage();
34+
address facet = ds.facetAddressAndSelectorPosition[msg.sig].facetAddress;
35+
36+
if (facet == address(0)) {
37+
revert FunctionDoesNotExist(msg.sig);
38+
}
39+
40+
assembly {
41+
calldatacopy(0, 0, calldatasize())
42+
let result := delegatecall(gas(), facet, 0, calldatasize(), 0, 0)
43+
returndatacopy(0, 0, returndatasize())
44+
switch result
45+
case 0 { revert(0, returndatasize()) }
46+
default { return(0, returndatasize()) }
47+
}
48+
}
49+
50+
receive() external payable {}
51+
}
52+
53+
contract SampleFacet {
54+
function ping() external pure returns (uint256) {
55+
return 7;
56+
}
57+
}
58+
59+
contract InitFacet {
60+
uint256 public initValue;
61+
62+
function record(uint256 value) external {
63+
initValue = value;
64+
}
65+
}
66+
67+
contract DiamondCutFacetTest is Test {
68+
DiamondCutHarness private diamond;
69+
DiamondCutFacet private cutFacet;
70+
SampleFacet private sampleFacet;
71+
InitFacet private initFacet;
72+
73+
address private owner;
74+
address private attacker;
75+
76+
function setUp() public {
77+
owner = makeAddr("owner");
78+
attacker = makeAddr("attacker");
79+
80+
diamond = new DiamondCutHarness(owner);
81+
cutFacet = new DiamondCutFacet();
82+
sampleFacet = new SampleFacet();
83+
initFacet = new InitFacet();
84+
85+
vm.prank(owner);
86+
_installDiamondCutFacet();
87+
}
88+
89+
function _installDiamondCutFacet() internal {
90+
IDiamond.FacetCut[] memory cuts = new IDiamond.FacetCut[](1);
91+
bytes4[] memory selectors = new bytes4[](1);
92+
selectors[0] = DiamondCutFacet.diamondCut.selector;
93+
cuts[0] = IDiamond.FacetCut({
94+
facetAddress: address(cutFacet),
95+
action: IDiamond.FacetCutAction.Add,
96+
functionSelectors: selectors
97+
});
98+
diamond.rawCut(cuts);
99+
}
100+
101+
function _singleCut(address facetAddress, bytes4 selector) internal pure returns (IDiamond.FacetCut[] memory cuts) {
102+
cuts = new IDiamond.FacetCut[](1);
103+
bytes4[] memory selectors = new bytes4[](1);
104+
selectors[0] = selector;
105+
cuts[0] = IDiamond.FacetCut({
106+
facetAddress: facetAddress,
107+
action: IDiamond.FacetCutAction.Add,
108+
functionSelectors: selectors
109+
});
110+
}
111+
112+
function test_diamondCut_addsFacetAndRoutesCall() public {
113+
IDiamond.FacetCut[] memory cuts = _singleCut(address(sampleFacet), SampleFacet.ping.selector);
114+
115+
vm.prank(owner);
116+
DiamondCutFacet(address(diamond)).diamondCut(cuts, address(0), "");
117+
118+
assertEq(SampleFacet(address(diamond)).ping(), 7);
119+
}
120+
121+
function test_diamondCut_revertsForNonOwner() public {
122+
IDiamond.FacetCut[] memory cuts = _singleCut(address(sampleFacet), SampleFacet.ping.selector);
123+
124+
vm.prank(attacker);
125+
vm.expectRevert(abi.encodeWithSelector(NotContractOwner.selector, attacker, owner));
126+
DiamondCutFacet(address(diamond)).diamondCut(cuts, address(0), "");
127+
}
128+
129+
function test_diamondCut_executesInitializerCall() public {
130+
IDiamond.FacetCut[] memory cuts = _singleCut(address(sampleFacet), SampleFacet.ping.selector);
131+
132+
vm.prank(owner);
133+
DiamondCutFacet(address(diamond)).diamondCut(
134+
cuts,
135+
address(initFacet),
136+
abi.encodeWithSelector(InitFacet.record.selector, 42)
137+
);
138+
139+
assertEq(diamond.initValue(), 42);
140+
assertEq(SampleFacet(address(diamond)).ping(), 7);
141+
}
142+
}

0 commit comments

Comments
 (0)