Skip to content

Commit 3d9f508

Browse files
authored
Speed up device creation and connection in mock mode (#641)
1 parent f8fae43 commit 3d9f508

File tree

15 files changed

+288
-245
lines changed

15 files changed

+288
-245
lines changed

src/ophyd_async/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
DEFAULT_TIMEOUT,
8484
CalculatableTimeout,
8585
Callback,
86+
LazyMock,
8687
NotConnected,
8788
Reference,
8889
StrictEnum,
@@ -176,6 +177,7 @@
176177
"DEFAULT_TIMEOUT",
177178
"CalculatableTimeout",
178179
"Callback",
180+
"LazyMock",
179181
"CALCULATE_TIMEOUT",
180182
"NotConnected",
181183
"Reference",

src/ophyd_async/core/_device.py

Lines changed: 70 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
import asyncio
44
import sys
55
from collections.abc import Coroutine, Iterator, Mapping, MutableMapping
6+
from functools import cached_property
67
from logging import LoggerAdapter, getLogger
78
from typing import Any, TypeVar
8-
from unittest.mock import Mock
99

1010
from bluesky.protocols import HasName
1111
from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop
1212

1313
from ._protocol import Connectable
14-
from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection
15-
16-
_device_mocks: dict[Device, Mock] = {}
14+
from ._utils import DEFAULT_TIMEOUT, LazyMock, NotConnected, wait_for_connection
1715

1816

1917
class DeviceConnector:
@@ -37,25 +35,23 @@ def create_children_from_annotations(self, device: Device):
3735
during ``__init__``.
3836
"""
3937

40-
async def connect(
41-
self,
42-
device: Device,
43-
mock: bool | Mock,
44-
timeout: float,
45-
force_reconnect: bool,
46-
):
38+
async def connect_mock(self, device: Device, mock: LazyMock):
39+
# Connect serially, no errors to gather up as in mock mode
40+
for name, child_device in device.children():
41+
await child_device.connect(mock=mock.child(name))
42+
43+
async def connect_real(self, device: Device, timeout: float, force_reconnect: bool):
4744
"""Used during ``Device.connect``.
4845
4946
This is called when a previous connect has not been done, or has been
5047
done in a different mock more. It should connect the Device and all its
5148
children.
5249
"""
53-
coros = {}
54-
for name, child_device in device.children():
55-
child_mock = getattr(mock, name) if mock else mock # Mock() or False
56-
coros[name] = child_device.connect(
57-
mock=child_mock, timeout=timeout, force_reconnect=force_reconnect
58-
)
50+
# Connect in parallel, gathering up NotConnected errors
51+
coros = {
52+
name: child_device.connect(timeout=timeout, force_reconnect=force_reconnect)
53+
for name, child_device in device.children()
54+
}
5955
await wait_for_connection(**coros)
6056

6157

@@ -67,9 +63,8 @@ class Device(HasName, Connectable):
6763
parent: Device | None = None
6864
# None if connect hasn't started, a Task if it has
6965
_connect_task: asyncio.Task | None = None
70-
# If not None, then this is the mock arg of the previous connect
71-
# to let us know if we can reuse an existing connection
72-
_connect_mock_arg: bool | None = None
66+
# The mock if we have connected in mock mode
67+
_mock: LazyMock | None = None
7368

7469
def __init__(
7570
self, name: str = "", connector: DeviceConnector | None = None
@@ -83,10 +78,18 @@ def name(self) -> str:
8378
"""Return the name of the Device"""
8479
return self._name
8580

81+
@cached_property
82+
def _child_devices(self) -> dict[str, Device]:
83+
return {}
84+
8685
def children(self) -> Iterator[tuple[str, Device]]:
87-
for attr_name, attr in self.__dict__.items():
88-
if attr_name != "parent" and isinstance(attr, Device):
89-
yield attr_name, attr
86+
yield from self._child_devices.items()
87+
88+
@cached_property
89+
def log(self) -> LoggerAdapter:
90+
return LoggerAdapter(
91+
getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name}
92+
)
9093

9194
def set_name(self, name: str):
9295
"""Set ``self.name=name`` and each ``self.child.name=name+"-child"``.
@@ -97,28 +100,33 @@ def set_name(self, name: str):
97100
New name to set
98101
"""
99102
self._name = name
100-
# Ensure self.log is recreated after a name change
101-
self.log = LoggerAdapter(
102-
getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name}
103-
)
103+
# Ensure logger is recreated after a name change
104+
if "log" in self.__dict__:
105+
del self.log
104106
for child_name, child in self.children():
105107
child_name = f"{self.name}-{child_name.strip('_')}" if self.name else ""
106108
child.set_name(child_name)
107109

108110
def __setattr__(self, name: str, value: Any) -> None:
111+
# Bear in mind that this function is called *a lot*, so
112+
# we need to make sure nothing expensive happens in it...
109113
if name == "parent":
110114
if self.parent not in (value, None):
111115
raise TypeError(
112116
f"Cannot set the parent of {self} to be {value}: "
113117
f"it is already a child of {self.parent}"
114118
)
115-
elif isinstance(value, Device):
119+
# ...hence not doing an isinstance check for attributes we
120+
# know not to be Devices
121+
elif name not in _not_device_attrs and isinstance(value, Device):
116122
value.parent = self
117-
return super().__setattr__(name, value)
123+
self._child_devices[name] = value
124+
# ...and avoiding the super call as we know it resolves to `object`
125+
return object.__setattr__(self, name, value)
118126

119127
async def connect(
120128
self,
121-
mock: bool | Mock = False,
129+
mock: bool | LazyMock = False,
122130
timeout: float = DEFAULT_TIMEOUT,
123131
force_reconnect: bool = False,
124132
) -> None:
@@ -133,26 +141,39 @@ async def connect(
133141
timeout:
134142
Time to wait before failing with a TimeoutError.
135143
"""
136-
uses_mock = bool(mock)
137-
can_use_previous_connect = (
138-
uses_mock is self._connect_mock_arg
139-
and self._connect_task
140-
and not (self._connect_task.done() and self._connect_task.exception())
141-
)
142-
if mock is True:
143-
mock = Mock() # create a new Mock if one not provided
144-
if force_reconnect or not can_use_previous_connect:
145-
self._connect_mock_arg = uses_mock
146-
if self._connect_mock_arg:
147-
_device_mocks[self] = mock
148-
coro = self._connector.connect(
149-
device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect
144+
if mock:
145+
# Always connect in mock mode serially
146+
if isinstance(mock, LazyMock):
147+
# Use the provided mock
148+
self._mock = mock
149+
elif not self._mock:
150+
# Make one
151+
self._mock = LazyMock()
152+
await self._connector.connect_mock(self, self._mock)
153+
else:
154+
# Try to cache the connect in real mode
155+
can_use_previous_connect = (
156+
self._mock is None
157+
and self._connect_task
158+
and not (self._connect_task.done() and self._connect_task.exception())
150159
)
151-
self._connect_task = asyncio.create_task(coro)
152-
153-
assert self._connect_task, "Connect task not created, this shouldn't happen"
154-
# Wait for it to complete
155-
await self._connect_task
160+
if force_reconnect or not can_use_previous_connect:
161+
self._mock = None
162+
coro = self._connector.connect_real(self, timeout, force_reconnect)
163+
self._connect_task = asyncio.create_task(coro)
164+
assert self._connect_task, "Connect task not created, this shouldn't happen"
165+
# Wait for it to complete
166+
await self._connect_task
167+
168+
169+
_not_device_attrs = {
170+
"_name",
171+
"_children",
172+
"_connector",
173+
"_timeout",
174+
"_mock",
175+
"_connect_task",
176+
}
156177

157178

158179
DeviceT = TypeVar("DeviceT", bound=Device)

src/ophyd_async/core/_mock_signal_backend.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import asyncio
22
from collections.abc import Callable
33
from functools import cached_property
4-
from unittest.mock import AsyncMock, Mock
4+
from unittest.mock import AsyncMock
55

66
from bluesky.protocols import Descriptor, Reading
77

88
from ._signal_backend import SignalBackend, SignalDatatypeT
99
from ._soft_signal_backend import SoftSignalBackend
10-
from ._utils import Callback
10+
from ._utils import Callback, LazyMock
1111

1212

1313
class MockSignalBackend(SignalBackend[SignalDatatypeT]):
@@ -16,7 +16,7 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]):
1616
def __init__(
1717
self,
1818
initial_backend: SignalBackend[SignalDatatypeT],
19-
mock: Mock,
19+
mock: LazyMock,
2020
) -> None:
2121
if isinstance(initial_backend, MockSignalBackend):
2222
raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend")
@@ -34,19 +34,22 @@ def __init__(
3434

3535
# use existing Mock if provided
3636
self.mock = mock
37-
self.put_mock = AsyncMock(name="put", spec=Callable)
38-
self.mock.attach_mock(self.put_mock, "put")
39-
4037
super().__init__(datatype=self.initial_backend.datatype)
4138

39+
@cached_property
40+
def put_mock(self) -> AsyncMock:
41+
put_mock = AsyncMock(name="put", spec=Callable)
42+
self.mock().attach_mock(put_mock, "put")
43+
return put_mock
44+
4245
def set_value(self, value: SignalDatatypeT):
4346
self.soft_backend.set_value(value)
4447

4548
def source(self, name: str, read: bool) -> str:
4649
return f"mock+{self.initial_backend.source(name, read)}"
4750

4851
async def connect(self, timeout: float) -> None:
49-
pass
52+
raise RuntimeError("It is not possible to connect a MockSignalBackend")
5053

5154
@cached_property
5255
def put_proceeds(self) -> asyncio.Event:

src/ophyd_async/core/_mock_signal_utils.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,26 @@
22
from contextlib import asynccontextmanager, contextmanager
33
from unittest.mock import AsyncMock, Mock
44

5-
from ._device import Device, _device_mocks
5+
from ._device import Device
66
from ._mock_signal_backend import MockSignalBackend
7-
from ._signal import Signal, SignalR, _mock_signal_backends
7+
from ._signal import Signal, SignalConnector, SignalR
88
from ._soft_signal_backend import SignalDatatypeT
9+
from ._utils import LazyMock
10+
11+
12+
def get_mock(device: Device | Signal) -> Mock:
13+
mock = device._mock # noqa: SLF001
14+
assert isinstance(mock, LazyMock), f"Device {device} not connected in mock mode"
15+
return mock()
916

1017

1118
def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend:
12-
assert (
13-
signal in _mock_signal_backends
19+
connector = signal._connector # noqa: SLF001
20+
assert isinstance(connector, SignalConnector), f"Expected Signal, got {signal}"
21+
assert isinstance(
22+
connector.backend, MockSignalBackend
1423
), f"Signal {signal} not connected in mock mode"
15-
return _mock_signal_backends[signal]
24+
return connector.backend
1625

1726

1827
def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT):
@@ -45,12 +54,6 @@ def get_mock_put(signal: Signal) -> AsyncMock:
4554
return _get_mock_signal_backend(signal).put_mock
4655

4756

48-
def get_mock(device: Device | Signal) -> Mock:
49-
if isinstance(device, Signal):
50-
return _get_mock_signal_backend(device).mock
51-
return _device_mocks[device]
52-
53-
5457
def reset_mock_put_calls(signal: Signal):
5558
backend = _get_mock_signal_backend(signal)
5659
backend.put_mock.reset_mock()

src/ophyd_async/core/_signal.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import functools
55
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
66
from typing import Any, Generic, cast
7-
from unittest.mock import Mock
87

98
from bluesky.protocols import (
109
Locatable,
@@ -30,9 +29,14 @@
3029
)
3130
from ._soft_signal_backend import SoftSignalBackend
3231
from ._status import AsyncStatus
33-
from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T
34-
35-
_mock_signal_backends: dict[Device, MockSignalBackend] = {}
32+
from ._utils import (
33+
CALCULATE_TIMEOUT,
34+
DEFAULT_TIMEOUT,
35+
CalculatableTimeout,
36+
Callback,
37+
LazyMock,
38+
T,
39+
)
3640

3741

3842
async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T:
@@ -54,26 +58,28 @@ class SignalConnector(DeviceConnector):
5458
def __init__(self, backend: SignalBackend):
5559
self.backend = self._init_backend = backend
5660

57-
async def connect(
58-
self,
59-
device: Device,
60-
mock: bool | Mock,
61-
timeout: float,
62-
force_reconnect: bool,
63-
):
64-
if mock:
65-
self.backend = MockSignalBackend(self._init_backend, mock)
66-
_mock_signal_backends[device] = self.backend
67-
else:
68-
self.backend = self._init_backend
61+
async def connect_mock(self, device: Device, mock: LazyMock):
62+
self.backend = MockSignalBackend(self._init_backend, mock)
63+
64+
async def connect_real(self, device: Device, timeout: float, force_reconnect: bool):
65+
self.backend = self._init_backend
6966
device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}")
7067
await self.backend.connect(timeout)
7168

