@@ -173,8 +173,8 @@ def combine_mutations_to_source(module: cst.Module, mutations: Sequence[Mutation
173173 :param mutations: Mutations that should be applied.
174174 :return: Mutated code and list of mutation names"""
175175
176- # add original imports (in particular __future__ imports)
177- result : list [MODULE_STATEMENT ] = get_leading_import_statements (module .body )
176+ # copy start of the module (in particular __future__ imports)
177+ result : list [MODULE_STATEMENT ] = get_statements_until_func_or_class (module .body )
178178 mutation_names : list [str ] = []
179179
180180 # statements we still need to potentially mutate and add to the result
@@ -252,17 +252,16 @@ def function_trampoline_arrangement(function: cst.FunctionDef, mutants: Iterable
252252 return nodes , mutant_names
253253
254254
255- def get_leading_import_statements (statements : Sequence [MODULE_STATEMENT ]) -> list [MODULE_STATEMENT ]:
256- """Get all `import ...` and `from ... import ...` statements at the start of the module """
257- leading_import_statements = []
255+ def get_statements_until_func_or_class (statements : Sequence [MODULE_STATEMENT ]) -> list [MODULE_STATEMENT ]:
256+ """Get all statements until we encounter the first function or class definition """
257+ result = []
258258
259259 for stmt in statements :
260- if m .matches (stmt , m .SimpleStatementLine ([m .AtLeastN (matcher = m .Import () | m .ImportFrom (), n = 1 )])):
261- leading_import_statements .append (stmt )
262- else :
263- break
260+ if m .matches (stmt , m .FunctionDef () | m .ClassDef ()):
261+ return result
262+ result .append (stmt )
264263
265- return leading_import_statements
264+ return result
266265
267266def group_by_top_level_node (mutations : Sequence [Mutation ]) -> Mapping [cst .CSTNode , Sequence [Mutation ]]:
268267 grouped : dict [cst .CSTNode , list [Mutation ]] = defaultdict (list )
0 commit comments