Skip to content

Allow constants in kwargs of derived_signal #859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/how-to/derive-one-signal-from-others.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The simplest API involves mapping a single Derived Signal to many low level Sign
- [`derived_signal_rw`](#ophyd_async.core.derived_signal_rw)
- [`derived_signal_w`](#ophyd_async.core.derived_signal_w)

If a signal is readable, then it requires a `raw_to_derived` function that maps the raw values of low level Signals into the datatype of the Derived Signal and the `raw_devices` that will be read/monitored to give those values.
If a signal is readable, then it requires a `raw_to_derived` function that maps the raw values of low level Signals into the datatype of the Derived Signal and the `raw_devices_and_constants` that will be read/monitored to give those values.

If a signal is writeable, then it requires a `set_derived` async function that sets the raw signals based on the derived value.

Expand All @@ -27,7 +27,7 @@ These examples show the low level Signals and Derived Signals in the same Device

The more general API involves a two way mapping between many Derived Signals and many low level Signals. This is done by implementing a `Raw` [](#typing.TypedDict) subclass with the names and datatypes of the low level Signals, a `Derived` [](#typing.TypedDict) subclass with the names and datatypes of the derived Signals, and [](#Transform) class with `raw_to_derived` and `derived_to_raw` methods to convert between the two. Some transforms will also require parameters which get their values from other Signals for both methods. These should be put in as type hints on the `Transform` subclass.

To create the derived signals, we make a [](#DerivedSignalFactory) instance that knows about the `Transform` class, the `raw_devices` that will be read/monitored to provide the raw values for the transform, and optionally the `set_derived` method to set them. The methods like [](#DerivedSignalFactory.derived_signal_rw) allow Derived signals to be created for each attribute in the `Derived` TypedDict subclass.
To create the derived signals, we make a [](#DerivedSignalFactory) instance that knows about the `Transform` class, the `raw_devices_and_constants` that will be read/monitored to provide the raw values for the transform, and optionally the `set_derived` method to set them. The methods like [](#DerivedSignalFactory.derived_signal_rw) allow Derived signals to be created for each attribute in the `Derived` TypedDict subclass.

In the below example we see this is action:

Expand Down
2 changes: 2 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ._signal_backend import (
Array1D,
DTypeScalar_co,
Primitive,
SignalBackend,
SignalDatatype,
SignalDatatypeT,
Expand Down Expand Up @@ -124,6 +125,7 @@
"SubsetEnum",
"Table",
"SignalMetadata",
"Primitive",
# Soft signal
"SoftSignalBackend",
"soft_signal_r_and_setter",
Expand Down
90 changes: 68 additions & 22 deletions src/ophyd_async/core/_derived_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from ._device import Device
from ._signal import Signal, SignalR, SignalRW, SignalT, SignalW
from ._signal_backend import SignalDatatypeT
from ._signal_backend import Primitive, SignalDatatypeT


class DerivedSignalFactory(Generic[TransformT]):
Expand All @@ -23,18 +23,31 @@ class DerivedSignalFactory(Generic[TransformT]):
:param set_derived:
An optional async function that takes the output of
`transform_cls.raw_to_derived` and applies it to the raw devices.
:param raw_and_transform_devices:
Devices whose values will be passed as parameters to the `transform_cls`,
and as arguments to `transform_cls.raw_to_derived`.
:param raw_and_transform_devices_and_constants:
Devices and Constants whose values will be passed as parameters
to the `transform_cls`, and as arguments to `transform_cls.raw_to_derived`.
"""

def __init__(
self,
transform_cls: type[TransformT],
set_derived: Callable[..., Awaitable[None]] | None = None,
**raw_and_transform_devices,
**raw_and_transform_devices_and_constants,
):
self._set_derived = set_derived
_raw_and_transform_devices, _raw_and_transform_constants = (
{
k: v
for k, v in raw_and_transform_devices_and_constants.items()
if isinstance(v, Device)
},
{
k: v
for k, v in raw_and_transform_devices_and_constants.items()
if isinstance(v, Primitive)
},
)

# Check the raw and transform devices match the input arguments of the Transform
if transform_cls is not Transform:
# Populate expected parameters and types
Expand All @@ -48,26 +61,42 @@ def __init__(
}

# Populate received parameters and types
# Use Signal datatype, or Locatable datatype, or set type as None
# Use Primitive's type, Signal's datatype,
# Locatable's datatype, or set type as None
received = {
k: v.datatype if isinstance(v, Signal) else get_locatable_type(v)
for k, v in raw_and_transform_devices.items()
**{
k: v.datatype if isinstance(v, Signal) else get_locatable_type(v)
for k, v in _raw_and_transform_devices.items()
},
**{k: type(v) for k, v in _raw_and_transform_constants.items()},
}

if expected != received:
msg = (
f"Expected devices to be passed as keyword arguments "
f"Expected the following to be passed as keyword arguments "
f"{expected}, got {received}"
)
raise TypeError(msg)
self._set_derived_takes_dict = (
is_typeddict(_get_first_arg_datatype(set_derived)) if set_derived else False
)

_raw_constants, _transform_constants = _partition_by_keys(
_raw_and_transform_constants, set(transform_cls.model_fields)
)

_raw_devices, _transform_devices = _partition_by_keys(
_raw_and_transform_devices, set(transform_cls.model_fields)
)

self._transformer = SignalTransformer(
transform_cls,
set_derived,
self._set_derived_takes_dict,
**raw_and_transform_devices,
_raw_devices,
_raw_constants,
_transform_devices,
_transform_constants,
)

def _make_signal(
Expand Down Expand Up @@ -177,7 +206,7 @@ def _get_first_arg_datatype(
def _make_factory(
raw_to_derived: Callable[..., SignalDatatypeT] | None = None,
set_derived: Callable[[SignalDatatypeT], Awaitable[None]] | None = None,
raw_devices: dict[str, Device] | None = None,
raw_devices_and_constants: dict[str, Device | Primitive] | None = None,
) -> DerivedSignalFactory:
if raw_to_derived:

Expand All @@ -190,7 +219,9 @@ def raw_to_derived(self, **kwargs) -> dict[str, SignalDatatypeT]:
DerivedTransform.raw_to_derived.__annotations__ = get_type_hints(raw_to_derived)

return DerivedSignalFactory(
DerivedTransform, set_derived=set_derived, **(raw_devices or {})
DerivedTransform,
set_derived=set_derived,
**(raw_devices_and_constants or {}),
)
else:
return DerivedSignalFactory(Transform, set_derived=set_derived)
Expand All @@ -200,7 +231,7 @@ def derived_signal_r(
raw_to_derived: Callable[..., SignalDatatypeT],
derived_units: str | None = None,
derived_precision: int | None = None,
**raw_devices: Device,
**raw_devices_and_constants: Device | Primitive,
) -> SignalR[SignalDatatypeT]:
"""Create a read only derived signal.

Expand All @@ -209,11 +240,14 @@ def derived_signal_r(
returns the derived value.
:param derived_units: Engineering units for the derived signal
:param derived_precision: Number of digits after the decimal place to display
:param raw_devices:
A dictionary of Devices to provide the values for raw_to_derived. The names
of these arguments must match the arguments of raw_to_derived.
:param raw_devices_and_constants:
A dictionary of Devices and Constants to provide the values for raw_to_derived.
The names of these arguments must match the arguments of raw_to_derived.
"""
factory = _make_factory(raw_to_derived=raw_to_derived, raw_devices=raw_devices)
factory = _make_factory(
raw_to_derived=raw_to_derived,
raw_devices_and_constants=raw_devices_and_constants,
)
return factory.derived_signal_r(
datatype=_get_return_datatype(raw_to_derived),
name="value",
Expand All @@ -227,7 +261,7 @@ def derived_signal_rw(
set_derived: Callable[[SignalDatatypeT], Awaitable[None]],
derived_units: str | None = None,
derived_precision: int | None = None,
**raw_devices: Device,
**raw_devices_and_constants: Device | Primitive,
) -> SignalRW[SignalDatatypeT]:
"""Create a read-write derived signal.

Expand All @@ -239,9 +273,9 @@ def derived_signal_rw(
either be an async function, or return an [](#AsyncStatus)
:param derived_units: Engineering units for the derived signal
:param derived_precision: Number of digits after the decimal place to display
:param raw_devices:
A dictionary of Devices to provide the values for raw_to_derived. The names
of these arguments must match the arguments of raw_to_derived.
:param raw_devices_and_constants:
A dictionary of Devices and Constants to provide the values for raw_to_derived.
The names of these arguments must match the arguments of raw_to_derived.
"""
raw_to_derived_datatype = _get_return_datatype(raw_to_derived)
set_derived_datatype = _get_first_arg_datatype(set_derived)
Expand All @@ -253,7 +287,9 @@ def derived_signal_rw(
raise TypeError(msg)

factory = _make_factory(
raw_to_derived=raw_to_derived, set_derived=set_derived, raw_devices=raw_devices
raw_to_derived=raw_to_derived,
set_derived=set_derived,
raw_devices_and_constants=raw_devices_and_constants,
)
return factory.derived_signal_rw(
datatype=raw_to_derived_datatype,
Expand Down Expand Up @@ -297,3 +333,13 @@ def get_locatable_type(obj: object) -> type | None:
if args:
return args[0]
return None


def _partition_by_keys(data: dict, keys: set) -> tuple[dict, dict]:
group_excluded, group_included = {}, {}
for k, v in data.items():
if k in keys:
group_included[k] = v
else:
group_excluded[k] = v
return group_excluded, group_included
26 changes: 18 additions & 8 deletions src/ophyd_async/core/_derived_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,20 @@ def __init__(
transform_cls: type[TransformT],
set_derived: Callable[..., Awaitable[None]] | None,
set_derived_takes_dict: bool,
**raw_and_transform_devices,
raw_devices,
raw_constants,
transform_devices,
transform_constants,
):
self._transform_cls = transform_cls
self._set_derived = set_derived
self._set_derived_takes_dict = set_derived_takes_dict
self._transform_devices = {
k: raw_and_transform_devices.pop(k) for k in transform_cls.model_fields
}
self._raw_devices = raw_and_transform_devices

self._transform_devices = transform_devices
self._transform_constants = transform_constants
self._raw_devices = raw_devices
self._raw_constants = raw_constants

self._derived_callbacks: dict[str, Callback[Reading]] = {}
self._cached_readings: dict[str, Reading] | None = None

Expand Down Expand Up @@ -122,7 +127,7 @@ def _make_transform_from_readings(
k: transform_readings[sig.name]["value"]
for k, sig in self.transform_readables.items()
}
return self._transform_cls(**transform_args)
return self._transform_cls(**(transform_args | self._transform_constants))

def _make_derived_readings(
self, raw_and_transform_readings: dict[str, Reading]
Expand All @@ -140,10 +145,15 @@ def _make_derived_readings(
transform = self._make_transform_from_readings(raw_and_transform_readings)
# Create the raw values from the rest then calculate the derived readings
# using the transform
# Extend dictionary with values of any Constants passed as arguments
raw_values = {
k: raw_and_transform_readings[sig.name]["value"]
for k, sig in self._raw_devices.items()
**{
k: raw_and_transform_readings[sig.name]["value"]
for k, sig in self._raw_devices.items()
},
**self._raw_constants,
}

derived_readings = {
name: Reading(
value=derived, timestamp=timestamp, alarm_severity=alarm_severity
Expand Down
5 changes: 3 additions & 2 deletions src/ophyd_async/sim/_mirror_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from bluesky.protocols import Movable

from ophyd_async.core import AsyncStatus, DerivedSignalFactory, Device, soft_signal_rw
from ophyd_async.core import AsyncStatus, DerivedSignalFactory, Device

from ._mirror_vertical import TwoJackDerived, TwoJackTransform
from ._motor import SimMotor
Expand All @@ -20,7 +20,8 @@ def __init__(self, name=""):
self.x1 = SimMotor()
self.x2 = SimMotor()
# Parameter
self.x1_x2_distance = soft_signal_rw(float, initial_value=1)
# This could also be set as 'soft_signal_rw(float, initial_value=1)'
self.x1_x2_distance = 1.0
# Derived signals
self._factory = DerivedSignalFactory(
TwoJackTransform,
Expand Down
1 change: 1 addition & 0 deletions src/ophyd_async/sim/_mirror_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, name=""):
self.y1 = SimMotor()
self.y2 = SimMotor()
# Parameter
# This could also be set as '1.0', if constant.
self.y1_y2_distance = soft_signal_rw(float, initial_value=1)
# Derived signals
self._factory = DerivedSignalFactory(
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_multi_derived_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_mismatching_args():
with pytest.raises(
TypeError,
match=re.escape(
"Expected devices to be passed as keyword arguments"
"Expected the following to be passed as keyword arguments"
" {'distance': <class 'float'>, 'jack1': <class 'float'>, "
"'jack2': <class 'float'>}, "
"got {'jack1': <class 'float'>, 'jack22': <class 'float'>, "
Expand Down
35 changes: 29 additions & 6 deletions tests/core/test_single_derived_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def _get_position(foo: float, bar: float) -> BeamstopPosition:
return BeamstopPosition.OUT_OF_POSITION


def _get_position_wrong_args(x: float, y: float) -> BeamstopPosition:
if abs(x) < 1 and abs(y) < 2:
return BeamstopPosition.IN_POSITION
else:
return BeamstopPosition.OUT_OF_POSITION


@pytest.mark.parametrize(
"x, y, position",
[
Expand Down Expand Up @@ -114,15 +121,15 @@ async def test_setting_all():
"func, expected_msg, args",
[
(
_get_position,
"Expected devices to be passed as keyword arguments "
"{'foo': <class 'float'>, 'bar': <class 'float'>}, "
"got {'x': <class 'float'>, 'y': <class 'float'>}",
{"x": soft_signal_rw(float), "y": soft_signal_rw(float)},
_get_position_wrong_args,
"Expected the following to be passed as keyword arguments "
"{'x': <class 'float'>, 'y': <class 'float'>}, "
"got {'foo': <class 'float'>, 'bar': <class 'float'>}",
{"foo": soft_signal_rw(float), "bar": soft_signal_rw(float)},
),
(
_get_position,
"Expected devices to be passed as keyword arguments "
"Expected the following to be passed as keyword arguments "
"{'foo': <class 'float'>, 'bar': <class 'float'>}, "
"got {'foo': <class 'int'>, 'bar': <class 'int'>}",
{
Expand All @@ -148,3 +155,19 @@ async def _put(value: float) -> None:

derived = derived_signal_rw(_get, _put, ts=signal_r)
assert await derived.get_value() == 4


async def test_derived_signal_allows_literals():
signal_rw = soft_signal_rw(int, 0, "TEST")

def _add_const_to_value(signal: int, const: int) -> int:
return const + signal

signal_r = derived_signal_r(
_add_const_to_value,
signal=signal_rw,
const=24,
)
assert await signal_r.get_value() == 24
await signal_rw.set(10)
assert await signal_r.get_value() == 34
Loading