Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ requires-python = ">=3.11"
sim = ["h5py"]
ca = ["aioca>=2.0a4"]
pva = ["p4p>=4.2.0"]
tango = ["pytango==10.0.3"]
tango = ["pytango==10.1.3"]
demo = ["ipython", "matplotlib", "pyqt6"]

[dependency-groups]
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/tango/core/_base_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ async def connect_mock(self, device: Device, mock: LazyMock):
async def connect_real(self, device: Device, timeout: float, force_reconnect: bool):
if not self.trl:
raise RuntimeError(f"Could not created Device Proxy for TRL {self.trl}")
self.proxy = await AsyncDeviceProxy(self.trl)
self.proxy = await AsyncDeviceProxy(self.trl) # type: ignore
children = sorted(
set()
.union(self.proxy.get_attribute_list())
.union(self.proxy.get_command_list())
)
) # type: ignore

children = [
child for child in children if child not in self.filler.ignored_signals
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/tango/core/_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def _write_convert(self, value):
def _convert(self, value):
return self._labels[int(value)]

def write_value(self, value: NDArray[np.str_]) -> NDArray[DevState]:
def write_value(self, value: NDArray[np.str_]) -> NDArray[np.int_]:
vfunc = np.vectorize(self._write_convert, otypes=[DevState])
new_array = vfunc(value)
return new_array

def value(self, value: NDArray[DevState]) -> NDArray[np.str_]:
def value(self, value: NDArray[np.int_]) -> NDArray[np.str_]:
vfunc = np.vectorize(self._convert)
new_array = vfunc(value)
return new_array
21 changes: 11 additions & 10 deletions src/ophyd_async/tango/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
from __future__ import annotations

import logging
from enum import IntEnum

import numpy.typing as npt
from tango import (
AttrWriteType,
DeviceProxy,
DevState,
)
from tango.asyncio import DeviceProxy as AsyncDeviceProxy

from ophyd_async.core import (
DEFAULT_TIMEOUT,
Signal,
SignalDatatype,
SignalDatatypeT,
SignalR,
SignalRW,
Expand Down Expand Up @@ -140,20 +138,23 @@ def tango_signal_x(

async def infer_python_type(
trl: str = "", proxy: DeviceProxy | None = None
) -> object | npt.NDArray | type[DevState] | IntEnum:
) -> type[SignalDatatype] | None:
"""Infers the python type from the TRL."""
# TODO: work out if this is still needed
device_trl, tr_name = get_device_trl_and_attr(trl)
if proxy is None:
dev_proxy = await AsyncDeviceProxy(device_trl)
dev_proxy = await AsyncDeviceProxy(device_trl) # type: ignore
else:
dev_proxy = proxy

if tr_name in dev_proxy.get_command_list():
config = await dev_proxy.get_command_config(tr_name)
# A Device proxy instantiated by awaiting
# tango.asyncio.DeviceProxy is typed the same as the sync
# despite having awaitable methods.
config = await dev_proxy.get_command_config(tr_name) # type: ignore
py_type = get_python_type(config)
elif tr_name in dev_proxy.get_attribute_list():
config = await dev_proxy.get_attribute_config(tr_name)
config = await dev_proxy.get_attribute_config(tr_name) # type: ignore
py_type = get_python_type(config)
else:
raise RuntimeError(f"Cannot find {tr_name} in {device_trl}")
Expand All @@ -165,7 +166,7 @@ async def infer_signal_type(
) -> type[Signal] | None:
device_trl, tr_name = get_device_trl_and_attr(trl)
if proxy is None:
dev_proxy = await AsyncDeviceProxy(device_trl)
dev_proxy = await AsyncDeviceProxy(device_trl) # type: ignore
else:
dev_proxy = proxy

Expand All @@ -174,7 +175,7 @@ async def infer_signal_type(
raise RuntimeError(f"Cannot find {tr_name} in {device_trl}")

if tr_name in dev_proxy.get_attribute_list():
config = await dev_proxy.get_attribute_config(tr_name)
config = await dev_proxy.get_attribute_config(tr_name) # type: ignore
if config.writable in [AttrWriteType.READ_WRITE, AttrWriteType.READ_WITH_WRITE]:
return SignalRW
elif config.writable == AttrWriteType.READ:
Expand All @@ -183,7 +184,7 @@ async def infer_signal_type(
return SignalW

if tr_name in dev_proxy.get_command_list():
config = await dev_proxy.get_command_config(tr_name)
config = await dev_proxy.get_command_config(tr_name) # type: ignore
command_character = get_command_character(config)
if command_character == CommandProxyReadCharacter.READ:
return SignalR
Expand Down
77 changes: 42 additions & 35 deletions src/ophyd_async/tango/core/_tango_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
Callback,
NotConnectedError,
SignalBackend,
SignalDatatype,
SignalDatatypeT,
SignalMetadata,
StrictEnum,
Expand Down Expand Up @@ -85,7 +86,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
set_global_executor(AsyncioExecutor())
return await func(*args, **kwargs)

return wrapper
return wrapper # type: ignore[return-value]


class TangoLongStringTable(Table):
Expand All @@ -98,7 +99,9 @@ class TangoDoubleStringTable(Table):
string: Sequence[str]


def get_python_type(config: AttributeInfoEx | CommandInfo | TestConfig) -> object:
def get_python_type(
config: AttributeInfoEx | CommandInfo | TestConfig,
) -> type[SignalDatatype] | None:
"""For converting between recieved tango types and python primatives."""
tango_type = None
tango_format = None
Expand Down Expand Up @@ -129,7 +132,7 @@ def get_python_type(config: AttributeInfoEx | CommandInfo | TestConfig) -> objec
if tango_type is CmdArgType.DevVarDoubleStringArray:
return TangoDoubleStringTable

def _get_type(cls: type) -> object:
def _get_type(cls: type) -> type[SignalDatatype]:
if tango_format == AttrDataFormat.SCALAR:
return cls
elif tango_format == AttrDataFormat.SPECTRUM:
Expand All @@ -138,7 +141,7 @@ def _get_type(cls: type) -> object:
return Array1D[cls]
elif tango_format == AttrDataFormat.IMAGE:
if cls is str or issubclass(cls, StrictEnum):
return Sequence[Sequence[str]]
raise TypeError("Images of type str or enum are not supported")
return npt.NDArray[cls]
else:
return cls
Expand All @@ -154,7 +157,7 @@ def _get_type(cls: type) -> object:
elif is_binary(tango_type, True):
return _get_type(str)
elif tango_type == CmdArgType.DevEnum:
if hasattr(config, "enum_labels"):
if hasattr(config, "enum_labels") and not isinstance(config, CommandInfo):
enum_dict = {label: str(label) for label in config.enum_labels}
return _get_type(StrictEnum("TangoEnum", enum_dict))
else:
Expand Down Expand Up @@ -236,15 +239,17 @@ def set_converter(self, converter: "TangoConverter"):
class AttributeProxy(TangoProxy):
"""Used by the tango transport."""

_callback: Callback | None = None
_eid: int | None = None
_poll_task: asyncio.Task | None = None
_abs_change: float | None = None
_rel_change: float | None = 0.1
_polling_period: float = 0.1
_allow_polling: bool = False
exception: BaseException | None = None
_last_reading: Reading = Reading(value=None, timestamp=0, alarm_severity=0)
def __init__(self, device_proxy: DeviceProxy, name: str):
self._callback: Callback | None = None
self._eid: int | None = None
self._poll_task: asyncio.Task | None = None
self._abs_change: float | None = None
self._rel_change: float | None = None
self._polling_period: float = 0.1
self._allow_polling: bool = False
self.exception: BaseException | None = None
self._last_reading: Reading = Reading(value=None, timestamp=0, alarm_severity=0)
super().__init__(device_proxy, name)

async def connect(self) -> None:
try:
Expand All @@ -255,19 +260,19 @@ async def connect(self) -> None:
eid = await self._proxy.subscribe_event( # type: ignore
self._name, EventType.CHANGE_EVENT, self._event_processor
)
await self._proxy.unsubscribe_event(eid)
await self._proxy.unsubscribe_event(eid) # type: ignore
self.support_events = True
except Exception:
pass

@ensure_proper_executor
async def get(self) -> object: # type: ignore
attr = await self._proxy.read_attribute(self._name)
attr = await self._proxy.read_attribute(self._name) # type: ignore
return self._converter.value(attr.value)

@ensure_proper_executor
async def get_w_value(self) -> object: # type: ignore
attr = await self._proxy.read_attribute(self._name)
attr = await self._proxy.read_attribute(self._name) # type: ignore
return self._converter.value(attr.w_value)

@ensure_proper_executor
Expand All @@ -279,7 +284,7 @@ async def put( # type: ignore
try:

async def _write():
return await self._proxy.write_attribute(self._name, value)
return await self._proxy.write_attribute(self._name, value) # type: ignore

task = asyncio.create_task(_write())
await asyncio.wait_for(task, timeout)
Expand All @@ -292,11 +297,11 @@ async def _write():

@ensure_proper_executor
async def get_config(self) -> AttributeInfoEx: # type: ignore
return await self._proxy.get_attribute_config(self._name)
return await self._proxy.get_attribute_config(self._name) # type: ignore

@ensure_proper_executor
async def get_reading(self) -> Reading: # type: ignore
attr = await self._proxy.read_attribute(self._name)
attr = await self._proxy.read_attribute(self._name) # type: ignore
reading = Reading(
value=self._converter.value(attr.value),
timestamp=attr.time.totime(),
Expand All @@ -310,14 +315,16 @@ def has_subscription(self) -> bool:

@ensure_proper_executor
async def _subscribe_to_event(self):
if not self._eid:
self._eid = await self._proxy.subscribe_event(
self._name,
EventType.CHANGE_EVENT,
self._event_processor,
stateless=True,
green_mode=GreenMode.Asyncio,
)
try:
if not self._eid:
self._eid = await self._proxy.subscribe_event(
self._name,
EventType.CHANGE_EVENT,
self._event_processor,
green_mode=GreenMode.Asyncio,
) # type: ignore
except Exception as exc:
logger.debug(f"Subscribe to event failed: {exc}")

def subscribe_callback(self, callback: Callback | None):
# If the attribute supports events, then we can subscribe to them
Expand Down Expand Up @@ -352,7 +359,7 @@ async def _poll():
def unsubscribe_callback(self):
if self._eid:
try:
self._proxy.unsubscribe_event(self._eid, green_mode=False)
self._proxy.unsubscribe_event(self._eid, green_mode=False) # type: ignore
except Exception as exc:
logger.warning(f"Could not unsubscribe from event: {exc}")
finally:
Expand Down Expand Up @@ -524,7 +531,7 @@ async def get_w_value(self) -> object:
return self._last_w_value

async def connect(self) -> None:
self._config = await self.device_proxy.get_command_config(self.name)
self._config = await self.device_proxy.get_command_config(self.name) # type: ignore
self._read_character = get_command_character(self._config)

@ensure_proper_executor
Expand Down Expand Up @@ -554,7 +561,7 @@ async def _put():

@ensure_proper_executor
async def get_config(self) -> CommandInfo: # type: ignore
return await self._proxy.get_command_config(self._name)
return await self._proxy.get_command_config(self._name) # type: ignore

async def get_reading(self) -> Reading:
if self._read_character == CommandProxyReadCharacter.READ:
Expand Down Expand Up @@ -594,7 +601,7 @@ def get_dtype_extended(datatype) -> object | None:

def get_source_metadata(
tango_resource: str,
tr_configs: dict[str, AttributeInfoEx],
tr_configs: dict[str, AttributeInfoEx | CommandInfo],
) -> SignalMetadata:
metadata = {}
for _, config in tr_configs.items():
Expand Down Expand Up @@ -637,7 +644,7 @@ def get_source_metadata(
tr_dtype = get_python_type(config)

if tr_dtype == CmdArgType.DevState:
_choices = list(DevState.names.keys())
_choices = list(DevState.names.keys()) # type: ignore

_precision = parse_precision(config)

Expand Down Expand Up @@ -667,7 +674,7 @@ async def get_tango_trl(
device_trl, trl_name = get_device_trl_and_attr(full_trl)
trl_name = trl_name.lower()
if device_proxy is None:
device_proxy = await AsyncDeviceProxy(device_trl, timeout=timeout)
device_proxy = await AsyncDeviceProxy(device_trl, timeout=timeout) # type: ignore
# all attributes can be always accessible with low register
if isinstance(device_proxy, DeviceProxy):
all_attrs = [
Expand Down Expand Up @@ -744,7 +751,7 @@ def __init__(
read_trl: self.device_proxy,
write_trl: self.device_proxy,
}
self.trl_configs: dict[str, AttributeInfoEx] = {}
self.trl_configs: dict[str, AttributeInfoEx | CommandInfo] = {}
self._polling: tuple[bool, float, float | None, float | None] = (
False,
0.1,
Expand Down
5 changes: 3 additions & 2 deletions src/ophyd_async/tango/core/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ class DevStateEnum(StrictEnum):
UNKNOWN = "UNKNOWN"


def get_full_attr_trl(device_trl: str, attr_name: str):
def get_full_attr_trl(device_trl: str, attr_name: str) -> str:
device_parts = device_trl.split("#", 1)
# my/device/name#dbase=no splits into my/device/name and dbase=no
# my/device/name#dbase=no splits into my/device/name and
# dbase=no
full_trl = device_parts[0] + "/" + attr_name
if len(device_parts) > 1:
full_trl += "#" + device_parts[1]
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/tango/testing/_one_of_everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AttributeData(Generic[T]):
"my_state",
"DevState",
DevState.INIT,
np.array(list(DevState.names.values()), dtype=DevState),
np.array(list(DevState.names.values()), dtype=DevState), # type: ignore
),
]

Expand Down
10 changes: 10 additions & 0 deletions tests/system_tests_tango/test_tango_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TangoDevice,
TangoSignalBackend,
get_full_attr_trl,
infer_python_type,
parse_precision,
tango_signal_r,
tango_signal_rw,
Expand Down Expand Up @@ -566,3 +567,12 @@ async def test_parse_precision(everything_device_trl):
assert precision == 2
else:
assert precision is None


@pytest.mark.asyncio
async def test_infer_python_type(everything_device_trl):
proxy = await DeviceProxy(everything_device_trl)
bad_attr = everything_device_trl + "/this_does_not_exist"
with pytest.raises(RuntimeError) as exc:
await infer_python_type(trl=bad_attr, proxy=proxy)
assert "Cannot find" in str(exc.value)
Loading