Skip to content

Commit a84cf38

Browse files
authored
fix: issue with extra suffix parts in imported sources (#148)
1 parent f378390 commit a84cf38

7 files changed

Lines changed: 199 additions & 140 deletions

File tree

ape_solidity/compiler.py

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,15 @@
5454
# Define a regex pattern that matches import statements
5555
# Both single and multi-line imports will be matched
5656
IMPORTS_PATTERN = re.compile(
57-
r"import\s+((.*?)(?=;)|[\s\S]*?from\s+(.*?)(?=;));\s", flags=re.MULTILINE
57+
r"import\s+(([\s\S]*?)(?=;)|[\s\S]*?from\s+([^\s;]+));\s*", flags=re.MULTILINE
5858
)
5959
LICENSES_PATTERN = re.compile(r"(// SPDX-License-Identifier:\s*([^\n]*)\s)")
60+
61+
# Comment patterns
62+
SINGLE_LINE_COMMENT_PATTERN = re.compile(r"^\s*//")
63+
MULTI_LINE_COMMENT_START_PATTERN = re.compile(r"/\*")
64+
MULTI_LINE_COMMENT_END_PATTERN = re.compile(r"\*/")
65+
6066
VERSION_PRAGMA_PATTERN = re.compile(r"pragma solidity[^;]*;")
6167
DEFAULT_OPTIMIZATION_RUNS = 200
6268

@@ -142,7 +148,7 @@ class SolidityConfig(PluginConfig):
142148
def _get_flattened_source(path: Path, name: Optional[str] = None) -> str:
143149
name = name or path.name
144150
result = f"// File: {name}\n"
145-
result += path.read_text() + "\n"
151+
result += f"{path.read_text().rstrip()}\n"
146152
return result
147153

148154

@@ -373,12 +379,15 @@ def _get_settings_from_imports(
373379
files_by_solc_version = self.get_version_map_from_imports(
374380
contract_filepaths, import_map, project=pm
375381
)
376-
return self._get_settings_from_version_map(files_by_solc_version, remappings, project=pm)
382+
return self._get_settings_from_version_map(
383+
files_by_solc_version, remappings, import_map=import_map, project=pm
384+
)
377385

378386
def _get_settings_from_version_map(
379387
self,
380388
version_map: dict,
381389
import_remappings: dict[str, str],
390+
import_map: Optional[dict[str, list[str]]] = None,
382391
project: Optional[ProjectManager] = None,
383392
**kwargs,
384393
) -> dict[Version, dict]:
@@ -397,7 +406,9 @@ def _get_settings_from_version_map(
397406
},
398407
**kwargs,
399408
}
400-
if remappings_used := self._get_used_remappings(sources, import_remappings, project=pm):
409+
if remappings_used := self._get_used_remappings(
410+
sources, import_remappings, import_map=import_map, project=pm
411+
):
401412
remappings_str = [f"{k}={v}" for k, v in remappings_used.items()]
402413

403414
# Standard JSON input requires remappings to be sorted.
@@ -421,6 +432,7 @@ def _get_used_remappings(
421432
self,
422433
sources: Iterable[Path],
423434
remappings: dict[str, str],
435+
import_map: Optional[dict[str, list[str]]] = None,
424436
project: Optional[ProjectManager] = None,
425437
) -> dict[str, str]:
426438
pm = project or self.local_project
@@ -435,7 +447,8 @@ def _get_used_remappings(
435447
# Filter out unused import remapping.
436448
result = {}
437449
sources = list(sources)
438-
imports = self.get_imports(sources, project=pm).values()
450+
import_map = import_map or self.get_imports(sources, project=pm)
451+
imports = import_map.values()
439452

440453
for source_list in imports:
441454
for src in source_list:
@@ -461,32 +474,20 @@ def get_standard_input_json(
461474
import_map = self.get_imports_from_remapping(paths, remapping, project=pm)
462475
version_map = self.get_version_map_from_imports(paths, import_map, project=pm)
463476
return self.get_standard_input_json_from_version_map(
464-
version_map, remapping, project=pm, **overrides
477+
version_map, remapping, project=pm, import_map=import_map, **overrides
465478
)
466479

467-
def get_standard_input_json_from(
468-
self,
469-
version_map: dict[Version, set[Path]],
470-
import_remappings: dict[str, str],
471-
project: Optional[ProjectManager] = None,
472-
**overrides,
473-
):
474-
pm = project or self.local_project
475-
settings = self._get_settings_from_version_map(
476-
version_map, import_remappings, project=pm, **overrides
477-
)
478-
return self.get_standard_input_json_from_settings(settings, version_map, project=pm)
479-
480480
def get_standard_input_json_from_version_map(
481481
self,
482482
version_map: dict[Version, set[Path]],
483483
import_remapping: dict[str, str],
484+
import_map: Optional[dict[str, list[str]]] = None,
484485
project: Optional[ProjectManager] = None,
485486
**overrides,
486487
):
487488
pm = project or self.local_project
488489
settings = self._get_settings_from_version_map(
489-
version_map, import_remapping, project=pm, **overrides
490+
version_map, import_remapping, import_map=import_map, project=pm, **overrides
490491
)
491492
return self.get_standard_input_json_from_settings(settings, version_map, project=pm)
492493

@@ -571,8 +572,16 @@ def _compile(
571572
settings: Optional[dict] = None,
572573
):
573574
pm = project or self.local_project
574-
input_jsons = self.get_standard_input_json(
575-
contract_filepaths, project=pm, **(settings or {})
575+
remapping = self.get_import_remapping(project=pm)
576+
paths = list(contract_filepaths) # Handle if given generator=
577+
import_map = self.get_imports_from_remapping(paths, remapping, project=pm)
578+
version_map = self.get_version_map_from_imports(paths, import_map, project=pm)
579+
input_jsons = self.get_standard_input_json_from_version_map(
580+
version_map,
581+
remapping,
582+
project=pm,
583+
import_map=import_map,
584+
**(settings or {}),
576585
)
577586
contract_versions: dict[str, Version] = {}
578587
contract_types: list[ContractType] = []
@@ -608,7 +617,7 @@ def _compile(
608617
for name, _ in contracts_out.items():
609618
# Filter source files that the user did not ask for, such as
610619
# imported relative files that are not part of the input.
611-
for input_file_path in contract_filepaths:
620+
for input_file_path in paths:
612621
if source_id in str(input_file_path):
613622
input_contract_names.append(name)
614623

@@ -1096,14 +1105,17 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
10961105

10971106
def _flatten_source(
10981107
self,
1099-
path: Path,
1108+
path: Union[Path, str],
11001109
project: Optional[ProjectManager] = None,
11011110
raw_import_name: Optional[str] = None,
11021111
handled: Optional[set[str]] = None,
11031112
) -> str:
11041113
pm = project or self.local_project
11051114
handled = handled or set()
1106-
source_id = f"{get_relative_path(path, pm.path)}"
1115+
1116+
path = Path(path)
1117+
source_id = f"{get_relative_path(path, pm.path)}" if path.is_absolute() else f"{path}"
1118+
11071119
handled.add(source_id)
11081120
remapping = self.get_import_remapping(project=project)
11091121
imports = self._get_imports((path,), remapping, pm, tracked=set(), include_raw=True)
@@ -1116,26 +1128,36 @@ def _flatten_source(
11161128
continue
11171129

11181130
sub_import_name = import_str.replace("import ", "").strip(" \n\t;\"'")
1119-
final_source += self._flatten_source(
1131+
sub_source = self._flatten_source(
11201132
pm.path / source_id,
11211133
project=pm,
11221134
raw_import_name=sub_import_name,
11231135
handled=handled,
11241136
)
1137+
final_source += sub_source
1138+
1139+
flattened_src = _get_flattened_source(path, name=raw_import_name)
1140+
if flattened_src and final_source.rstrip():
1141+
final_source = f"{final_source.rstrip()}\n\n{flattened_src}"
1142+
elif flattened_src:
1143+
final_source = flattened_src
11251144

1126-
final_source += _get_flattened_source(path, name=raw_import_name)
11271145
return final_source
11281146

11291147
def flatten_contract(
11301148
self, path: Path, project: Optional[ProjectManager] = None, **kwargs
11311149
) -> Content:
1132-
# try compiling in order to validate it works
11331150
res = self._flatten_source(path, project=project)
11341151
res = remove_imports(res)
11351152
res = process_licenses(res)
11361153
res = remove_version_pragmas(res)
11371154
pragma = get_first_version_pragma(path.read_text())
11381155
res = "\n".join([pragma, res])
1156+
1157+
# Simple auto-format.
1158+
while "\n\n\n" in res:
1159+
res = res.replace("\n\n\n", "\n\n")
1160+
11391161
lines = res.splitlines()
11401162
line_dict = {i + 1: line for i, line in enumerate(lines)}
11411163
return Content(root=line_dict)
@@ -1244,11 +1266,37 @@ def _import_str_to_source_id(
12441266
return f"{get_relative_path(path.absolute(), pm.path)}"
12451267

12461268

1247-
def remove_imports(flattened_contract: str) -> str:
1248-
# Use regex.sub() to remove matched import statements
1249-
no_imports_contract = IMPORTS_PATTERN.sub("", flattened_contract)
1269+
def remove_imports(source_code: str) -> str:
1270+
in_multi_line_comment = False
1271+
result_lines = []
1272+
1273+
lines = source_code.splitlines()
1274+
for line in lines:
1275+
# Check if we're entering a multi-line comment
1276+
if MULTI_LINE_COMMENT_START_PATTERN.search(line):
1277+
in_multi_line_comment = True
1278+
1279+
# If inside a multi-line comment, just add the line to the result
1280+
if in_multi_line_comment:
1281+
result_lines.append(line)
1282+
# Check if this line ends the multi-line comment
1283+
if MULTI_LINE_COMMENT_END_PATTERN.search(line):
1284+
in_multi_line_comment = False
1285+
continue
1286+
1287+
# Skip single-line comments
1288+
if SINGLE_LINE_COMMENT_PATTERN.match(line):
1289+
result_lines.append(line)
1290+
continue
1291+
1292+
# Skip import statements in non-comment lines
1293+
if IMPORTS_PATTERN.search(line):
1294+
continue
1295+
1296+
# Add the line to the result if it's not an import statement
1297+
result_lines.append(line)
12501298

1251-
return no_imports_contract
1299+
return "\n".join(result_lines)
12521300

12531301

12541302
def remove_version_pragmas(flattened_contract: str) -> str:
@@ -1285,9 +1333,7 @@ def process_licenses(contract: str) -> str:
12851333
license_line, root_license = extracted_licenses[-1]
12861334

12871335
# Get the unique license identifiers. All licenses in a contract _should_ be the same.
1288-
unique_license_identifiers = {
1289-
license_identifier for _, license_identifier in extracted_licenses
1290-
}
1336+
unique_license_identifiers = {lid for _, lid in extracted_licenses}
12911337

12921338
# If we have more than one unique license identifier, warn the user and use the root.
12931339
if len(unique_license_identifiers) > 1:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
include_package_data=True,
7070
install_requires=[
7171
"py-solc-x>=2.0.2,<3",
72-
"eth-ape>=0.8.1,<0.9",
72+
"eth-ape>=0.8.4,<0.9",
7373
"ethpm-types", # Use the version ape requires
7474
"eth-pydantic-types", # Use the version ape requires
7575
"packaging", # Use the version ape requires

tests/contracts/Imports.sol

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ import "@safe/contracts/common/Enum.sol";
2626
// Purposely exclude the contracts folder to test older Ape-style project imports.
2727
import "@noncompilingdependency/subdir/SubCompilingContract.sol";
2828

29+
// Showing sources with extra extensions are by default excluded,
30+
// unless used as an import somewhere in a non-excluded source.
31+
import "./Source.extra.ext.sol";
32+
2933
contract Imports {
3034
function foo() pure public returns(bool) {
3135
return true;
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// SPDX-License-Identifier: MIT
2+
pragma solidity ^0.8.4;
3+
4+
// Showing sources with extra extensions are by default excluded,
5+
// unless used as an import somewhere in a non-excluded source.
6+
contract SourceExtraExt {
7+
function foo() pure public returns(bool) {
8+
return true;
9+
}
10+
}

tests/data/ImportingLessConstrainedVersionFlat.sol

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ pragma solidity =0.8.12;
33

44
// File: ./SpecificVersionRange.sol
55

6-
7-
86
contract SpecificVersionRange {
97
function foo() pure public returns(bool) {
108
return true;
@@ -13,8 +11,6 @@ contract SpecificVersionRange {
1311

1412
// File: ImportingLessConstrainedVersion.sol
1513

16-
17-
1814
// The file we are importing specific range '>=0.8.12 <0.8.15';
1915
// This means on its own, the plugin would use 0.8.14 if its installed.
2016
// However - it should use 0.8.12 because of this file's requirements.

0 commit comments

Comments
 (0)