Skip to content

Commit 968d56f

Browse files
authored
Support async generators (#407)
Functions with "yield" return a generator which we can simply return as a value to the caller.
1 parent 873019e commit 968d56f

File tree

9 files changed

+69
-89
lines changed

9 files changed

+69
-89
lines changed

e2e_projects/my_lib/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ dev = [
1717
]
1818

1919
[tool.mutmut]
20-
debug = true
20+
debug = true
21+
22+
[tool.pytest]
23+
asyncio_default_fixture_loop_scope = "function"

e2e_projects/my_lib/src/my_lib/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,29 @@ def fibonacci(n: int) -> int:
2626
return n
2727
return fibonacci(n - 1) + fibonacci(n - 2)
2828

29+
async def async_consumer():
30+
results = []
31+
async for i in async_generator():
32+
results.append(i)
33+
return results
34+
35+
async def async_generator():
36+
for i in range(10):
37+
yield i
38+
39+
def simple_consumer():
40+
generator = double_generator()
41+
next(generator) # skip the initial yield
42+
results = []
43+
for i in range(10):
44+
results.append(generator.send(i))
45+
return results
46+
47+
def double_generator():
48+
while True:
49+
x = yield
50+
yield x * 2
51+
2952
@cache
3053
def cached_fibonacci(n: int) -> int:
3154
if n <= 1:

e2e_projects/my_lib/tests/test_my_lib.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from my_lib import hello, Point, badly_tested, make_greeter, fibonacci, cached_fibonacci, escape_sequences
1+
from my_lib import hello, Point, badly_tested, make_greeter, fibonacci, cached_fibonacci, escape_sequences, simple_consumer, async_consumer
2+
import pytest
23

34
"""These tests are flawed on purpose, some mutants survive and some are killed."""
45

@@ -33,4 +34,14 @@ def test_fibonacci():
3334
assert cached_fibonacci(1) == 1
3435

3536
def test_escape_sequences():
36-
assert escape_sequences().lower() == "foofoo\\\'\"\a\b\f\n\r\t\v\111\x10\N{ghost}\u1234\U0001F51F".lower()
37+
assert escape_sequences().lower() == "foofoo\\\'\"\a\b\f\n\r\t\v\111\x10\N{ghost}\u1234\U0001F51F".lower()
38+
39+
def test_simple_consumer():
40+
# only verifying length, should report surviving mutants for the contents
41+
assert len(simple_consumer()) == 10
42+
43+
@pytest.mark.asyncio
44+
async def test_async_consumer():
45+
result = await async_consumer()
46+
assert result == list(range(10))
47+

mutmut/file_mutation.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import libcst as cst
99
from libcst.metadata import PositionProvider, MetadataWrapper
1010
import libcst.matchers as m
11-
from mutmut.trampoline_templates import build_trampoline, mangle_function_name, trampoline_impl, yield_from_trampoline_impl
11+
from mutmut.trampoline_templates import build_trampoline, mangle_function_name, trampoline_impl
1212
from mutmut.node_mutation import mutation_operators, OPERATORS_TYPE
1313

1414
NEVER_MUTATE_FUNCTION_NAMES = { "__getattribute__", "__setattr__", "__new__" }
@@ -165,8 +165,6 @@ def _skip_node_and_children(self, node: cst.CSTNode):
165165
# convert str trampoline implementations to CST nodes with some whitespace
166166
trampoline_impl_cst = list(cst.parse_module(trampoline_impl).body)
167167
trampoline_impl_cst[-1] = trampoline_impl_cst[-1].with_changes(leading_lines = [cst.EmptyLine(), cst.EmptyLine()])
168-
yield_from_trampoline_impl_cst = list(cst.parse_module(yield_from_trampoline_impl).body)
169-
yield_from_trampoline_impl_cst[-1] = yield_from_trampoline_impl_cst[-1].with_changes(leading_lines = [cst.EmptyLine(), cst.EmptyLine()])
170168

171169

172170
def combine_mutations_to_source(module: cst.Module, mutations: Sequence[Mutation]) -> tuple[str, Sequence[str]]:
@@ -185,7 +183,6 @@ def combine_mutations_to_source(module: cst.Module, mutations: Sequence[Mutation
185183

186184
# trampoline functions
187185
result.extend(trampoline_impl_cst)
188-
result.extend(yield_from_trampoline_impl_cst)
189186

190187
mutations_within_function = group_by_top_level_node(mutations)
191188

@@ -234,7 +231,6 @@ def function_trampoline_arrangement(function: cst.FunctionDef, mutants: Iterable
234231

235232
name = function.name.value
236233
mangled_name = mangle_function_name(name=name, class_name=class_name) + '__mutmut'
237-
_is_generator = is_generator(function)
238234

239235
# copy of original function
240236
nodes.append(function.with_changes(name=cst.Name(mangled_name + '_orig')))
@@ -248,7 +244,7 @@ def function_trampoline_arrangement(function: cst.FunctionDef, mutants: Iterable
248244
nodes.append(mutated_method) # type: ignore
249245

250246
# trampoline that forwards the calls
251-
trampoline = list(cst.parse_module(build_trampoline(orig_name=name, mutants=mutant_names, class_name=class_name, is_generator=_is_generator)).body)
247+
trampoline = list(cst.parse_module(build_trampoline(orig_name=name, mutants=mutant_names, class_name=class_name)).body)
252248
trampoline[0] = trampoline[0].with_changes(leading_lines=[cst.EmptyLine()])
253249
nodes.extend(trampoline)
254250

@@ -274,28 +270,6 @@ def group_by_top_level_node(mutations: Sequence[Mutation]) -> Mapping[cst.CSTNod
274270

275271
return grouped
276272

277-
def is_generator(function: cst.FunctionDef) -> bool:
278-
"""Return True if the function has yield statement(s)."""
279-
visitor = IsGeneratorVisitor(function)
280-
function.visit(visitor)
281-
return visitor.is_generator
282-
283-
class IsGeneratorVisitor(cst.CSTVisitor):
284-
"""Check if a function is a generator.
285-
We do so by checking if any child is a Yield statement, but not looking into inner function definitions."""
286-
def __init__(self, original_function: cst.FunctionDef):
287-
self.is_generator = False
288-
self.original_function: cst.FunctionDef = original_function
289-
290-
def visit_FunctionDef(self, node):
291-
# do not recurse into inner function definitions
292-
if self.original_function != node:
293-
return False
294-
295-
def visit_Yield(self, node):
296-
self.is_generator = True
297-
return False
298-
299273
def pragma_no_mutate_lines(source: str) -> set[int]:
300274
return {
301275
i + 1

mutmut/trampoline_templates.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
CLASS_NAME_SEPARATOR = 'ǁ'
22

3-
def build_trampoline(*, orig_name, mutants, class_name, is_generator):
3+
def build_trampoline(*, orig_name, mutants, class_name):
44
mangled_name = mangle_function_name(name=orig_name, class_name=class_name)
55

66
mutants_dict = f'{mangled_name}__mutmut_mutants : ClassVar[MutantDict] = {{\n' + ', \n '.join(f'{repr(m)}: {m}' for m in mutants) + '\n}'
@@ -12,18 +12,13 @@ def build_trampoline(*, orig_name, mutants, class_name, is_generator):
1212
access_suffix = '")'
1313
self_arg = ', self'
1414

15-
if is_generator:
16-
yield_statement = 'yield from ' # note the space at the end!
17-
trampoline_name = '_mutmut_yield_from_trampoline'
18-
else:
19-
yield_statement = ''
20-
trampoline_name = '_mutmut_trampoline'
15+
trampoline_name = '_mutmut_trampoline'
2116

2217
return f"""
2318
{mutants_dict}
2419
2520
def {orig_name}({'self, ' if class_name is not None else ''}*args, **kwargs):
26-
result = {yield_statement}{trampoline_name}({access_prefix}{mangled_name}__mutmut_orig{access_suffix}, {access_prefix}{mangled_name}__mutmut_mutants{access_suffix}, args, kwargs{self_arg})
21+
result = {trampoline_name}({access_prefix}{mangled_name}__mutmut_orig{access_suffix}, {access_prefix}{mangled_name}__mutmut_mutants{access_suffix}, args, kwargs{self_arg})
2722
return result
2823
2924
{orig_name}.__signature__ = _mutmut_signature({mangled_name}__mutmut_orig)
@@ -62,11 +57,11 @@ def _mutmut_trampoline(orig, mutants, call_args, call_kwargs, self_arg = None):
6257
from mutmut.__main__ import record_trampoline_hit
6358
record_trampoline_hit(orig.__module__ + '.' + orig.__name__)
6459
result = orig(*call_args, **call_kwargs)
65-
return result # for the yield case
60+
return result
6661
prefix = orig.__module__ + '.' + orig.__name__ + '__mutmut_'
6762
if not mutant_under_test.startswith(prefix):
6863
result = orig(*call_args, **call_kwargs)
69-
return result # for the yield case
64+
return result
7065
mutant_name = mutant_under_test.rpartition('.')[-1]
7166
if self_arg:
7267
# call to a class method where self is not bound
@@ -76,4 +71,3 @@ def _mutmut_trampoline(orig, mutants, call_args, call_kwargs, self_arg = None):
7671
return result
7772
7873
"""
79-
yield_from_trampoline_impl = trampoline_impl.replace('result = ', 'result = yield from ').replace('_mutmut_trampoline', '_mutmut_yield_from_trampoline')

test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pytest
2+
pytest-asyncio>=1.0.0
23
mock>=2.0.0
34
coverage
45
whatthepatch==0.0.6

tests/e2e/snapshots/my_lib.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@
2525
"my_lib.x_fibonacci__mutmut_7": 0,
2626
"my_lib.x_fibonacci__mutmut_8": 0,
2727
"my_lib.x_fibonacci__mutmut_9": 0,
28+
"my_lib.x_async_consumer__mutmut_1": 1,
29+
"my_lib.x_async_consumer__mutmut_2": 1,
30+
"my_lib.x_async_generator__mutmut_1": 1,
31+
"my_lib.x_async_generator__mutmut_2": 1,
32+
"my_lib.x_simple_consumer__mutmut_1": 1,
33+
"my_lib.x_simple_consumer__mutmut_2": 1,
34+
"my_lib.x_simple_consumer__mutmut_3": 1,
35+
"my_lib.x_simple_consumer__mutmut_4": 1,
36+
"my_lib.x_simple_consumer__mutmut_5": 1,
37+
"my_lib.x_simple_consumer__mutmut_6": 0,
38+
"my_lib.x_simple_consumer__mutmut_7": 1,
39+
"my_lib.x_double_generator__mutmut_1": 1,
40+
"my_lib.x_double_generator__mutmut_2": 1,
41+
"my_lib.x_double_generator__mutmut_3": 0,
42+
"my_lib.x_double_generator__mutmut_4": 0,
2843
"my_lib.x\u01c1Point\u01c1__init____mutmut_1": 1,
2944
"my_lib.x\u01c1Point\u01c1__init____mutmut_2": 1,
3045
"my_lib.x\u01c1Point\u01c1abs__mutmut_1": 33,

tests/test_mutation.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
MutmutProgrammaticFailException,
1414
CatchOutput,
1515
)
16-
from mutmut.trampoline_templates import trampoline_impl, yield_from_trampoline_impl, mangle_function_name
17-
from mutmut.file_mutation import create_mutations, mutate_file_contents, is_generator
16+
from mutmut.trampoline_templates import trampoline_impl, mangle_function_name
17+
from mutmut.file_mutation import create_mutations, mutate_file_contents
1818

1919
def mutants_for_source(source: str) -> list[str]:
2020
module, mutated_nodes = create_mutations(source)
@@ -514,43 +514,6 @@ def foo():
514514
assert mutated_source.count('from __future__') == 1
515515

516516

517-
def test_preserve_generators():
518-
source = '''
519-
def foo():
520-
yield 1
521-
'''.strip()
522-
mutated_source = mutated_module(source)
523-
assert 'yield from _mutmut_yield_from_trampoline' in mutated_source
524-
525-
526-
def test_is_generator():
527-
source = '''
528-
def foo():
529-
yield 1
530-
'''.strip()
531-
assert is_generator(parse_statement(source)) # type: ignore
532-
533-
source = '''
534-
def foo():
535-
yield from bar()
536-
'''.strip()
537-
assert is_generator(parse_statement(source)) # type: ignore
538-
539-
source = '''
540-
def foo():
541-
return 1
542-
'''.strip()
543-
assert not is_generator(parse_statement(source)) # type: ignore
544-
545-
source = '''
546-
def foo():
547-
def bar():
548-
yield 2
549-
return 1
550-
'''.strip()
551-
assert not is_generator(parse_statement(source)) # type: ignore
552-
553-
554517
# Negate the effects of CatchOutput because it does not play nicely with capfd in GitHub Actions
555518
@patch.object(CatchOutput, 'dump_output')
556519
@patch.object(CatchOutput, 'stop')
@@ -678,7 +641,6 @@ def add(self, value):
678641
679642
lib.foo()
680643
{trampoline_impl.strip()}
681-
{yield_from_trampoline_impl.strip()}
682644
683645
def x_foo__mutmut_orig(a, b):
684646
return a > b
@@ -708,7 +670,7 @@ def x_bar__mutmut_1():
708670
}}
709671
710672
def bar(*args, **kwargs):
711-
result = yield from _mutmut_yield_from_trampoline(x_bar__mutmut_orig, x_bar__mutmut_mutants, args, kwargs)
673+
result = _mutmut_trampoline(x_bar__mutmut_orig, x_bar__mutmut_mutants, args, kwargs)
712674
return result
713675
714676
bar.__signature__ = _mutmut_signature(x_bar__mutmut_orig)

tests/test_mutmut3.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from mutmut.trampoline_templates import (
2-
trampoline_impl,
3-
yield_from_trampoline_impl,
4-
)
1+
from mutmut.trampoline_templates import trampoline_impl
52
from mutmut.file_mutation import mutate_file_contents
63

74
def mutated_module(source: str) -> str:
@@ -16,7 +13,7 @@ def test_mutate_file_contents():
1613
def foo(a, b, c):
1714
return a + b * c
1815
"""
19-
trampolines = trampoline_impl.removesuffix('\n\n') + yield_from_trampoline_impl.removesuffix('\n\n')
16+
trampolines = trampoline_impl.removesuffix('\n\n')
2017

2118
expected = f"""
2219
a + 1{trampolines}
@@ -54,7 +51,7 @@ def foo(a: List[int]) -> int:
5451
return 1
5552
"""
5653

57-
expected = trampoline_impl.removesuffix('\n\n') + yield_from_trampoline_impl.removesuffix('\n\n') + """
54+
expected = trampoline_impl.removesuffix('\n\n') + """
5855
def x_foo__mutmut_orig(a: List[int]) -> int:
5956
return 1
6057
def x_foo__mutmut_1(a: List[int]) -> int:

0 commit comments

Comments
 (0)