33import asyncio
44import sys
55from collections .abc import Coroutine , Iterator , Mapping , MutableMapping
6+ from functools import cached_property
67from logging import LoggerAdapter , getLogger
78from typing import Any , TypeVar
8- from unittest .mock import Mock
99
1010from bluesky .protocols import HasName
1111from bluesky .run_engine import call_in_bluesky_event_loop , in_bluesky_event_loop
1212
1313from ._protocol import Connectable
14- from ._utils import DEFAULT_TIMEOUT , NotConnected , wait_for_connection
15-
16- _device_mocks : dict [Device , Mock ] = {}
14+ from ._utils import DEFAULT_TIMEOUT , LazyMock , NotConnected , wait_for_connection
1715
1816
1917class DeviceConnector :
@@ -37,25 +35,23 @@ def create_children_from_annotations(self, device: Device):
3735 during ``__init__``.
3836 """
3937
40- async def connect (
41- self ,
42- device : Device ,
43- mock : bool | Mock ,
44- timeout : float ,
45- force_reconnect : bool ,
46- ):
38+ async def connect_mock (self , device : Device , mock : LazyMock ):
39+ # Connect serially, no errors to gather up as in mock mode
40+ for name , child_device in device .children ():
41+ await child_device .connect (mock = mock .child (name ))
42+
43+ async def connect_real (self , device : Device , timeout : float , force_reconnect : bool ):
4744 """Used during ``Device.connect``.
4845
4946 This is called when a previous connect has not been done, or has been
5047 done in a different mock more. It should connect the Device and all its
5148 children.
5249 """
53- coros = {}
54- for name , child_device in device .children ():
55- child_mock = getattr (mock , name ) if mock else mock # Mock() or False
56- coros [name ] = child_device .connect (
57- mock = child_mock , timeout = timeout , force_reconnect = force_reconnect
58- )
50+ # Connect in parallel, gathering up NotConnected errors
51+ coros = {
52+ name : child_device .connect (timeout = timeout , force_reconnect = force_reconnect )
53+ for name , child_device in device .children ()
54+ }
5955 await wait_for_connection (** coros )
6056
6157
@@ -67,9 +63,8 @@ class Device(HasName, Connectable):
6763 parent : Device | None = None
6864 # None if connect hasn't started, a Task if it has
6965 _connect_task : asyncio .Task | None = None
70- # If not None, then this is the mock arg of the previous connect
71- # to let us know if we can reuse an existing connection
72- _connect_mock_arg : bool | None = None
66+ # The mock if we have connected in mock mode
67+ _mock : LazyMock | None = None
7368
7469 def __init__ (
7570 self , name : str = "" , connector : DeviceConnector | None = None
@@ -83,10 +78,18 @@ def name(self) -> str:
8378 """Return the name of the Device"""
8479 return self ._name
8580
81+ @cached_property
82+ def _child_devices (self ) -> dict [str , Device ]:
83+ return {}
84+
8685 def children (self ) -> Iterator [tuple [str , Device ]]:
87- for attr_name , attr in self .__dict__ .items ():
88- if attr_name != "parent" and isinstance (attr , Device ):
89- yield attr_name , attr
86+ yield from self ._child_devices .items ()
87+
88+ @cached_property
89+ def log (self ) -> LoggerAdapter :
90+ return LoggerAdapter (
91+ getLogger ("ophyd_async.devices" ), {"ophyd_async_device_name" : self .name }
92+ )
9093
9194 def set_name (self , name : str ):
9295 """Set ``self.name=name`` and each ``self.child.name=name+"-child"``.
@@ -97,28 +100,33 @@ def set_name(self, name: str):
97100 New name to set
98101 """
99102 self ._name = name
100- # Ensure self.log is recreated after a name change
101- self .log = LoggerAdapter (
102- getLogger ("ophyd_async.devices" ), {"ophyd_async_device_name" : self .name }
103- )
103+ # Ensure logger is recreated after a name change
104+ if "log" in self .__dict__ :
105+ del self .log
104106 for child_name , child in self .children ():
105107 child_name = f"{ self .name } -{ child_name .strip ('_' )} " if self .name else ""
106108 child .set_name (child_name )
107109
108110 def __setattr__ (self , name : str , value : Any ) -> None :
111+ # Bear in mind that this function is called *a lot*, so
112+ # we need to make sure nothing expensive happens in it...
109113 if name == "parent" :
110114 if self .parent not in (value , None ):
111115 raise TypeError (
112116 f"Cannot set the parent of { self } to be { value } : "
113117 f"it is already a child of { self .parent } "
114118 )
115- elif isinstance (value , Device ):
119+ # ...hence not doing an isinstance check for attributes we
120+ # know not to be Devices
121+ elif name not in _not_device_attrs and isinstance (value , Device ):
116122 value .parent = self
117- return super ().__setattr__ (name , value )
123+ self ._child_devices [name ] = value
124+ # ...and avoiding the super call as we know it resolves to `object`
125+ return object .__setattr__ (self , name , value )
118126
119127 async def connect (
120128 self ,
121- mock : bool | Mock = False ,
129+ mock : bool | LazyMock = False ,
122130 timeout : float = DEFAULT_TIMEOUT ,
123131 force_reconnect : bool = False ,
124132 ) -> None :
@@ -133,26 +141,39 @@ async def connect(
133141 timeout:
134142 Time to wait before failing with a TimeoutError.
135143 """
136- uses_mock = bool (mock )
137- can_use_previous_connect = (
138- uses_mock is self ._connect_mock_arg
139- and self ._connect_task
140- and not (self ._connect_task .done () and self ._connect_task .exception ())
141- )
142- if mock is True :
143- mock = Mock () # create a new Mock if one not provided
144- if force_reconnect or not can_use_previous_connect :
145- self ._connect_mock_arg = uses_mock
146- if self ._connect_mock_arg :
147- _device_mocks [self ] = mock
148- coro = self ._connector .connect (
149- device = self , mock = mock , timeout = timeout , force_reconnect = force_reconnect
144+ if mock :
145+ # Always connect in mock mode serially
146+ if isinstance (mock , LazyMock ):
147+ # Use the provided mock
148+ self ._mock = mock
149+ elif not self ._mock :
150+ # Make one
151+ self ._mock = LazyMock ()
152+ await self ._connector .connect_mock (self , self ._mock )
153+ else :
154+ # Try to cache the connect in real mode
155+ can_use_previous_connect = (
156+ self ._mock is None
157+ and self ._connect_task
158+ and not (self ._connect_task .done () and self ._connect_task .exception ())
150159 )
151- self ._connect_task = asyncio .create_task (coro )
152-
153- assert self ._connect_task , "Connect task not created, this shouldn't happen"
154- # Wait for it to complete
155- await self ._connect_task
160+ if force_reconnect or not can_use_previous_connect :
161+ self ._mock = None
162+ coro = self ._connector .connect_real (self , timeout , force_reconnect )
163+ self ._connect_task = asyncio .create_task (coro )
164+ assert self ._connect_task , "Connect task not created, this shouldn't happen"
165+ # Wait for it to complete
166+ await self ._connect_task
167+
168+
169+ _not_device_attrs = {
170+ "_name" ,
171+ "_children" ,
172+ "_connector" ,
173+ "_timeout" ,
174+ "_mock" ,
175+ "_connect_task" ,
176+ }
156177
157178
158179DeviceT = TypeVar ("DeviceT" , bound = Device )
0 commit comments