Skip to content

Commit ba33ab9

Browse files
authored
Merge pull request #43 from crytic/improve-generate-cli
Improve unit test generation CLI arguments
2 parents 3a6bf68 + 3bb3141 commit ba33ab9

File tree

12 files changed

+145
-125
lines changed

12 files changed

+145
-125
lines changed

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,21 @@ The available tool commands are:
4040
The `generate` command is used to generate Foundry unit tests from Echidna or Medusa corpus call sequences.
4141

4242
**Command-line options:**
43-
- `compilation_path`: The path to the Solidity file or Foundry directory
44-
- `-cd`/`--corpus-dir` `path_to_corpus_dir`: The path to the corpus directory relative to the working directory.
45-
- `-c`/`--contract` `contract_name`: The name of the target contract.
46-
- `-td`/`--test-directory` `path_to_test_directory`: The path to the test directory relative to the working directory.
47-
- `-i`/`--inheritance-path` `relative_path_to_contract`: The relative path from the test directory to the contract (used for inheritance).
48-
- `-f`/`--fuzzer` `fuzzer_name`: The name of the fuzzer, currently supported: `echidna` and `medusa`
49-
- `--named-inputs`: Includes function input names when making calls
50-
- `--config`: Path to the fuzz-utils config JSON file
51-
- `--all-sequences`: Include all corpus sequences when generating unit tests.
43+
- `compilation_path`: The path to the Solidity file or Foundry directory. By default `.`
44+
- `-cd`/`--corpus-dir` `path_to_corpus_dir`: The path to the corpus directory relative to the working directory. By default `corpus`
45+
- `-c`/`--contract` `contract_name`: The name of the target contract. If the compilation path only contains one contract the target will be automatically derived.
46+
- `-td`/`--test-directory` `path_to_test_directory`: The path to the test directory relative to the working directory. By default `test`
47+
- `-i`/`--inheritance-path` `relative_path_to_contract`: The relative path from the test directory to the contract (used for overriding inheritance). If this configuration option is not provided the inheritance path will be automatically derived.
48+
- `-f`/`--fuzzer` `fuzzer_name`: The name of the fuzzer, currently supported: `echidna` and `medusa`. By default `medusa`
49+
- `--named-inputs`: Includes function input names when making calls. By default`false`
50+
- `--config`: Path to the fuzz-utils config JSON file. Empty by default.
51+
- `--all-sequences`: Include all corpus sequences when generating unit tests. By default `false`
5252

5353
**Example**
5454

5555
In order to generate a test file for the [BasicTypes.sol](tests/test_data/src/BasicTypes.sol) contract, based on the Echidna corpus reproducers for this contract ([corpus-basic](tests/test_data/echidna-corpora/corpus-basic/)), we need to `cd` into the `tests/test_data` directory which contains the Foundry project and run the command:
5656
```bash
57-
fuzz-utils generate ./src/BasicTypes.sol --corpus-dir echidna-corpora/corpus-basic --contract "BasicTypes" --test-directory "./test/" --inheritance-path "../src/" --fuzzer echidna
57+
fuzz-utils generate ./src/BasicTypes.sol --corpus-dir echidna-corpora/corpus-basic --contract "BasicTypes" --fuzzer echidna
5858
```
5959

6060
Running this command should generate a `BasicTypes_Echidna_Test.sol` file in the [test](/tests/test_data/test/) directory of the Foundry project.
@@ -65,9 +65,9 @@ The `template` command is used to generate a fuzzing harness. The harness can in
6565

6666
**Command-line options:**
6767
- `compilation_path`: The path to the Solidity file or Foundry directory
68-
- `-n`/`--name` `name: str`: The name of the fuzzing harness.
69-
- `-c`/`--contracts` `target_contracts: list`: The name of the target contract.
70-
- `-o`/`--output-dir` `output_directory: str`: Output directory name. By default it is `fuzzing`
68+
- `-n`/`--name` `name: str`: The name of the fuzzing harness. By default `DefaultHarness`
69+
- `-c`/`--contracts` `target_contracts: list`: The name of the target contract. Empty by default.
70+
- `-o`/`--output-dir` `output_directory: str`: Output directory name. By default `fuzzing`
7171
- `--config`: Path to the `fuzz-utils` config JSON file
7272
- `--mode`: The strategy to use when generating the harnesses. Valid options: `simple`, `prank`, `actor`
7373

fuzz_utils/generate/FoundryTest.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,49 @@
11
"""The FoundryTest class that handles generation of unit tests from call sequences"""
22
import os
3-
import sys
43
import json
4+
import copy
55
from typing import Any
66
import jinja2
77

