Skip to content

Commit ee65c72

Browse files
authored
Expose mock directly (#335)
1 parent 39dbff6 commit ee65c72

File tree

7 files changed

+80
-51
lines changed

7 files changed

+80
-51
lines changed

docs/how-to/write-tests-for-devices.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ In addition this example also utilizes helper functions like ``assert_reading``
4141
:pyobject: test_sensor_reading_shows_value
4242

4343

44+
Given that the mock signal holds a ``unittest.mock.Mock`` object you can retrieve this object and assert that the device has been set correctly using ``get_mock_put``. You are also free to use any other behaviour that ``unittest.mock.Mock`` provides, such as in this example which sets the parent of the mock to allow ordering across signals to be asserted:
45+
46+
.. literalinclude:: ../../tests/epics/demo/test_demo.py
47+
:pyobject: test_retrieve_mock_and_assert
48+
4449
There are several other test utility functions:
4550

4651
Use ``callback_on_mock_put``, for hooking in logic when a mock value changes (e.g. because someone puts to it). This can be called directly, or used as a context, with the callbacks ending after exit.

src/ophyd_async/core/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@
2424
walk_rw_signals,
2525
)
2626
from .flyer import HardwareTriggeredFlyable, TriggerLogic
27-
from .mock_signal_backend import (
28-
MockSignalBackend,
29-
)
27+
from .mock_signal_backend import MockSignalBackend
3028
from .mock_signal_utils import (
31-
assert_mock_put_called_with,
3229
callback_on_mock_put,
30+
get_mock_put,
3331
mock_puts_blocked,
3432
reset_mock_put_calls,
3533
set_mock_put_proceeds,
@@ -70,7 +68,7 @@
7068
)
7169

