Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
91d1094
add check raw_to_derived argument type
Nov 17, 2025
b3750f7
fix test name
Nov 17, 2025
060ba24
Merge branch 'main' into improve-error-when-raw2derived-has-no-type
Villtord Nov 17, 2025
f65ffe5
use inspect module
Dec 5, 2025
ffaf6aa
fix small typing issues pylance
Dec 5, 2025
674b378
Merge branch 'main' into improve-error-when-raw2derived-has-no-type
Villtord Dec 5, 2025
ce5704a
remove commented lines
Dec 5, 2025
b3b7798
Merge branch 'main' into improve-error-when-raw2derived-has-no-type
Villtord Dec 9, 2025
a892cf8
fix raw_to_derived signature and move empty hints check to factory
Dec 11, 2025
0570649
fix multi derived test
Dec 12, 2025
199ed8f
convert string annotations to class
Dec 12, 2025
8bc79cc
add lost test back
Dec 12, 2025
3c25ef2
add type hint test for many-to-many
Dec 15, 2025
1919565
remove unnecessary derived_to_raw functino
Dec 15, 2025
ec3c454
move helper dict_wrapper method out from class
Dec 15, 2025
2e2616c
add type hints
Dec 15, 2025
ca20b1f
Merge branch 'main' into improve-error-when-raw2derived-has-no-type
Villtord Dec 15, 2025
9bcadcf
Merge branch 'main' into improve-error-when-raw2derived-has-no-type
Villtord Dec 17, 2025
e463319
test_subclasses_in_hints
Dec 17, 2025
f7ea387
fix test
Dec 17, 2025
7939630
simplify logic
Dec 17, 2025
ee5359d
add typeVar checking
Dec 17, 2025
72886a3
make type checking happy
Dec 17, 2025
b87d586
remove noqa
Dec 17, 2025
f527c6d
add "cls" to excluded names
Dec 17, 2025
7cff97a
remove unnecessary get_type_hints
Dec 18, 2025
4955163
reply comments
Dec 19, 2025
d9adee5
Merge branch 'main' into improve-error-when-raw2derived-has-no-type
Villtord Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions src/ophyd_async/core/_derived_signal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from collections.abc import Awaitable, Callable
from typing import Any, Generic, get_args, get_origin, get_type_hints, is_typeddict
import functools
from collections.abc import Awaitable, Callable, Mapping
from inspect import Parameter, signature
from typing import (
Any,
Generic,
get_args,
get_origin,
get_type_hints,
is_typeddict,
)

from bluesky.protocols import Locatable

Expand Down Expand Up @@ -55,10 +64,16 @@ def __init__(
**{k: f.annotation for k, f in transform_cls.model_fields.items()},
**{
k: v
for k, v in get_type_hints(transform_cls.raw_to_derived).items()
if k not in {"self", "return"}
for k, v in _get_params_types_dict(
transform_cls.raw_to_derived
).items()
if k not in {"self"} # noqa: E501
},
}
if empty_keys := [k for k, v in expected.items() if v == Parameter.empty]:
raise TypeError(
f"{transform_cls.raw_to_derived} is missing a type hint for arguments: {empty_keys}" # noqa: E501
)

