Skip to content

Commit 0f93869

Browse files
committed
Port test utils
1 parent 3d53826 commit 0f93869

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

tests/utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# from unittest import TestCase
2+
3+
4+
def call(*args, **kwargs):
5+
return (tuple(args), dict(kwargs))
6+
7+
8+
class Mock:
9+
"""A mock callable object that stores its calls."""
10+
11+
def __init__(self, return_value=None, side_effect=None):
12+
self._return_value = return_value
13+
self._side_effect = side_effect
14+
self._calls = []
15+
16+
def __call__(self, *args, **kwargs):
17+
self._calls.append(call(*args, **kwargs))
18+
if self._side_effect:
19+
raise self._side_effect
20+
return self._return_value
21+
22+
def assert_called(self):
23+
"""Assert that the mock was called at least once."""
24+
assert len(self._calls) > 0, "Expected mock to be called, but it was not."
25+
26+
def assert_not_called(self):
27+
"""Assert that the mock was not called."""
28+
assert len(self._calls) == 0, "Expected mock to not be called, but it was."
29+
30+
def assert_called_with(self, *args, **kwargs):
31+
"""Assert that the mock was last called with the given arguments."""
32+
# First call should be self, so we prepend it
33+
expected_args = [self] + list(args)
34+
expectation = call(*expected_args, **kwargs)
35+
36+
# Try to have a useful output for assertion failures
37+
assert self._calls[-1] == expectation, "Expected call with {}, got {}".format(
38+
expectation, self._calls[-1]
39+
)
40+
41+
def assert_has_calls(self, calls):
42+
"""Assert that the mock has the expected calls with arguments."""
43+
assert self._calls == calls, "Expected calls {}, got {}".format(
44+
calls, self._calls
45+
)
46+
47+
48+
class AsyncMock(Mock):
49+
"""An async version of Mock that can be awaited."""
50+
51+
async def __call__(self, *args, **kwargs):
52+
return super().__call__(self, *args, **kwargs)
53+
54+
def assert_awaited(self):
55+
"""Assert that the async mock was awaited at least once."""
56+
return super().assert_called()
57+
58+
def assert_not_awaited(self):
59+
"""Assert that the async mock was not awaited."""
60+
return super().assert_not_called()
61+
62+
def assert_awaited_with(self, *args, **kwargs):
63+
"""Assert that the async mock was last awaited with the given arguments."""
64+
return super().assert_called_with(*args, **kwargs)
65+
66+
def assert_has_awaits(self, awaits):
67+
"""Assert that the async mock has the expected awaits with arguments."""
68+
return super().assert_has_calls(awaits)

0 commit comments

Comments
 (0)