Skip to content
This repository was archived by the owner on Oct 1, 2024. It is now read-only.

Commit c0639a6

Browse files
committed
Support decorations on tests
Decorations on functions are not re-applied after rewriting their assertions. This is because a function like @foo def a(): pass is (effectively) sugar for def a(): pass a = foo(a) However, rewrite_assertion only extracts the function from the recompiled code. To include the decorations, we have to execute the code and extract the function from the globals. However, this presents a problem, since the decorator must be in scope when executing the code. We could use the module's __dict__, but it's possible that the decorator is redefined between when it is applied to the function and when the module finishes execution: def foo(f): return f @foo def a(): pass foo = 5 This will load properly, but would fail if we tried to execute code for the function with the module's final __dict__. We have similar problems for constructs like for i in range(4): def _(i=i): return i To really be correct here, we have to rewrite assertions when importing the module. This is actually much simpler than the existing strategy (as can be seen by the negative diffstat). It does result in a behavioral change where all assertions in a test module are rewritten instead of just those in tests. This patch does not handle cases like # a_test.py import b_test # b_test.py assert 1 == 2 because if a_test is imported first, then it will import b_test with the regular importer and the assertions will not be rewritten. To fix this correctly, we need to replace the loader to add an exec hook which applies only to test modules. This patch does not implement this and so I have marked this patch RFC. However, if such a scenario does not occur, this is more than enough to get hypothesis working.
1 parent f41c616 commit c0639a6

File tree

5 files changed

+44
-107
lines changed

5 files changed

+44
-107
lines changed

tests/test_rewrite.py

+29-49
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import ast
22

3-
from tests.utilities import testable_test
3+
from tests.utilities import testable_test, failing_assertion
44
from ward import fixture, test
55
from ward._rewrite import (
66
RewriteAssert,
77
get_assertion_msg,
88
is_binary_comparison,
99
is_comparison_type,
1010
make_call_node,
11-
rewrite_assertions_in_tests,
1211
)
12+
from ward.expect import TestAssertionFailure, raises
1313
from ward.testing import Test, each
1414

1515

@@ -34,37 +34,6 @@ def as_dict(node):
3434
return node
3535

3636

37-
@testable_test
38-
def passing_fn():
39-
assert 1 == 1
40-
41-
42-
@testable_test
43-
def failing_fn():
44-
assert 1 == 2
45-
46-
47-
@fixture
48-
def passing():
49-
yield Test(fn=passing_fn, module_name="m", id="id-pass")
50-
51-
52-
@fixture
53-
def failing():
54-
yield Test(fn=failing_fn, module_name="m", id="id-fail")
55-
56-
57-
@test("rewrite_assertions_in_tests returns all tests, keeping metadata")
58-
def _(p=passing, f=failing):
59-
in_tests = [p, f]
60-
out_tests = rewrite_assertions_in_tests(in_tests)
61-
62-
def meta(test):
63-
return test.description, test.id, test.module_name, test.fn.ward_meta
64-
65-
assert [meta(test) for test in in_tests] == [meta(test) for test in out_tests]
66-
67-
6837
@test("RewriteAssert.visit_Assert doesn't transform `{src}`")
6938
def _(
7039
src=each(
@@ -121,6 +90,33 @@ def _(
12190
assert out_tree.value.args[1].id == "y"
12291
assert out_tree.value.args[2].s == ""
12392

93+
@test("This test suite's assertions are themselves rewritten")
94+
def _():
95+
with raises(TestAssertionFailure):
96+
assert 1 == 2
97+
with raises(TestAssertionFailure):
98+
assert 1 != 1
99+
with raises(TestAssertionFailure):
100+
assert 1 in ()
101+
with raises(TestAssertionFailure):
102+
assert 1 not in (1,)
103+
with raises(TestAssertionFailure):
104+
assert None is Ellipsis
105+
with raises(TestAssertionFailure):
106+
assert None is not None
107+
with raises(TestAssertionFailure):
108+
assert 2 < 1
109+
with raises(TestAssertionFailure):
110+
assert 2 <= 1
111+
with raises(TestAssertionFailure):
112+
assert 1 > 2
113+
with raises(TestAssertionFailure):
114+
assert 1 >= 2
115+
116+
@test("Non-test modules' assertions aren't rewritten")
117+
def _():
118+
with raises(AssertionError):
119+
failing_assertion()
124120

125121
@test("RewriteAssert.visit_Assert transforms `{src}`")
126122
def _(src="assert 1 == 2, 'msg'"):
@@ -210,19 +206,3 @@ def _():
210206
@test("test with indentation level of 2")
211207
def _():
212208
assert 2 + 3 == 5
213-
214-
215-
@test("rewriter finds correct function when there is a lambda in an each")
216-
def _():
217-
@testable_test
218-
def _(x=each(lambda: 5)):
219-
assert x == 5
220-
221-
t = Test(fn=_, module_name="m")
222-
223-
rewritten = rewrite_assertions_in_tests([t])[0]
224-
225-
# https://github.com/darrenburns/ward/issues/169
226-
# The assertion rewriter thought the lambda function stored in co_consts was the test function,
227-
# so it was rebuilding the test function using the lambda as the test instead of the original function.
228-
assert rewritten.fn.__code__.co_name != "<lambda>"

tests/utilities.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def testable_test(func):
3131
testable_test.path = FORCE_TEST_PATH # type: ignore[attr-defined]
3232

3333

34+
def failing_assertion():
35+
assert 1 == 2
36+
37+
3438
@fixture
3539
def dummy_fixture():
3640
"""

ward/_collect.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from cucumber_tag_expressions.model import Expression
1616

1717
from ward._errors import CollectionError
18+
from ward._rewrite import exec_module
1819
from ward._testing import COLLECTED_TESTS, is_test_module_name
1920
from ward._utilities import get_absolute_path
2021
from ward.fixtures import Fixture
@@ -149,7 +150,7 @@ def load_modules(modules: Iterable[pkgutil.ModuleInfo]) -> List[ModuleType]:
149150
if pkg_data.pkg_root not in sys.path:
150151
sys.path.append(str(pkg_data.pkg_root))
151152
m.__package__ = pkg_data.pkg_name
152-
m.__loader__.exec_module(m)
153+
exec_module(m)
153154
loaded_modules.append(m)
154155

155156
return loaded_modules

ward/_rewrite.py

+8-53
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import textwrap
55
import types
6+
from pathlib import Path
67
from typing import Iterable, List
78

89
from ward.expect import (
@@ -87,57 +88,11 @@ def visit_Assert(self, node): # noqa: C901 - no chance to reduce complexity
8788
return node
8889

8990

90-
def rewrite_assertions_in_tests(tests: Iterable[Test]) -> List[Test]:
91-
return [rewrite_assertion(test) for test in tests]
92-
93-
94-
def rewrite_assertion(test: Test) -> Test:
95-
# Get the old code and code object
96-
code_lines, line_no = inspect.getsourcelines(test.fn)
97-
98-
code = "".join(code_lines)
99-
indents = textwrap._leading_whitespace_re.findall(code)
100-
col_offset = len(indents[0]) if len(indents) > 0 else 0
101-
code = textwrap.dedent(code)
102-
code_obj = test.fn.__code__
103-
104-
# Rewrite the AST of the code
105-
tree = ast.parse(code)
106-
ast.increment_lineno(tree, line_no - 1)
107-
91+
def exec_module(module: types.ModuleType):
92+
filename = module.__spec__.origin
93+
code = module.__loader__.get_source(module.__name__)
94+
tree = ast.parse(code, filename=filename)
10895
new_tree = RewriteAssert().visit(tree)
109-
110-
if sys.version_info[:2] < (3, 11):
111-
# We dedented the code so that it was a valid tree, now re-apply the indent
112-
for child in ast.walk(new_tree):
113-
if hasattr(child, "col_offset"):
114-
child.col_offset = getattr(child, "col_offset", 0) + col_offset
115-
116-
# Reconstruct the test function
117-
new_mod_code_obj = compile(new_tree, code_obj.co_filename, "exec")
118-
119-
# TODO: This probably isn't correct for nested closures
120-
clo_glob = {}
121-
if test.fn.__closure__:
122-
clo_glob = test.fn.__closure__[0].cell_contents.__globals__
123-
124-
# Look through the new module,
125-
# find the code object with the same name as the original code object,
126-
# and build a new function with the injected assert functions added to the global namespace.
127-
# Filtering on the code object name prevents finding other kinds of code objects,
128-
# like lambdas stored directly in test function arguments.
129-
for const in new_mod_code_obj.co_consts:
130-
if isinstance(const, types.CodeType) and const.co_name == code_obj.co_name:
131-
new_test_func = types.FunctionType(
132-
const,
133-
{**assert_func_namespace, **test.fn.__globals__, **clo_glob},
134-
test.fn.__name__,
135-
test.fn.__defaults__,
136-
)
137-
new_test_func.ward_meta = test.fn.ward_meta
138-
return Test(
139-
**{k: vars(test)[k] for k in vars(test) if k != "fn"},
140-
fn=new_test_func,
141-
)
142-
143-
return test
96+
code = compile(new_tree, filename, "exec", dont_inherit=True)
97+
module.__dict__.update(assert_func_namespace)
98+
exec(code, module.__dict__)

ward/_run.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from ward._config import set_defaults_from_config
2424
from ward._debug import init_breakpointhooks
25-
from ward._rewrite import rewrite_assertions_in_tests
2625
from ward._suite import Suite
2726
from ward._terminal import (
2827
SessionPrelude,
@@ -204,11 +203,9 @@ def test(
204203
if config.order == "random":
205204
shuffle(filtered_tests)
206205

207-
tests = rewrite_assertions_in_tests(filtered_tests)
208-
209206
time_to_collect_secs = default_timer() - start_run
210207

211-
suite = Suite(tests=tests)
208+
suite = Suite(tests=filtered_tests)
212209
test_results = suite.generate_test_runs(
213210
dry_run=dry_run, capture_output=capture_output
214211
)

0 commit comments

Comments
 (0)