# Populate received parameters and types
# Use Primitive's type, Signal's datatype,
Expand Down Expand Up @@ -195,28 +210,33 @@ def _get_return_datatype(func: Callable[..., SignalDatatypeT]) -> type[SignalDat
def _get_first_arg_datatype(
func: Callable[[SignalDatatypeT], Any],
) -> type[SignalDatatypeT]:
args = get_type_hints(func)
args.pop("return", None)
args = _get_params_types_dict(func)
if not args:
msg = f"{func} does not have a type hinted argument"
raise TypeError(msg)
return list(args.values())[0]


def _get_params_types_dict(inspected_function: Callable) -> Mapping[str, Any]:
hints = get_type_hints(inspected_function)
sig = signature(inspected_function)
normalized = {}
# convert string annotations to class
for name, param in sig.parameters.items():
if name not in ["self", "args", "kwargs"]:
normalized[name] = hints.get(name, param.annotation)
return normalized


def _make_factory(
raw_to_derived: Callable[..., SignalDatatypeT] | None = None,
raw_to_derived_func: Callable[..., SignalDatatypeT] | None = None,
set_derived: Callable[[SignalDatatypeT], Awaitable[None]] | None = None,
raw_devices_and_constants: dict[str, Device | Primitive] | None = None,
) -> DerivedSignalFactory:
if raw_to_derived:
if raw_to_derived_func:

class DerivedTransform(Transform):
def raw_to_derived(self, **kwargs) -> dict[str, SignalDatatypeT]:
return {"value": raw_to_derived(**kwargs)}

# Update the signature for raw_to_derived to match what we are passed as this
# will be checked in DerivedSignalFactory
DerivedTransform.raw_to_derived.__annotations__ = get_type_hints(raw_to_derived)
raw_to_derived = dict_wrapper(raw_to_derived_func)

return DerivedSignalFactory(
DerivedTransform,
Expand All @@ -227,6 +247,14 @@ def raw_to_derived(self, **kwargs) -> dict[str, SignalDatatypeT]:
return DerivedSignalFactory(Transform, set_derived=set_derived)


def dict_wrapper(fn):
@functools.wraps(fn)
def wrapped(self, **kwargs):
return {"value": fn(**kwargs)}

return wrapped


def derived_signal_r(
raw_to_derived: Callable[..., SignalDatatypeT],
derived_units: str | None = None,
Expand All @@ -245,7 +273,7 @@ def derived_signal_r(
The names of these arguments must match the arguments of raw_to_derived.
"""
factory = _make_factory(
raw_to_derived=raw_to_derived,
raw_to_derived_func=raw_to_derived,
raw_devices_and_constants=raw_devices_and_constants,
)
return factory.derived_signal_r(
Expand Down Expand Up @@ -278,16 +306,16 @@ def derived_signal_rw(
The names of these arguments must match the arguments of raw_to_derived.
"""
raw_to_derived_datatype = _get_return_datatype(raw_to_derived)
set_derived_datatype = _get_first_arg_datatype(set_derived)
if raw_to_derived_datatype != set_derived_datatype:
set_derived_arg_datatype = _get_first_arg_datatype(set_derived)
if raw_to_derived_datatype != set_derived_arg_datatype:
msg = (
f"{raw_to_derived} has datatype {raw_to_derived_datatype} "
f"!= {set_derived_datatype} datatype {set_derived_datatype}"
f"!= {set_derived_arg_datatype} datatype {set_derived_arg_datatype}"
)
raise TypeError(msg)

factory = _make_factory(
raw_to_derived=raw_to_derived,
raw_to_derived_func=raw_to_derived,
set_derived=set_derived,
raw_devices_and_constants=raw_devices_and_constants,
)
Expand Down
25 changes: 24 additions & 1 deletion tests/unit_tests/core/test_single_derived_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def test_get_returns_right_position(

@pytest.mark.parametrize("cls", [ReadOnlyBeamstop, MovableBeamstop])
async def test_monitoring_position(cls: type[ReadOnlyBeamstop | MovableBeamstop]):
results = asyncio.Queue[BeamstopPosition]()
results: asyncio.Queue[dict[str, Reading]] = asyncio.Queue()
inst = cls("inst")
inst.position.subscribe_reading(results.put_nowait)
assert (await results.get())["inst-position"][
Expand Down Expand Up @@ -181,6 +181,29 @@ async def _put(value: float) -> None:
pass


# function without type hint on first argument
def _get_no_type(ts) -> float:
return ts


async def test_derived_signal_rw_get_method_no_param_type():
signal_rw = soft_signal_rw(int, initial_value=4)
with pytest.raises(
TypeError,
match=re.escape(" is missing a type hint for arguments: ['ts']"),
):
derived_signal_rw(_get_no_type, _put, ts=signal_rw)


async def test_derived_signal_r_get_method_no_param_type():
signal_rw = soft_signal_rw(int, initial_value=4)
with pytest.raises(
TypeError,
match=re.escape(" is missing a type hint for arguments: ['ts']"),
):
derived_signal_r(_get_no_type, ts=signal_rw)


@pytest.fixture
def derived_signal_backend() -> SignalBackend[SignalDatatype]:
signal_rw = soft_signal_rw(int, initial_value=4)
Expand Down