7270
__all__ = [
73-
"assert_mock_put_called_with",
71+
"get_mock_put",
7472
"callback_on_mock_put",
7573
"mock_puts_blocked",
7674
"set_mock_values",

src/ophyd_async/core/mock_signal_backend.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from functools import cached_property
3-
from typing import Optional, Type
3+
from typing import Callable, Optional, Type
44
from unittest.mock import Mock
55

66
from bluesky.protocols import Descriptor, Reading
@@ -10,7 +10,7 @@
1010
from ophyd_async.core.utils import DEFAULT_TIMEOUT, ReadingValueCallback, T
1111

1212

13-
class MockSignalBackend(SignalBackend):
13+
class MockSignalBackend(SignalBackend[T]):
1414
def __init__(
1515
self,
1616
datatype: Optional[Type[T]] = None,
@@ -31,11 +31,11 @@ def __init__(
3131

3232
if not isinstance(self.initial_backend, SoftSignalBackend):
3333
# If the backend is a hard signal backend, or not provided,
34-
# then we create a soft signal to mimick it
34+
# then we create a soft signal to mimic it
3535

3636
self.soft_backend = SoftSignalBackend(datatype=datatype)
3737
else:
38-
self.soft_backend = initial_backend
38+
self.soft_backend = self.initial_backend
3939

4040
def source(self, name: str) -> str:
4141
if self.initial_backend:
@@ -47,7 +47,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None:
4747

4848
@cached_property
4949
def put_mock(self) -> Mock:
50-
return Mock(name="put")
50+
return Mock(name="put", spec=Callable)
5151

5252
@cached_property
5353
def put_proceeds(self) -> asyncio.Event:
@@ -65,9 +65,6 @@ async def put(self, value: Optional[T], wait=True, timeout=None):
6565
def set_value(self, value: T):
6666
self.soft_backend.set_value(value)
6767

68-
async def get_descriptor(self, source: str) -> Descriptor:
69-
return await self.soft_backend.get_descriptor(source)
70-
7168
async def get_reading(self) -> Reading:
7269
return await self.soft_backend.get_reading()
7370

src/ophyd_async/core/mock_signal_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from contextlib import asynccontextmanager, contextmanager
2-
from typing import Any, Callable, Iterable, Iterator, List
3-
from unittest.mock import ANY, Mock
2+
from typing import Any, Callable, Iterable
3+
from unittest.mock import Mock
44

55
from ophyd_async.core.signal import Signal
66
from ophyd_async.core.utils import T
@@ -22,7 +22,7 @@ def set_mock_value(signal: Signal[T], value: T):
2222
backend.set_value(value)
2323

2424

25-
def set_mock_put_proceeds(signal: Signal[T], proceeds: bool):
25+
def set_mock_put_proceeds(signal: Signal, proceeds: bool):
2626
"""Allow or block a put with wait=True from proceeding"""
2727
backend = _get_mock_signal_backend(signal)
2828

@@ -33,17 +33,17 @@ def set_mock_put_proceeds(signal: Signal[T], proceeds: bool):
3333

3434

3535
@asynccontextmanager
36-
async def mock_puts_blocked(*signals: List[Signal]):
36+
async def mock_puts_blocked(*signals: Signal):
3737
for signal in signals:
3838
set_mock_put_proceeds(signal, False)
3939
yield
4040
for signal in signals:
4141
set_mock_put_proceeds(signal, True)
4242

4343

44-
def assert_mock_put_called_with(signal: Signal, value: Any, wait=ANY, timeout=ANY):
45-
backend = _get_mock_signal_backend(signal)
46-
backend.put_mock.assert_called_with(value, wait=wait, timeout=timeout)
44+
def get_mock_put(signal: Signal) -> Mock:
45+
"""Get the mock associated with the put call on the signal."""
46+
return _get_mock_signal_backend(signal).put_mock
4747

4848

4949
def reset_mock_put_calls(signal: Signal):
@@ -79,15 +79,15 @@ def __next__(self):
7979
return next_value
8080

8181
def __del__(self):
82-
if self.require_all_consumed and self.index != len(self.values):
82+
if self.require_all_consumed and self.index != len(list(self.values)):
8383
raise AssertionError("Not all values have been consumed.")
8484

8585

8686
def set_mock_values(
8787
signal: Signal,
8888
values: Iterable[Any],
8989
require_all_consumed: bool = False,
90-
) -> Iterator[Any]:
90+
) -> _SetValuesIterator:
9191
"""Iterator to set a signal to a sequence of values, optionally repeating the
9292
sequence.
9393
@@ -127,7 +127,7 @@ def _unset_side_effect_cm(put_mock: Mock):
127127
put_mock.side_effect = None
128128

129129

130-
def callback_on_mock_put(signal: Signal, callback: Callable[[T], None]):
130+
def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]):
131131
"""For setting a callback when a backend is put to.
132132
133133
Can either be used in a context, with the callback being

src/ophyd_async/core/signal_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class SignalBackend(Generic[T]):
1414

1515
#: Like ca://PV_PREFIX:SIGNAL
1616
@abstractmethod
17-
def source(name: str) -> str:
17+
def source(self, name: str) -> str:
1818
"""Return source of signal. Signals may pass a name to the backend, which can be
1919
used or discarded."""
2020

tests/core/test_mock_signal_backend.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
import asyncio
22
import re
33
from itertools import repeat
4-
from unittest.mock import MagicMock, call
4+
from unittest.mock import ANY, MagicMock, call
55

66
import pytest
77

88
from ophyd_async.core import MockSignalBackend, SignalRW
99
from ophyd_async.core.device import Device, DeviceCollector
1010
from ophyd_async.core.mock_signal_utils import (
11-
assert_mock_put_called_with,
1211
callback_on_mock_put,
12+
get_mock_put,
1313
mock_puts_blocked,
1414
reset_mock_put_calls,
1515
set_mock_put_proceeds,
1616
set_mock_value,
1717
set_mock_values,
1818
)
19-
from ophyd_async.core.signal import (
20-
SignalW,
21-
soft_signal_r_and_setter,
22-
soft_signal_rw,
23-
)
19+
from ophyd_async.core.signal import SignalW, soft_signal_r_and_setter, soft_signal_rw
2420
from ophyd_async.core.soft_signal_backend import SoftSignalBackend
2521
from ophyd_async.epics.signal.signal import epics_signal_r, epics_signal_rw
2622

@@ -31,6 +27,7 @@ async def test_mock_signal_backend(connect_mock_mode):
3127
# If mock is false it will be handled like a normal signal, otherwise it will
3228
# initalize a new backend from the one in the line above
3329
await mock_signal.connect(mock=connect_mock_mode)
30+
assert isinstance(mock_signal._backend, MockSignalBackend)
3431

3532
assert await mock_signal._backend.get_value() == ""
3633
await mock_signal._backend.put("test")
@@ -74,6 +71,8 @@ async def test_set_mock_put_proceeds():
7471
mock_signal = SignalW(SoftSignalBackend(str))
7572
await mock_signal.connect(mock=True)
7673

74+
assert isinstance(mock_signal._backend, MockSignalBackend)
75+
7776
assert mock_signal._backend.put_proceeds.is_set() is True
7877

7978
set_mock_put_proceeds(mock_signal, False)
@@ -95,6 +94,7 @@ async def test_set_mock_put_proceeds_timeout():
9594
async def test_put_proceeds_timeout():
9695
mock_signal = SignalW(SoftSignalBackend(str))
9796
await mock_signal.connect(mock=True)
97+
assert isinstance(mock_signal._backend, MockSignalBackend)
9898

9999
assert mock_signal._backend.put_proceeds.is_set() is True
100100

@@ -112,14 +112,14 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend():
112112
set_mock_value(signal, 10)
113113
exc_msgs.append(str(exc.value))
114114
with pytest.raises(AssertionError) as exc:
115-
assert_mock_put_called_with(signal, 10)
115+
get_mock_put(signal).assert_called_once_with(10)
116116
exc_msgs.append(str(exc.value))
117117
with pytest.raises(AssertionError) as exc:
118-
async with mock_puts_blocked(signal, 10):
118+
async with mock_puts_blocked(signal):
119119
...
120120
exc_msgs.append(str(exc.value))
121121
with pytest.raises(AssertionError) as exc:
122-
with callback_on_mock_put(signal, 10):
122+
with callback_on_mock_put(signal, lambda x: _):
123123
...
124124
exc_msgs.append(str(exc.value))
125125
with pytest.raises(AssertionError) as exc:
@@ -137,16 +137,13 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend():
137137
)
138138

139139

140-
async def test_assert_mock_put_called_with():
140+
async def test_get_mock_put():
141141
mock_signal = epics_signal_rw(str, "READ_PV", "WRITE_PV", name="mock_name")
142142
await mock_signal.connect(mock=True)
143143
await mock_signal.set("test_value", wait=True, timeout=100)
144144

145-
# can leave out kwargs
146-
assert_mock_put_called_with(mock_signal, "test_value")
147-
assert_mock_put_called_with(mock_signal, "test_value", wait=True)
148-
assert_mock_put_called_with(mock_signal, "test_value", timeout=100)
149-
assert_mock_put_called_with(mock_signal, "test_value", wait=True, timeout=100)
145+
mock = get_mock_put(mock_signal)
146+
mock.assert_called_once_with("test_value", wait=True, timeout=100)
150147

151148
def err_text(text, wait, timeout):
152149
return (
@@ -162,7 +159,7 @@ def err_text(text, wait, timeout):
162159
("test_value", False, 0), # wait and timeout wrong
163160
]:
164161
with pytest.raises(AssertionError) as exc:
165-
assert_mock_put_called_with(mock_signal, text, wait=wait, timeout=timeout)
162+
mock.assert_called_once_with(text, wait=wait, timeout=timeout)
166163
for err_substr in err_text(text, wait, timeout):
167164
assert err_substr in str(exc.value)
168165

@@ -216,10 +213,8 @@ async def test_callback_on_mock_put_no_ctx():
216213
mock_signal = SignalRW(SoftSignalBackend(float))
217214
await mock_signal.connect(mock=True)
218215
calls = []
219-
(
220-
callback_on_mock_put(
221-
mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args})
222-
),
216+
callback_on_mock_put(
217+
mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args})
223218
)
224219
await mock_signal.set(10.0)
225220
assert calls == [
@@ -249,16 +244,16 @@ def some_function_without_kwargs(arg):
249244
async def test_set_mock_values(mock_signals):
250245
signal1, signal2 = mock_signals
251246

252-
await signal2.get_value() == "first_value"
247+
assert await signal2.get_value() == "first_value"
253248
for value_set in set_mock_values(signal1, ["second_value", "third_value"]):
254249
assert await signal1.get_value() == value_set
255250

256251
iterator = set_mock_values(signal2, ["second_value", "third_value"])
257-
await signal2.get_value() == "first_value"
252+
assert await signal2.get_value() == "first_value"
258253
next(iterator)
259-
await signal2.get_value() == "second_value"
254+
assert await signal2.get_value() == "second_value"
260255
next(iterator)
261-
await signal2.get_value() == "third_value"
256+
assert await signal2.get_value() == "third_value"
262257

263258

264259
async def test_set_mock_values_exhausted_passes(mock_signals):
@@ -300,10 +295,10 @@ async def test_set_mock_values_exhausted_fails(mock_signals):
300295
async def test_reset_mock_put_calls(mock_signals):
301296
signal1, signal2 = mock_signals
302297
await signal1.set("test_value", wait=True, timeout=1)
303-
assert_mock_put_called_with(signal1, "test_value")
298+
get_mock_put(signal1).assert_called_with("test_value", wait=ANY, timeout=ANY)
304299
reset_mock_put_calls(signal1)
305300
with pytest.raises(AssertionError) as exc:
306-
assert_mock_put_called_with(signal1, "test_value")
301+
get_mock_put(signal1).assert_called_with("test_value", wait=ANY, timeout=ANY)
307302
# Replacing spaces because they change between runners
308303
# (e.g the github actions runner has more)
309304
assert str(exc.value).replace(" ", "").replace("\n", "") == (
@@ -350,3 +345,13 @@ async def set(self):
350345
assert await signal.get_value() == 0
351346
backend_put(100)
352347
assert await signal.get_value() == 100
348+
349+
350+
async def test_when_put_mock_called_with_typo_then_fails_but_calling_directly_passes():
351+
mock_signal = SignalRW(SoftSignalBackend(int))
352+
await mock_signal.connect(mock=True)
353+
assert isinstance(mock_signal._backend, MockSignalBackend)
354+
mock = mock_signal._backend.put_mock
355+
with pytest.raises(AttributeError):
356+
mock.asssert_called_once() # Note typo here is deliberate!
357+
mock()

tests/epics/demo/test_demo.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
assert_reading,
1717
assert_value,
1818
callback_on_mock_put,
19+
get_mock_put,
1920
set_mock_value,
2021
)
2122
from ophyd_async.epics import demo
@@ -176,6 +177,29 @@ async def test_sensor_reading_shows_value(mock_sensor: demo.Sensor):
176177
)
177178

178179

180+
async def test_retrieve_mock_and_assert(mock_mover: demo.Mover):
181+
mover_setpoint_mock = get_mock_put(mock_mover.setpoint)
182+
await mock_mover.setpoint.set(10)
183+
mover_setpoint_mock.assert_called_once_with(10, wait=ANY, timeout=ANY)
184+
185+
# Assert that velocity is set before move
186+
mover_velocity_mock = get_mock_put(mock_mover.velocity)
187+
188+
parent_mock = Mock()
189+
parent_mock.attach_mock(mover_setpoint_mock, "setpoint")
190+
parent_mock.attach_mock(mover_velocity_mock, "velocity")
191+
192+
await mock_mover.velocity.set(100)
193+
await mock_mover.setpoint.set(67)
194+
195+
parent_mock.assert_has_calls(
196+
[
197+
call.velocity(100, wait=True, timeout=ANY),
198+
call.setpoint(67, wait=True, timeout=ANY),
199+
]
200+
)
201+
202+
179203
async def test_read_mover(mock_mover: demo.Mover):
180204
await mock_mover.stage()
181205
assert (await mock_mover.read())["mock_mover"]["value"] == 0.0

0 commit comments

Comments
 (0)