Skip to content

Derived signal type hint checking should support subclasses #1166

@oliwenmandiamond

Description

@oliwenmandiamond

When trying to create a derived_signal_rw, if you pass in a class and then you give it something that is a subclass of what you wanted, it fails. My example is as follows:

from typing import Generic, Protocol, TypeVar

from bluesky.protocols import Movable
from ophyd_async.core import (
    AsyncStatus,
    EnumTypes,
    Reference,
    SignalRW,
    StandardReadable,
    StrictEnum,
    derived_signal_rw,
)
from ophyd_async.epics.core import epics_signal_rw

from dodal.devices.selectable_source import SelectedSource, get_obj_from_selected_source


EnumTypesT = TypeVar("EnumTypesT", bound=EnumTypes)


class FastShutter(Movable[EnumTypesT], Protocol, Generic[EnumTypesT]):
    open_state: EnumTypesT
    close_state: EnumTypesT
    shutter_state: SignalRW[EnumTypesT]

    @AsyncStatus.wrap
    async def set(self, state: EnumTypesT):
        self.shutter_state.set(state)


class GenericFastShutter(StandardReadable, FastShutter[EnumTypesT], Generic[EnumTypesT]):
    """
    Basic enum device specialised for a fast shutter with configured open_state and
    close_state so it is generic enough to be used with any device or plan without
    knowing the specific enum to use.

    For example:
        await shutter.set(shutter.open_state)
        await shutter.set(shutter.close_state)
    OR
        run_engine(bps.mv(shutter, shutter.open_state))
        run_engine(bps.mv(shutter, shutter.close_state))
    """

    def __init__(
        self,
        pv: str,
        open_state: EnumTypesT,
        close_state: EnumTypesT,
        name: str = "",
    ):
        """
        Arguments:
            pv: The pv to connect to the shutter device.
            open_state: The enum value that corresponds with opening the shutter.
            close_state: The enum value that corresponds with closing the shutter.
        """
        self.open_state = open_state
        self.close_state = close_state
        with self.add_children_as_readables():
            self.shutter_state = epics_signal_rw(type(self.open_state), pv)
        super().__init__(name)

We have a protocol that defines a simple FastShutter. On beamlines, fast shutters use different enums, so this is an attempt to solve this issue e.g i09 uses InOut, B07 uses OpenClose.

e.g I do not need to know the underlying enum used because I can used the saved states

@device_factory()
def fsi1() -> GenericFastShutter[InOut]:
    return GenericFastShutter[InOut](
        f"{I_PREFIX.beamline_prefix}-EA-FSHTR-01:CTRL", InOut.OUT, InOut.IN
    )

def my_plan(shutter: FastShutter = inject("fsi1")):
    yield from mv(shutter, shutter.open_state) # Open state configured to OUT
    yield from mv(shutter, shutter.close_state) # Close state configured to IN

So the issue we have ran into is that we also have a use case where we need a device to coordinate the opening and shutting of two shutters e.g i09 has PGM and DCM source.

We created a DualFastShutter which uses the FastShutter Protocol and it takes two GenericFastShutters as reference. It then uses a derived signal to read the active shutter and coordinate the set value depending on which one is selected.

class DualFastShutter(StandardReadable, FastShutter[EnumTypesT], Generic[EnumTypesT]):
    def __init__(
        self,
        shutter1: GenericFastShutter[EnumTypesT],
        shutter2: GenericFastShutter[EnumTypesT],
        selected_source: SignalRW[SelectedSource],
        name: str = "",
    ):
        self._validate_shutter_states(shutter1.open_state, shutter2.open_state)
        self._validate_shutter_states(shutter1.close_state, shutter2.close_state)
        self.open_state = shutter1.open_state
        self.close_state = shutter1.close_state

        self.shutter1_ref = Reference(shutter1)
        self.shutter2_ref = Reference(shutter2)
        self.selected_shutter_ref = Reference(selected_source)
        self.shutter_state = derived_signal_rw(
            self._read_shutter_state,
            self._set_shutter_state,
            selected_shutter=selected_source,
            shutter1=shutter1.shutter_state,
            shutter2=shutter2.shutter_state,
        )
        super().__init__(name)

    def _validate_shutter_states(
        self,
        state1: EnumTypesT,
        state2: EnumTypesT,
    ) -> None:
        if state1 is not state2:
            raise ValueError(
                f"{state1} is not same value as {state2}. They must be the same to be compatible. "
            )

    def _read_shutter_state(
        self,
        selected_shutter: SelectedSource,
        shutter1: EnumTypesT,
        shutter2: EnumTypesT,
    ) -> EnumTypesT:
        return get_obj_from_selected_source(selected_shutter, shutter1, shutter2)

    async def _set_shutter_state(self, value: EnumTypesT):
        selected_shutter = await self.selected_shutter_ref().get_value()
        active_shutter = get_obj_from_selected_source(
            selected_shutter,
            self.shutter1_ref(),
            self.shutter2_ref(),
        )
        inactive_shutter = get_obj_from_selected_source(
            selected_shutter,
            self.shutter2_ref(),
            self.shutter1_ref(),
        )
        await inactive_shutter.set(inactive_shutter.close_state)
        await active_shutter.set(value)

So we have our shutter and it is going to be one of the ophyd-async enum types. This is decided when configuring the device we don't know which enum it is going to be using. The issue is when we try to use this device we the following error:

@device_factory()
def source_selector() -> SourceSelector:
    return SourceSelector()


@device_factory()
def fsi1() -> GenericFastShutter[InOut]:
    return GenericFastShutter[InOut](
        f"{I_PREFIX.beamline_prefix}-EA-FSHTR-01:CTRL", InOut.OUT, InOut.IN
    )


@device_factory()
def fsj1() -> GenericFastShutter[InOut]:
    return GenericFastShutter[InOut](
        f"{J_PREFIX.beamline_prefix}-EA-FSHTR-01:CTRL", InOut.OUT, InOut.IN
    )


@device_factory()
def dual_fast_shutter() -> DualFastShutter[InOut]:
    return DualFastShutter[InOut](fsi1(), fsj1(), source_selector().selected_source)

Then I do dodal connect i09

  File "/workspaces/dodal/src/dodal/cli.py", line 96, in connect
    raise NotConnectedError(exceptions)
ophyd_async.core._utils.NotConnectedError: 
dual_fast_shutter: TypeError: Expected the following to be passed as keyword arguments {'selected_shutter': <enum 'SelectedSource'>, 'shutter1': ~EnumTypesT, 'shutter2': ~EnumTypesT}, got {'selected_shutter': <enum 'SelectedSource'>, 'shutter1': <enum 'InOut'>, 'shutter2': <enum 'InOut'>}

It doesn't recognise that InOut is a subclass of EnumTypesT and therefore fails creating the signal when it shouldn't

Acceptance Criteria

  • derived_signal_rw type hint works with sub classes

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions