Skip to content

Commit 056e72c

Browse files
committed
Trampoline for generator functions
1 parent b4b553d commit 056e72c

File tree

3 files changed

+70
-5
lines changed

3 files changed

+70
-5
lines changed

mutmut/__main__.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def _mutmut_trampoline(orig, mutants, *args, **kwargs):
197197
return mutants[mutant_name](*args, **kwargs)
198198
199199
"""
200+
yield_from_trampoline_impl = trampoline_impl.replace('return ', 'yield from ').replace('_mutmut_trampoline', '_mutmut_yield_from_trampoline')
200201

201202

202203
def create_mutants():
@@ -297,7 +298,7 @@ def write_all_mutants_to_file(*, out, source, filename):
297298
return mutant_names, hash_by_function_name
298299

299300

300-
def build_trampoline(orig_name, mutants, class_name=None):
301+
def build_trampoline(*, orig_name, mutants, class_name, is_generator):
301302
assert orig_name not in NEVER_MUTATE_FUNCTION_NAMES
302303

303304
mangled_name = mangle_function_name(name=orig_name, class_name=class_name)
@@ -309,11 +310,18 @@ def build_trampoline(orig_name, mutants, class_name=None):
309310
access_prefix = f'object.__getattribute__(self, "'
310311
access_suffix = '")'
311312

313+
if is_generator:
314+
return_or_yield_statement = 'yield from'
315+
trampoline_name = '_mutmut_yield_from_trampoline'
316+
else:
317+
return_or_yield_statement = 'return'
318+
trampoline_name = '_mutmut_trampoline'
319+
312320
return f"""
313321
{mutants_dict}
314322
315323
def {orig_name}({'self, ' if class_name is not None else ''}*args, **kwargs):
316-
return _mutmut_trampoline({access_prefix}{mangled_name}__mutmut_orig{access_suffix}, {access_prefix}{mangled_name}__mutmut_mutants{access_suffix}, *args, **kwargs)
324+
{return_or_yield_statement} {trampoline_name}({access_prefix}{mangled_name}__mutmut_orig{access_suffix}, {access_prefix}{mangled_name}__mutmut_mutants{access_suffix}, *args, **kwargs)
317325
318326
{orig_name}.__signature__ = _mutmut_signature({mangled_name}__mutmut_orig)
319327
{mangled_name}__mutmut_orig.__name__ = '{mangled_name}'
@@ -447,6 +455,23 @@ def is_inside_dict_synonym_call(self):
447455
return False
448456

449457

458+
def is_generator(node):
459+
assert node.type == 'funcdef'
460+
461+
def _is_generator(n):
462+
if n is not node and n.type in ('funcdef', 'classdef'):
463+
return False
464+
465+
if n.type == 'keyword' and n.value == 'yield':
466+
return True
467+
468+
for c in getattr(n, 'children', []):
469+
if _is_generator(c):
470+
return True
471+
return False
472+
return _is_generator(node)
473+
474+
450475
def yield_mutants_for_function(node, *, class_name=None, no_mutate_lines):
451476
assert node.type == 'funcdef'
452477

@@ -481,7 +506,7 @@ def yield_mutants_for_function(node, *, class_name=None, no_mutate_lines):
481506
finally:
482507
context.stack.pop()
483508

484-
trampoline = build_trampoline(node.name.value, context.mutants, class_name=class_name)
509+
trampoline = build_trampoline(orig_name=node.name.value, mutants=context.mutants, class_name=class_name, is_generator=is_generator(node))
485510
if class_name is not None:
486511
trampoline = indent(trampoline, ' ')
487512
yield 'trampoline', trampoline, None, None
@@ -530,6 +555,7 @@ def yield_mutants_for_module(node, no_mutate_lines):
530555
yield from yield_future_imports(node)
531556

532557
yield 'trampoline_impl', trampoline_impl, None, None
558+
yield 'trampoline_impl', yield_from_trampoline_impl, None, None
533559
yield 'filler', '\n', None, None
534560
for child_node in node.children:
535561
if child_node.type == 'funcdef':

tests/test_mutation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
CLASS_NAME_SEPARATOR,
88
FuncContext,
99
get_diff_for_mutant,
10+
is_generator,
1011
mangle_function_name,
1112
orig_function_and_class_names_from_key,
1213
pragma_no_mutate_lines,
@@ -358,3 +359,40 @@ def foo():
358359
mutated_source = full_mutated_source(source)
359360
assert mutated_source.split('\n')[0] == 'from __future__ import annotations'
360361
assert mutated_source.count('from __future__') == 1
362+
363+
364+
def test_preserve_generators():
365+
source = '''
366+
def foo():
367+
yield 1
368+
'''.strip()
369+
mutated_source = full_mutated_source(source)
370+
assert 'yield from _mutmut_yield_from_trampoline' in mutated_source
371+
372+
373+
def test_is_generator():
374+
source = '''
375+
def foo():
376+
yield 1
377+
'''.strip()
378+
assert is_generator(parse(source).children[0])
379+
380+
source = '''
381+
def foo():
382+
yield from bar()
383+
'''.strip()
384+
assert is_generator(parse(source).children[0])
385+
386+
source = '''
387+
def foo():
388+
return 1
389+
'''.strip()
390+
assert not is_generator(parse(source).children[0])
391+
392+
source = '''
393+
def foo():
394+
def bar():
395+
yield 2
396+
return 1
397+
'''.strip()
398+
assert not is_generator(parse(source).children[0])

tests/test_mutmut3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from mutmut.__main__ import (
44
trampoline_impl,
5+
yield_from_trampoline_impl,
56
yield_mutants_for_module,
67
)
78

@@ -14,7 +15,7 @@ def foo(a, b, c):
1415
return a + b * c
1516
"""
1617

17-
expected = trampoline_impl + """
18+
expected = trampoline_impl + yield_from_trampoline_impl + """
1819
1920
a + 1
2021
@@ -53,7 +54,7 @@ def foo(a: List[int]) -> int:
5354
return 1
5455
"""
5556

56-
expected = trampoline_impl + """
57+
expected = trampoline_impl + yield_from_trampoline_impl + """
5758
5859
def x_foo__mutmut_orig(a: List[int]) -> int:
5960
return 1

0 commit comments

Comments
 (0)