Skip to content

Commit 33f65f7

Browse files
committed
Fix __future__ imports with leading comments
1 parent 1548d03 commit 33f65f7

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

mutmut/file_mutation.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ def combine_mutations_to_source(module: cst.Module, mutations: Sequence[Mutation
173173
:param mutations: Mutations that should be applied.
174174
:return: Mutated code and list of mutation names"""
175175

176-
# add original imports (in particular __future__ imports)
177-
result: list[MODULE_STATEMENT] = get_leading_import_statements(module.body)
176+
# copy start of the module (in particular __future__ imports)
177+
result: list[MODULE_STATEMENT] = get_statements_until_func_or_class(module.body)
178178
mutation_names: list[str] = []
179179

180180
# statements we still need to potentially mutate and add to the result
@@ -252,17 +252,16 @@ def function_trampoline_arrangement(function: cst.FunctionDef, mutants: Iterable
252252
return nodes, mutant_names
253253

254254

255-
def get_leading_import_statements(statements: Sequence[MODULE_STATEMENT]) -> list[MODULE_STATEMENT]:
256-
"""Get all `import ...` and `from ... import ...` statements at the start of the module"""
257-
leading_import_statements = []
255+
def get_statements_until_func_or_class(statements: Sequence[MODULE_STATEMENT]) -> list[MODULE_STATEMENT]:
256+
"""Get all statements until we encounter the first function or class definition"""
257+
result = []
258258

259259
for stmt in statements:
260-
if m.matches(stmt, m.SimpleStatementLine([m.AtLeastN(matcher=m.Import() | m.ImportFrom(), n=1)])):
261-
leading_import_statements.append(stmt)
262-
else:
263-
break
260+
if m.matches(stmt, m.FunctionDef() | m.ClassDef()):
261+
return result
262+
result.append(stmt)
264263

265-
return leading_import_statements
264+
return result
266265

267266
def group_by_top_level_node(mutations: Sequence[Mutation]) -> Mapping[cst.CSTNode, Sequence[Mutation]]:
268267
grouped: dict[cst.CSTNode, list[Mutation]] = defaultdict(list)

tests/test_mutation.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,20 @@ def foo():
417417
assert mutated_source.split('\n')[0] == 'from __future__ import annotations'
418418
assert mutated_source.count('from __future__') == 1
419419

420+
def test_from_future_with_docstring_still_first():
421+
source = """
422+
'''This documents the module'''
423+
from __future__ import annotations
424+
from collections.abc import Iterable
425+
426+
def foo():
427+
return 1
428+
""".strip()
429+
mutated_source = mutated_module(source)
430+
assert mutated_source.split('\n')[0] == "'''This documents the module'''"
431+
assert mutated_source.split('\n')[1] == 'from __future__ import annotations'
432+
assert mutated_source.count('from __future__') == 1
433+
420434

421435
def test_preserve_generators():
422436
source = '''
@@ -582,10 +596,10 @@ def add(self, value):
582596

583597
assert src == f"""from __future__ import division
584598
import lib
585-
{trampoline_impl.strip()}
586-
{yield_from_trampoline_impl.strip()}
587599
588600
lib.foo()
601+
{trampoline_impl.strip()}
602+
{yield_from_trampoline_impl.strip()}
589603
590604
def x_foo__mutmut_orig(a, b):
591605
return a > b

tests/test_mutmut3.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ def test_mutate_file_contents():
1616
def foo(a, b, c):
1717
return a + b * c
1818
"""
19+
trampolines = trampoline_impl.removesuffix('\n\n') + yield_from_trampoline_impl.removesuffix('\n\n')
1920

20-
expected = trampoline_impl.removesuffix('\n\n') + yield_from_trampoline_impl.removesuffix('\n\n') + """
21-
a + 1
21+
expected = f"""
22+
a + 1{trampolines}
2223
2324
def x_foo__mutmut_orig(a, b, c):
2425
return a + b * c
@@ -29,10 +30,10 @@ def x_foo__mutmut_1(a, b, c):
2930
def x_foo__mutmut_2(a, b, c):
3031
return a + b / c
3132
32-
x_foo__mutmut_mutants : ClassVar[MutantDict] = {
33+
x_foo__mutmut_mutants : ClassVar[MutantDict] = {{
3334
'x_foo__mutmut_1': x_foo__mutmut_1,
3435
'x_foo__mutmut_2': x_foo__mutmut_2
35-
}
36+
}}
3637
3738
def foo(*args, **kwargs):
3839
result = _mutmut_trampoline(x_foo__mutmut_orig, x_foo__mutmut_mutants, args, kwargs)

0 commit comments

Comments
 (0)