diff --git a/crytic_compile/cryticparser/cryticparser.py b/crytic_compile/cryticparser/cryticparser.py index 19df89e7..f525338d 100755 --- a/crytic_compile/cryticparser/cryticparser.py +++ b/crytic_compile/cryticparser/cryticparser.py @@ -304,6 +304,14 @@ def _init_etherscan(parser: ArgumentParser) -> None: default=DEFAULTS_FLAG_IN_CONFIG["etherscan_only_bytecode"], ) + group_etherscan.add_argument( + "--etherscan-target-only", + help="Etherscan only include target contract.", + action="store_true", + dest="etherscan_target_only", + default=DEFAULTS_FLAG_IN_CONFIG["etherscan_target_only"], + ) + group_etherscan.add_argument( "--etherscan-apikey", help="Etherscan API key.", diff --git a/crytic_compile/cryticparser/defaults.py b/crytic_compile/cryticparser/defaults.py index b8f99677..b196a392 100755 --- a/crytic_compile/cryticparser/defaults.py +++ b/crytic_compile/cryticparser/defaults.py @@ -30,6 +30,7 @@ "etherlime_compile_arguments": None, "etherscan_only_source_code": False, "etherscan_only_bytecode": False, + "etherscan_target_only": False, "etherscan_api_key": None, "etherscan_export_directory": "etherscan-contracts", "waffle_ignore_compile": False, diff --git a/crytic_compile/platform/etherscan.py b/crytic_compile/platform/etherscan.py index e83b206a..da1fdcb4 100644 --- a/crytic_compile/platform/etherscan.py +++ b/crytic_compile/platform/etherscan.py @@ -202,6 +202,8 @@ def compile(self, crytic_compile: "CryticCompile", **kwargs: str) -> None: only_source = kwargs.get("etherscan_only_source_code", False) only_bytecode = kwargs.get("etherscan_only_bytecode", False) + target_only = kwargs.get("etherscan_target_only", False) + etherscan_api_key = kwargs.get("etherscan_api_key", None) arbiscan_api_key = kwargs.get("arbiscan_api_key", None) polygonscan_api_key = kwargs.get("polygonscan_api_key", None) @@ -330,6 +332,9 @@ def compile(self, crytic_compile: "CryticCompile", **kwargs: str) -> None: solc_standard_json.standalone_compile(filenames, compilation_unit, working_dir=working_dir) + if target_only: + _remove_unused_contracts(compilation_unit) + @staticmethod def is_supported(target: str, **kwargs: str) -> bool: """Check if the target is a etherscan project @@ -390,3 +395,82 @@ def _relative_to_short(relative: Path) -> Path: Path: Translated path """ return relative + + +# pylint: disable=too-many-locals,too-many-branches +def _remove_unused_contracts(compilation_unit: CompilationUnit) -> None: + """ + Removes unused contracts from the compilation unit + + Args: + compilation_unit (CompilationUnit): compilation unit to populate + + Returns: + + """ + if len(list(compilation_unit.asts.keys())) == 1: + # there is only 1 file + return + + # for etherscan this will be the value the etherscan api returns in 'ContractName' + root_contract_name = compilation_unit.unique_id + + # find the root file path according to a contract with the correct name being defined in it + # and also get the base path used by all paths (the keys under 'asts') + root_file_path = None + base_path = "" + for file_path, file_ast in compilation_unit.asts.items(): + if root_file_path is not None: # already found target contract + break + for node in file_ast["nodes"]: + if node["nodeType"] == "ContractDefinition" and node["name"] == root_contract_name: + root_file_path = file_path + base_path = file_path.replace(file_ast["absolutePath"], "") + break + + if root_file_path is None: + # we could not find a contract with that name in any of the files + return + + # Starting with the root contract, fetch all dependencies (and their dependencies, etc.) + files_to_include = [] + files_to_check = [root_file_path] + while any(files_to_check): + target_file_path = files_to_check.pop() + for node in compilation_unit.asts[target_file_path]["nodes"]: + if node["nodeType"] == "ImportDirective": + import_path = os.path.join(base_path, node["absolutePath"]) + if import_path not in files_to_check and import_path not in files_to_include: + files_to_check.append(import_path) + files_to_include.append(target_file_path) + + if len(list(compilation_unit.asts.keys())) == len(files_to_include): + # all of the files need to be included + return + + # Remove all of the unused files from the compilation unit + included_contractnames = set() + for target_file_path in files_to_include: + for node in compilation_unit.asts[target_file_path]["nodes"]: + if node["nodeType"] == "ContractDefinition": + included_contractnames.add(node["name"]) + + for contractname in list(compilation_unit.contracts_names): + if contractname not in included_contractnames: + compilation_unit.contracts_names.remove(contractname) + del compilation_unit.abis[contractname] + del compilation_unit.natspec[contractname] + del compilation_unit.bytecodes_init[contractname] + del compilation_unit.bytecodes_runtime[contractname] + del compilation_unit.srcmaps_init[contractname] + del compilation_unit.srcmaps_runtime[contractname] + + for contract_filename in list(compilation_unit.filename_to_contracts.keys()): + if contract_filename.absolute not in files_to_include: + del compilation_unit.filename_to_contracts[contract_filename] + + for fileobj in list(compilation_unit.crytic_compile.filenames): + if fileobj.absolute not in files_to_include: + compilation_unit.crytic_compile.filenames.remove(fileobj) + compilation_unit.filenames.remove(fileobj) + del compilation_unit.asts[fileobj.absolute]