7269

70+
class _ChildrenNotAllowed(dict[str, Device]):
71+
def __setitem__(self, key: str, value: Device) -> None:
72+
raise AttributeError(
73+
f"Cannot add Device or Signal child {key}={value} of Signal, "
74+
"make a subclass of Device instead"
75+
)
76+
77+
7378
class Signal(Device, Generic[SignalDatatypeT]):
7479
"""A Device with the concept of a value, with R, RW, W and X flavours"""
7580

7681
_connector: SignalConnector
82+
_child_devices = _ChildrenNotAllowed() # type: ignore
7783

7884
def __init__(
7985
self,
@@ -89,14 +95,6 @@ def source(self) -> str:
8995
"""Like ca://PV_PREFIX:SIGNAL, or "" if not set"""
9096
return self._connector.backend.source(self.name, read=True)
9197

92-
def __setattr__(self, name: str, value: Any) -> None:
93-
if name != "parent" and isinstance(value, Device):
94-
raise AttributeError(
95-
f"Cannot add Device or Signal {value} as a child of Signal {self}, "
96-
"make a subclass of Device instead"
97-
)
98-
return super().__setattr__(name, value)
99-
10098

10199
class _SignalCache(Generic[SignalDatatypeT]):
102100
def __init__(self, backend: SignalBackend[SignalDatatypeT], signal: Signal):

0 commit comments

Comments
 (0)