88
from slither import Slither
9-
from slither.core.declarations.contract import Contract
109
from fuzz_utils.utils.crytic_print import CryticPrint
10+
from fuzz_utils.utils.slither_utils import get_target_contract
11+
from fuzz_utils.templates.default_config import default_config
1112

1213
from fuzz_utils.generate.fuzzers.Medusa import Medusa
1314
from fuzz_utils.generate.fuzzers.Echidna import Echidna
1415
from fuzz_utils.templates.foundry_templates import templates
1516

16-
17-
class FoundryTest: # pylint: disable=too-many-instance-attributes
17+
# pylint: disable=too-few-public-methods,too-many-instance-attributes
18+
class FoundryTest:
1819
"""
1920
Handles the generation of Foundry test files
2021
"""
2122

23+
config: dict = copy.deepcopy(default_config["generate"])
24+
2225
def __init__(
2326
self,
2427
config: dict,
2528
slither: Slither,
2629
fuzzer: Echidna | Medusa,
2730
) -> None:
28-
self.inheritance_path = config["inheritancePath"]
29-
self.target_name = config["targetContract"]
30-
self.corpus_path = config["corpusDir"]
31-
self.test_dir = config["testsDir"]
32-
self.all_sequences = config["allSequences"]
3331
self.slither = slither
34-
self.target = self.get_target_contract()
35-
self.fuzzer = fuzzer
32+
for key, value in config.items():
33+
if key in self.config:
34+
self.config[key] = value
3635

37-
def get_target_contract(self) -> Contract:
38-
"""Gets the Slither Contract object for the specified contract file"""
39-
contracts = self.slither.get_contract_from_name(self.target_name)
40-
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
41-
for contract in contracts:
42-
if contract.name == self.target_name:
43-
return contract
44-
45-
# TODO throw error if no contract found
46-
sys.exit(-1)
36+
self.target = get_target_contract(self.slither, self.config["targetContract"])
37+
self.target_file_name = self.target.source_mapping.filename.relative.split("/")[-1]
38+
self.fuzzer = fuzzer
4739

4840
def create_poc(self) -> str:
4941
"""Takes in a directory path to the echidna reproducers and generates a test file"""
5042

5143
file_list: list[dict[str, Any]] = []
5244
tests_list = []
5345
dir_list = []
54-
if self.all_sequences:
46+
if self.config["allSequences"]:
5547
dir_list = self.fuzzer.corpus_dirs
5648
else:
5749
dir_list = [self.fuzzer.reproducer_dir]
@@ -79,13 +71,12 @@ def create_poc(self) -> str:
7971

