diff --git a/src/ophyd_async/core/_device_filler.py b/src/ophyd_async/core/_device_filler.py index e4154c2599..b49d3fd585 100644 --- a/src/ophyd_async/core/_device_filler.py +++ b/src/ophyd_async/core/_device_filler.py @@ -1,5 +1,6 @@ from __future__ import annotations +import types from abc import abstractmethod from collections.abc import Callable, Iterator, Sequence from typing import ( @@ -9,8 +10,10 @@ NoReturn, Protocol, TypeVar, + Union, cast, get_args, + get_origin, get_type_hints, runtime_checkable, ) @@ -76,6 +79,7 @@ def __init__( self._extras: dict[UniqueName, Sequence[Any]] = {} self._signal_datatype: dict[LogicalName, type | None] = {} self._vector_device_type: dict[LogicalName, type[Device] | None] = {} + self._optional_devices: set[str] = set() self.ignored_signals: set[str] = set() # Backends and Connectors stored ready for the connection phase self._unfilled_backends: dict[ @@ -121,6 +125,20 @@ def _scan_for_annotations(self): self.ignored_signals.add(attr_name) name = UniqueName(attr_name) origin = get_origin_class(annotation) + args = get_args(annotation) + + if ( + get_origin(annotation) is Union + and types.NoneType in args + and len(args) == 2 + ): + # Annotation is an Union with two arguments, one of which is None + # Make this signal an optional parameter and set origin to T + # so the device is added to unfilled_connectors + self._optional_devices.add(name) + (annotation,) = [x for x in args if x is not types.NoneType] + origin = get_origin_class(annotation) + if ( name == "parent" or name.startswith("_") @@ -241,10 +259,17 @@ def check_filled(self, source: str): :param source: The source of the data that should have done the filling, for reporting as an error message """ - unfilled = sorted(set(self._unfilled_connectors).union(self._unfilled_backends)) - if unfilled: + unfilled = set(self._unfilled_connectors).union(self._unfilled_backends) + unfilled_optional = sorted(unfilled.intersection(self._optional_devices)) + + for name in unfilled_optional: + setattr(self._device, name, None) + + required = sorted(unfilled.difference(unfilled_optional)) + + if required: raise RuntimeError( - f"{self._device.name}: cannot provision {unfilled} from {source}" + f"{self._device.name}: cannot provision {required} from {source}" ) def _ensure_device_vector(self, name: LogicalName) -> DeviceVector: diff --git a/tests/unit_tests/core/test_device.py b/tests/unit_tests/core/test_device.py index 8b13373f79..195abadc86 100644 --- a/tests/unit_tests/core/test_device.py +++ b/tests/unit_tests/core/test_device.py @@ -8,6 +8,7 @@ from ophyd_async.core import ( DEFAULT_TIMEOUT, Device, + DeviceFiller, DeviceVector, NotConnectedError, Reference, @@ -266,3 +267,50 @@ def test_setitem_with_non_device_value(): device_vector = DeviceVector(children={}) with pytest.raises(TypeError, match="Expected Device, got"): device_vector[1] = "not_a_device" + + +def test_device_filler_check_filled_with_optional_signals(): + """Test DeviceFiller.check_filled with both mandatory and optional Signals.""" + + class TestDevice(Device): + mandatory_signal: SignalRW[int] + optional_signal: SignalRW[int] | None + + # Create a mock backend factory + def mock_backend_factory(datatype): + backend = Mock() + backend.datatype = datatype + return backend + + # Create a mock connector factory + def mock_connector_factory(): + return Mock() + + device = TestDevice() + filler = DeviceFiller( + device=device, + signal_backend_factory=mock_backend_factory, + device_connector_factory=mock_connector_factory, + ) + + # Create signals from annotations (unfilled) + list(filler.create_signals_from_annotations(filled=False)) + + assert hasattr(device, "optional_signal") + assert isinstance(device.optional_signal, SignalRW) + + # Test failure path: check_filled should fail when mandatory signal is unfilled + with pytest.raises(RuntimeError, match="cannot provision.*mandatory_signal"): + filler.check_filled("test_source") + + # Fill the mandatory signal + filler.fill_child_signal("mandatory_signal", SignalRW, None) + + # Test success path: check_filled should succeed and set optional_signal to None + filler.check_filled("test_source") + + # Verify mandatory signal exists and optional signal is None + assert hasattr(device, "mandatory_signal") + assert isinstance(device.mandatory_signal, SignalRW) + assert hasattr(device, "optional_signal") + assert device.optional_signal is None diff --git a/tests/unit_tests/epics/pvi/test_pvi.py b/tests/unit_tests/epics/pvi/test_pvi.py index 44330a3283..1513c70ec8 100644 --- a/tests/unit_tests/epics/pvi/test_pvi.py +++ b/tests/unit_tests/epics/pvi/test_pvi.py @@ -52,6 +52,10 @@ class Block4(StandardReadable): signal_rw: SignalRW[int] +class Block5(Device): + signal_rw: SignalRW[int] | None + + DeviceT = TypeVar("DeviceT", bound=Device) @@ -133,6 +137,19 @@ async def test_device_create_children_from_annotations(): assert device.signal_device is top_block_1_device +async def test_device_create_device_with_optional_signals(): + device = with_pvi_connector(Block5, "PREFIX:") + + # Makes sure before connecting we create a signal and it exists in the device + assert hasattr(device, "signal_rw") + assert isinstance(device.signal_rw, SignalRW) + + await device.connect(mock=True) + + # After connecting if the optional signal is not filled it's set to None + assert isinstance(device.signal_rw, SignalRW) + + async def test_device_create_children_from_annotations_with_device_vectors(): device = with_pvi_connector(Block4, "PREFIX:", name="test_device") await device.connect(mock=True)