diff --git a/crytic_compile/platform/vyper.py b/crytic_compile/platform/vyper.py index 4d6fc13f..680f5904 100644 --- a/crytic_compile/platform/vyper.py +++ b/crytic_compile/platform/vyper.py @@ -7,7 +7,7 @@ import shutil import subprocess from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from crytic_compile.compilation_unit import CompilationUnit from crytic_compile.compiler.compiler import CompilerVersion @@ -55,7 +55,37 @@ def __init__(self, target: Optional[Path] = None, **_kwargs: str): } }, } - + # https://github.com/ApeWorX/ape-vyper/blob/08115bc581e8a4e959d60028dbd2a71e4c635d43/ape_vyper/compiler.py#L103-L135 + def get_imports( + self, contract_filepaths: List[str], base_path: Optional[Path] = None + ): + base_path = (base_path or Path.cwd()).absolute() + imports = [] + for path in contract_filepaths: + content = Path(path).read_text().splitlines() + source_id = Path(base_path, path).resolve().absolute() + for line in content: + if line.startswith("import "): + import_line_parts = line.replace("import ", "").split(" ") + suffix = import_line_parts[0].strip().replace(".", os.path.sep) + + + elif line.startswith("from ") and " import " in line: + import_line_parts = line.replace("from ", "").split(" ") + module_name = import_line_parts[0].strip().replace(".", os.path.sep) + suffix = os.path.sep.join([module_name, import_line_parts[2].strip()]) + + + else: + # Not an import line + continue + + imported = source_id.parent / f"{suffix}.vy" + if imported.is_file(): + imports.append((imported, f"{suffix}.vy")) + + self.add_import_files(imports) + def compile(self, crytic_compile: "CryticCompile", **kwargs: str) -> None: """Compile the target @@ -71,6 +101,8 @@ def compile(self, crytic_compile: "CryticCompile", **kwargs: str) -> None: if self._target is not None and os.path.isfile(self._target): self.add_source_files([target]) + + self.get_imports(self.standard_json_input["sources"].keys()) vyper_bin = kwargs.get("vyper", "vyper") compilation_artifacts = _run_vyper_standard_json(self.standard_json_input, vyper_bin) @@ -111,6 +143,13 @@ def compile(self, crytic_compile: "CryticCompile", **kwargs: str) -> None: source_unit = compilation_unit.create_source_unit(filename) source_unit.ast = ast + def add_import_files(self, file_paths: Tuple[str, str]) -> None: + for file_path, import_directive in file_paths: + with open(file_path, "r", encoding="utf8") as f: + self.standard_json_input["sources"][import_directive] = { # type: ignore + "content": f.read(), + } + def add_source_files(self, file_paths: List[str]) -> None: """ Append files