8072
# 4. Generate the test file
8173
template = jinja2.Template(templates["CONTRACT"])
82-
write_path = f"{self.test_dir}{self.target_name}"
83-
inheritance_path = f"{self.inheritance_path}{self.target_name}"
84-
74+
write_path = os.path.join(self.config["testsDir"], self.config["targetContract"])
75+
inheritance_path = os.path.join(self.config["inheritancePath"])
8576
# 5. Save the test file
8677
test_file_str = template.render(
87-
file_path=f"{inheritance_path}.sol",
88-
target_name=self.target_name,
78+
file_path=inheritance_path,
79+
target_name=self.config["targetContract"],
8980
amount=0,
9081
tests=tests_list,
9182
fuzzer=self.fuzzer.name,

fuzz_utils/generate/fuzzers/Echidna.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import jinja2
55

66
from slither import Slither
7-
from slither.core.declarations.contract import Contract
87
from slither.core.declarations.function_contract import FunctionContract
98
from slither.core.solidity_types.elementary_type import ElementaryType
109
from slither.core.solidity_types.user_defined_type import UserDefinedType
@@ -16,9 +15,10 @@
1615
from fuzz_utils.templates.foundry_templates import templates
1716
from fuzz_utils.utils.encoding import parse_echidna_byte_string
1817
from fuzz_utils.utils.error_handler import handle_exit
18+
from fuzz_utils.utils.slither_utils import get_target_contract
1919

2020

21-
# pylint: disable=too-many-instance-attributes
21+
# pylint: disable=too-few-public-methods,too-many-instance-attributes
2222
class Echidna:
2323
"""
2424
Handles the generation of Foundry test files from Echidna reproducers
@@ -30,22 +30,12 @@ def __init__(
3030
self.name = "Echidna"
3131
self.target_name = target_name
3232
self.slither = slither
33-
self.target = self.get_target_contract()
33+
self.target = get_target_contract(slither, target_name)
3434
self.reproducer_dir = f"{corpus_path}/reproducers"
3535
self.corpus_dirs = [f"{corpus_path}/coverage", self.reproducer_dir]
3636
self.named_inputs = named_inputs
3737
self.declared_variables: set[tuple[str, str]] = set()
3838

39-
def get_target_contract(self) -> Contract:
40-
"""Finds and returns Slither Contract"""
41-
contracts = self.slither.get_contract_from_name(self.target_name)
42-
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
43-
for contract in contracts:
44-
if contract.name == self.target_name:
45-
return contract
46-
47-
handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")
48-
4939
def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str:
5040
"""
5141
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.

fuzz_utils/generate/fuzzers/Medusa.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from eth_abi import abi
55
from eth_utils import to_checksum_address
66
from slither import Slither
7-
from slither.core.declarations.contract import Contract
87
from slither.core.declarations.function_contract import FunctionContract
98
from slither.core.solidity_types.elementary_type import ElementaryType
109
from slither.core.solidity_types.user_defined_type import UserDefinedType
@@ -16,9 +15,10 @@
1615
from fuzz_utils.templates.foundry_templates import templates
1716
from fuzz_utils.utils.encoding import byte_to_escape_sequence
1817
from fuzz_utils.utils.error_handler import handle_exit
18+
from fuzz_utils.utils.slither_utils import get_target_contract
1919

20-
21-
class Medusa: # pylint: disable=too-many-instance-attributes
20+
# pylint: disable=too-few-public-methods,too-many-instance-attributes
21+
class Medusa:
2222
"""
2323
Handles the generation of Foundry test files from Medusa reproducers
2424
"""
@@ -30,7 +30,7 @@ def __init__(
3030
self.target_name = target_name
3131
self.corpus_path = corpus_path
3232
self.slither = slither
33-
self.target = self.get_target_contract()
33+
self.target = get_target_contract(slither, target_name)
3434
self.reproducer_dir = f"{corpus_path}/test_results"
3535
self.corpus_dirs = [
3636
f"{corpus_path}/call_sequences/immutable",
@@ -40,16 +40,6 @@ def __init__(
4040
self.named_inputs = named_inputs
4141
self.declared_variables: set[tuple[str, str]] = set()
4242

43-
def get_target_contract(self) -> Contract:
44-
"""Finds and returns Slither Contract"""
45-
contracts = self.slither.get_contract_from_name(self.target_name)
46-
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
47-
for contract in contracts:
48-
if contract.name == self.target_name:
49-
return contract
50-
51-
handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.")
52-
5343
def parse_reproducer(self, file_path: str, calls: Any, index: int) -> str:
5444
"""
5545
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.

fuzz_utils/parsing/commands/generate.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""Defines the flags and logic associated with the `generate` command"""
2-
import json
2+
from pathlib import Path
33
from argparse import Namespace, ArgumentParser
44
from slither import Slither
55
from fuzz_utils.utils.crytic_print import CryticPrint
66
from fuzz_utils.generate.FoundryTest import FoundryTest
77
from fuzz_utils.generate.fuzzers.Medusa import Medusa
88
from fuzz_utils.generate.fuzzers.Echidna import Echidna
99
from fuzz_utils.utils.error_handler import handle_exit
10+
from fuzz_utils.parsing.parser_util import check_config_and_set_default_values, open_config
11+
from fuzz_utils.utils.slither_utils import get_target_contract
12+
13+
COMMAND: str = "generate"
1014

1115

1216
def generate_flags(parser: ArgumentParser) -> None:
@@ -57,15 +61,13 @@ def generate_flags(parser: ArgumentParser) -> None:
5761
)
5862

5963

64+
# pylint: disable=too-many-branches
6065
def generate_command(args: Namespace) -> None:
6166
"""The execution logic of the `generate` command"""
6267
config: dict = {}
6368
# If the config file is defined, read it
6469
if args.config:
65-
with open(args.config, "r", encoding="utf-8") as readFile:
66-
complete_config = json.load(readFile)
67-
if "generate" in complete_config:
68-
config = complete_config["generate"]
70+
config = open_config(args.config, COMMAND)
6971
# Override the config with the CLI values
7072
if args.compilation_path:
7173
config["compilationPath"] = args.compilation_path
@@ -90,10 +92,18 @@ def generate_command(args: Namespace) -> None:
9092
if "allSequences" not in config:
9193
config["allSequences"] = False
9294

95+
check_config_and_set_default_values(
96+
config,
97+
["compilationPath", "testsDir", "fuzzer", "corpusDir"],
98+
[".", "test", "medusa", "corpus"],
99+
)
100+
93101
CryticPrint().print_information("Running Slither...")
94102
slither = Slither(args.compilation_path)
95103
fuzzer: Echidna | Medusa
96104

105+
derive_config(slither, config)
106+
97107
match config["fuzzer"]:
98108
case "echidna":
99109
fuzzer = Echidna(
@@ -114,3 +124,27 @@ def generate_command(args: Namespace) -> None:
114124
foundry_test = FoundryTest(config, slither, fuzzer)
115125
foundry_test.create_poc()
116126
CryticPrint().print_success("Done!")
127+
128+
129+
def derive_config(slither: Slither, config: dict) -> None:
130+
"""Derive values for the target contract and inheritance path"""
131+
# Derive target if it is not defined but the compilationPath only contains one contract
132+
if "targetContract" not in config or len(config["targetContract"]) == 0:
133+
if len(slither.contracts_derived) == 1:
134+
config["targetContract"] = slither.contracts_derived[0].name
135+
CryticPrint().print_information(
136+
f"Target contract not specified. Using derived target: {config['targetContract']}."
137+
)
138+
else:
139+
handle_exit(
140+
"Target contract cannot be determined. Please specify the target with `-c targetName`"
141+
)
142+
143+
# Derive inheritance path if it is not defined
144+
if "inheritancePath" not in config or len(config["inheritancePath"]) == 0:
145+
contract = get_target_contract(slither, config["targetContract"])
146+
contract_path = Path(contract.source_mapping.filename.relative)
147+
tests_path = Path(config["testsDir"])
148+
config["inheritancePath"] = str(
149+
Path(*([".." * len(tests_path.parts)])).joinpath(contract_path)
150+
)

fuzz_utils/parsing/commands/template.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
"""Defines the flags and logic associated with the `template` command"""
22
import os
3-
import json
43
from argparse import Namespace, ArgumentParser
54
from slither import Slither
65
from fuzz_utils.template.HarnessGenerator import HarnessGenerator
76
from fuzz_utils.utils.crytic_print import CryticPrint
87
from fuzz_utils.utils.remappings import find_remappings
98
from fuzz_utils.utils.error_handler import handle_exit
9+
from fuzz_utils.parsing.parser_util import (
10+
check_configuration_field_exists_and_non_empty,
11+
open_config,
12+
)
13+
14+
COMMAND: str = "template"
1015

1116

1217
def template_flags(parser: ArgumentParser) -> None:
@@ -42,10 +47,7 @@ def template_command(args: Namespace) -> None:
4247
else:
4348
output_dir = os.path.join("./test", "fuzzing")
4449
if args.config:
45-
with open(args.config, "r", encoding="utf-8") as readFile:
46-
complete_config = json.load(readFile)
47-
if "template" in complete_config:
48-
config = complete_config["template"]
50+
config = open_config(args.config, COMMAND)
4951

5052
if args.target_contracts:
5153
config["targets"] = args.target_contracts
@@ -72,15 +74,9 @@ def check_configuration(config: dict) -> None:
7274
"""Checks the configuration"""
7375
mandatory_configuration_fields = ["mode", "targets", "compilationPath"]
7476
for field in mandatory_configuration_fields:
75-
check_configuration_field_exists_and_non_empty(config, field)
77+
check_configuration_field_exists_and_non_empty(config, COMMAND, field)
7678

7779
if config["mode"].lower() not in ("simple", "prank", "actor"):
7880
handle_exit(
7981
f"The selected mode {config['mode']} is not a valid harness generation strategy."
8082
)
81-
82-
83-
def check_configuration_field_exists_and_non_empty(config: dict, field: str) -> None:
84-
"""Checks that the configuration dictionary contains a non-empty field"""
85-
if field not in config or len(config[field]) == 0:
86-
handle_exit(f"The template configuration field {field} is not configured.")

fuzz_utils/parsing/parser_util.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Utility functions used in the command parsers"""
2+
import json
3+
from fuzz_utils.utils.error_handler import handle_exit
4+
5+
6+
def check_config_and_set_default_values(
7+
config: dict, fields: list[str], defaults: list[str]
8+
) -> None:
9+
"""Checks that the configuration dictionary contains a non-empty field"""
10+
assert len(fields) == len(defaults)
11+
for idx, field in enumerate(fields):
12+
if field not in config or len(config[field]) == 0:
13+
config[field] = defaults[idx]
14+
15+
16+
def check_configuration_field_exists_and_non_empty(config: dict, command: str, field: str) -> None:
17+
"""Checks that the configuration dictionary contains a non-empty field"""
18+
if field not in config or len(config[field]) == 0:
19+
handle_exit(f"The {command} configuration field {field} is not configured.")
20+
21+
22+
def open_config(cli_config: str, command: str) -> dict:
23+
"""Open config file if provided return its contents"""
24+
with open(cli_config, "r", encoding="utf-8") as readFile:
25+
complete_config = json.load(readFile)
26+
if command in complete_config:
27+
return complete_config[command]
28+
29+
handle_exit(
30+
f"The provided configuration file does not contain the `{command}` command configuration field."
31+
)

0 commit comments

Comments
 (0)