5454# Define a regex pattern that matches import statements
5555# Both single and multi-line imports will be matched
5656IMPORTS_PATTERN = re .compile (
57- r"import\s+((. *?)(?=;)|[\s\S]*?from\s+(.*?)(?=; ));\s" , flags = re .MULTILINE
57+ r"import\s+(([\s\S] *?)(?=;)|[\s\S]*?from\s+([^\s;]+ ));\s* " , flags = re .MULTILINE
5858)
5959LICENSES_PATTERN = re .compile (r"(// SPDX-License-Identifier:\s*([^\n]*)\s)" )
60+
61+ # Comment patterns
62+ SINGLE_LINE_COMMENT_PATTERN = re .compile (r"^\s*//" )
63+ MULTI_LINE_COMMENT_START_PATTERN = re .compile (r"/\*" )
64+ MULTI_LINE_COMMENT_END_PATTERN = re .compile (r"\*/" )
65+
6066VERSION_PRAGMA_PATTERN = re .compile (r"pragma solidity[^;]*;" )
6167DEFAULT_OPTIMIZATION_RUNS = 200
6268
@@ -142,7 +148,7 @@ class SolidityConfig(PluginConfig):
142148def _get_flattened_source (path : Path , name : Optional [str ] = None ) -> str :
143149 name = name or path .name
144150 result = f"// File: { name } \n "
145- result += path .read_text () + " \n "
151+ result += f" { path .read_text (). rstrip () } \n "
146152 return result
147153
148154
@@ -373,12 +379,15 @@ def _get_settings_from_imports(
373379 files_by_solc_version = self .get_version_map_from_imports (
374380 contract_filepaths , import_map , project = pm
375381 )
376- return self ._get_settings_from_version_map (files_by_solc_version , remappings , project = pm )
382+ return self ._get_settings_from_version_map (
383+ files_by_solc_version , remappings , import_map = import_map , project = pm
384+ )
377385
378386 def _get_settings_from_version_map (
379387 self ,
380388 version_map : dict ,
381389 import_remappings : dict [str , str ],
390+ import_map : Optional [dict [str , list [str ]]] = None ,
382391 project : Optional [ProjectManager ] = None ,
383392 ** kwargs ,
384393 ) -> dict [Version , dict ]:
@@ -397,7 +406,9 @@ def _get_settings_from_version_map(
397406 },
398407 ** kwargs ,
399408 }
400- if remappings_used := self ._get_used_remappings (sources , import_remappings , project = pm ):
409+ if remappings_used := self ._get_used_remappings (
410+ sources , import_remappings , import_map = import_map , project = pm
411+ ):
401412 remappings_str = [f"{ k } ={ v } " for k , v in remappings_used .items ()]
402413
403414 # Standard JSON input requires remappings to be sorted.
@@ -421,6 +432,7 @@ def _get_used_remappings(
421432 self ,
422433 sources : Iterable [Path ],
423434 remappings : dict [str , str ],
435+ import_map : Optional [dict [str , list [str ]]] = None ,
424436 project : Optional [ProjectManager ] = None ,
425437 ) -> dict [str , str ]:
426438 pm = project or self .local_project
@@ -435,7 +447,8 @@ def _get_used_remappings(
435447 # Filter out unused import remapping.
436448 result = {}
437449 sources = list (sources )
438- imports = self .get_imports (sources , project = pm ).values ()
450+ import_map = import_map or self .get_imports (sources , project = pm )
451+ imports = import_map .values ()
439452
440453 for source_list in imports :
441454 for src in source_list :
@@ -461,32 +474,20 @@ def get_standard_input_json(
461474 import_map = self .get_imports_from_remapping (paths , remapping , project = pm )
462475 version_map = self .get_version_map_from_imports (paths , import_map , project = pm )
463476 return self .get_standard_input_json_from_version_map (
464- version_map , remapping , project = pm , ** overrides
477+ version_map , remapping , project = pm , import_map = import_map , ** overrides
465478 )
466479
467- def get_standard_input_json_from (
468- self ,
469- version_map : dict [Version , set [Path ]],
470- import_remappings : dict [str , str ],
471- project : Optional [ProjectManager ] = None ,
472- ** overrides ,
473- ):
474- pm = project or self .local_project
475- settings = self ._get_settings_from_version_map (
476- version_map , import_remappings , project = pm , ** overrides
477- )
478- return self .get_standard_input_json_from_settings (settings , version_map , project = pm )
479-
480480 def get_standard_input_json_from_version_map (
481481 self ,
482482 version_map : dict [Version , set [Path ]],
483483 import_remapping : dict [str , str ],
484+ import_map : Optional [dict [str , list [str ]]] = None ,
484485 project : Optional [ProjectManager ] = None ,
485486 ** overrides ,
486487 ):
487488 pm = project or self .local_project
488489 settings = self ._get_settings_from_version_map (
489- version_map , import_remapping , project = pm , ** overrides
490+ version_map , import_remapping , import_map = import_map , project = pm , ** overrides
490491 )
491492 return self .get_standard_input_json_from_settings (settings , version_map , project = pm )
492493
@@ -571,8 +572,16 @@ def _compile(
571572 settings : Optional [dict ] = None ,
572573 ):
573574 pm = project or self .local_project
574- input_jsons = self .get_standard_input_json (
575- contract_filepaths , project = pm , ** (settings or {})
575+ remapping = self .get_import_remapping (project = pm )
576+ paths = list (contract_filepaths ) # Handle if given generator=
577+ import_map = self .get_imports_from_remapping (paths , remapping , project = pm )
578+ version_map = self .get_version_map_from_imports (paths , import_map , project = pm )
579+ input_jsons = self .get_standard_input_json_from_version_map (
580+ version_map ,
581+ remapping ,
582+ project = pm ,
583+ import_map = import_map ,
584+ ** (settings or {}),
576585 )
577586 contract_versions : dict [str , Version ] = {}
578587 contract_types : list [ContractType ] = []
@@ -608,7 +617,7 @@ def _compile(
608617 for name , _ in contracts_out .items ():
609618 # Filter source files that the user did not ask for, such as
610619 # imported relative files that are not part of the input.
611- for input_file_path in contract_filepaths :
620+ for input_file_path in paths :
612621 if source_id in str (input_file_path ):
613622 input_contract_names .append (name )
614623
@@ -1096,14 +1105,17 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
10961105
10971106 def _flatten_source (
10981107 self ,
1099- path : Path ,
1108+ path : Union [ Path , str ] ,
11001109 project : Optional [ProjectManager ] = None ,
11011110 raw_import_name : Optional [str ] = None ,
11021111 handled : Optional [set [str ]] = None ,
11031112 ) -> str :
11041113 pm = project or self .local_project
11051114 handled = handled or set ()
1106- source_id = f"{ get_relative_path (path , pm .path )} "
1115+
1116+ path = Path (path )
1117+ source_id = f"{ get_relative_path (path , pm .path )} " if path .is_absolute () else f"{ path } "
1118+
11071119 handled .add (source_id )
11081120 remapping = self .get_import_remapping (project = project )
11091121 imports = self ._get_imports ((path ,), remapping , pm , tracked = set (), include_raw = True )
@@ -1116,26 +1128,36 @@ def _flatten_source(
11161128 continue
11171129
11181130 sub_import_name = import_str .replace ("import " , "" ).strip (" \n \t ;\" '" )
1119- final_source + = self ._flatten_source (
1131+ sub_source = self ._flatten_source (
11201132 pm .path / source_id ,
11211133 project = pm ,
11221134 raw_import_name = sub_import_name ,
11231135 handled = handled ,
11241136 )
1137+ final_source += sub_source
1138+
1139+ flattened_src = _get_flattened_source (path , name = raw_import_name )
1140+ if flattened_src and final_source .rstrip ():
1141+ final_source = f"{ final_source .rstrip ()} \n \n { flattened_src } "
1142+ elif flattened_src :
1143+ final_source = flattened_src
11251144
1126- final_source += _get_flattened_source (path , name = raw_import_name )
11271145 return final_source
11281146
11291147 def flatten_contract (
11301148 self , path : Path , project : Optional [ProjectManager ] = None , ** kwargs
11311149 ) -> Content :
1132- # try compiling in order to validate it works
11331150 res = self ._flatten_source (path , project = project )
11341151 res = remove_imports (res )
11351152 res = process_licenses (res )
11361153 res = remove_version_pragmas (res )
11371154 pragma = get_first_version_pragma (path .read_text ())
11381155 res = "\n " .join ([pragma , res ])
1156+
1157+ # Simple auto-format.
1158+ while "\n \n \n " in res :
1159+ res = res .replace ("\n \n \n " , "\n \n " )
1160+
11391161 lines = res .splitlines ()
11401162 line_dict = {i + 1 : line for i , line in enumerate (lines )}
11411163 return Content (root = line_dict )
@@ -1244,11 +1266,37 @@ def _import_str_to_source_id(
12441266 return f"{ get_relative_path (path .absolute (), pm .path )} "
12451267
12461268
1247- def remove_imports (flattened_contract : str ) -> str :
1248- # Use regex.sub() to remove matched import statements
1249- no_imports_contract = IMPORTS_PATTERN .sub ("" , flattened_contract )
1269+ def remove_imports (source_code : str ) -> str :
1270+ in_multi_line_comment = False
1271+ result_lines = []
1272+
1273+ lines = source_code .splitlines ()
1274+ for line in lines :
1275+ # Check if we're entering a multi-line comment
1276+ if MULTI_LINE_COMMENT_START_PATTERN .search (line ):
1277+ in_multi_line_comment = True
1278+
1279+ # If inside a multi-line comment, just add the line to the result
1280+ if in_multi_line_comment :
1281+ result_lines .append (line )
1282+ # Check if this line ends the multi-line comment
1283+ if MULTI_LINE_COMMENT_END_PATTERN .search (line ):
1284+ in_multi_line_comment = False
1285+ continue
1286+
1287+ # Skip single-line comments
1288+ if SINGLE_LINE_COMMENT_PATTERN .match (line ):
1289+ result_lines .append (line )
1290+ continue
1291+
1292+ # Skip import statements in non-comment lines
1293+ if IMPORTS_PATTERN .search (line ):
1294+ continue
1295+
1296+ # Add the line to the result if it's not an import statement
1297+ result_lines .append (line )
12501298
1251- return no_imports_contract
1299+ return " \n " . join ( result_lines )
12521300
12531301
12541302def remove_version_pragmas (flattened_contract : str ) -> str :
@@ -1285,9 +1333,7 @@ def process_licenses(contract: str) -> str:
12851333 license_line , root_license = extracted_licenses [- 1 ]
12861334
12871335 # Get the unique license identifiers. All licenses in a contract _should_ be the same.
1288- unique_license_identifiers = {
1289- license_identifier for _ , license_identifier in extracted_licenses
1290- }
1336+ unique_license_identifiers = {lid for _ , lid in extracted_licenses }
12911337
12921338 # If we have more than one unique license identifier, warn the user and use the root.
12931339 if len (unique_license_identifiers ) > 1 :
0 commit comments