Skip to content

Commit 147fd94

Browse files
oliwenmandiamondDiamondJosephcoretlgithub-advanced-security[bot]
authored
Add SupersetEnum and use it in ADBaseDataType (#864)
* Add SupersetEnum, rework Enums to not inherit from another and adjust framework to change * Correct supported values check * Organise imports and correct string docs * Update src/ophyd_async/core/_signal_backend.py grammar correction Co-authored-by: Joseph Ware <[email protected]> * Update src/ophyd_async/core/_signal_backend.py grammar correction Co-authored-by: Joseph Ware <[email protected]> * Added tests and improved handling * Fixed linting issue * Improved test coverage and updated check for SupersetEnum * Update src/ophyd_async/core/_signal_backend.py Co-authored-by: Tom C (DLS) <[email protected]> * Added UNDEFINED to ADBaseDataType * Fix import * Fix liniting issue and circular imports * Further linting correction * Added error message for UNDEFINED ADBaseDataType * Potential fix for code scanning alert no. 186: First parameter of a class method is not named 'cls' Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * Applied feedback * Removed empty string values in datatypes * Reverted suggest support_values check as broke tests * Wrote SupersetEnum check in a more succinct way Had to widen some types to do this * Remove unused code --------- Co-authored-by: Joseph Ware <[email protected]> Co-authored-by: Tom C (DLS) <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Tom Cobb <[email protected]>
1 parent 0b00d49 commit 147fd94

File tree

7 files changed

+126
-40
lines changed

7 files changed

+126
-40
lines changed

src/ophyd_async/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@
7575
DEFAULT_TIMEOUT,
7676
CalculatableTimeout,
7777
Callback,
78+
EnumTypes,
7879
LazyMock,
7980
NotConnected,
8081
Reference,
8182
StrictEnum,
8283
SubsetEnum,
84+
SupersetEnum,
8385
WatcherUpdate,
8486
gather_dict,
8587
get_dtype,
@@ -122,6 +124,8 @@
122124
"Array1D",
123125
"StrictEnum",
124126
"SubsetEnum",
127+
"SupersetEnum",
128+
"EnumTypes",
125129
"Table",
126130
"SignalMetadata",
127131
# Soft signal

src/ophyd_async/core/_signal_backend.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,16 @@
66
from bluesky.protocols import Reading
77
from event_model import DataKey, Dtype, Limits
88

9+
from ophyd_async.core._utils import (
10+
Callback,
11+
EnumTypes,
12+
StrictEnum,
13+
SubsetEnum,
14+
SupersetEnum,
15+
get_enum_cls,
16+
)
17+
918
from ._table import Table
10-
from ._utils import Callback, StrictEnum, get_enum_cls
1119

1220
DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic)
1321
"""A numpy dtype like [](#numpy.float64)."""
@@ -24,7 +32,7 @@
2432
Primitive = bool | int | float | str
2533
SignalDatatype = (
2634
Primitive
27-
| StrictEnum
35+
| EnumTypes
2836
| Array1D[np.bool_]
2937
| Array1D[np.int8]
3038
| Array1D[np.uint8]
@@ -39,16 +47,18 @@
3947
| np.ndarray
4048
| Sequence[str]
4149
| Sequence[StrictEnum]
50+
| Sequence[SubsetEnum]
51+
| Sequence[SupersetEnum]
4252
| Table
4353
)
4454
"""The supported [](#Signal) datatypes:
4555
4656
- A python primitive [](#bool), [](#int), [](#float), [](#str)
47-
- A [](#StrictEnum) or [](#SubsetEnum) subclass
57+
- An [](#EnumTypes) subclass
4858
- A fixed datatype [](#Array1D) of numpy bool, signed and unsigned integers or float
4959
- A [](#numpy.ndarray) which can change dimensions and datatype at runtime
5060
- A sequence of [](#str)
51-
- A sequence of [](#StrictEnum) or [](#SubsetEnum) subclass
61+
- A sequence of [](#EnumTypes) subclasses
5262
- A [](#Table) subclass
5363
"""
5464
# TODO: These typevars will not be needed when we drop python 3.11
@@ -58,7 +68,7 @@
5868
SignalDatatypeT = TypeVar("SignalDatatypeT", bound=SignalDatatype)
5969
"""A typevar for a [](#SignalDatatype)."""
6070
SignalDatatypeV = TypeVar("SignalDatatypeV", bound=SignalDatatype)
61-
EnumT = TypeVar("EnumT", bound=StrictEnum)
71+
EnumT = TypeVar("EnumT", bound=EnumTypes)
6272
TableT = TypeVar("TableT", bound=Table)
6373

6474

@@ -136,7 +146,7 @@ def _datakey_dtype(datatype: type[SignalDatatype]) -> Dtype:
136146
or issubclass(datatype, Table)
137147
):
138148
return "array"
139-
elif issubclass(datatype, StrictEnum):
149+
elif issubclass(datatype, EnumTypes):
140150
return "string"
141151
elif issubclass(datatype, Primitive):
142152
return _primitive_dtype[datatype]
@@ -153,7 +163,7 @@ def _datakey_dtype_numpy(
153163
elif (
154164
get_origin(datatype) is Sequence
155165
or datatype is str
156-
or issubclass(datatype, StrictEnum)
166+
or issubclass(datatype, EnumTypes)
157167
):
158168
# TODO: use np.dtypes.StringDType when we can use in structured arrays
159169
# https://github.com/numpy/numpy/issues/25693
@@ -167,7 +177,7 @@ def _datakey_dtype_numpy(
167177

168178

169179
def _datakey_shape(value: SignalDatatype) -> list[int]:
170-
if type(value) in _primitive_dtype or isinstance(value, StrictEnum):
180+
if type(value) in _primitive_dtype or isinstance(value, EnumTypes):
171181
return []
172182
elif isinstance(value, np.ndarray):
173183
return list(value.shape)

src/ophyd_async/core/_utils.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
66
from dataclasses import dataclass
77
from enum import Enum, EnumMeta
8-
from typing import Any, Generic, Literal, ParamSpec, TypeVar, get_args, get_origin
8+
from typing import (
9+
Any,
10+
Generic,
11+
Literal,
12+
ParamSpec,
13+
TypeVar,
14+
get_args,
15+
get_origin,
16+
)
917
from unittest.mock import Mock
1018

1119
import numpy as np
@@ -19,20 +27,16 @@
1927
logger = logging.getLogger("ophyd_async")
2028

2129

22-
class StrictEnumMeta(EnumMeta):
23-
def __new__(metacls, *args, **kwargs):
24-
ret = super().__new__(metacls, *args, **kwargs)
30+
class UppercaseNameEnumMeta(EnumMeta):
31+
def __new__(cls, *args, **kwargs):
32+
ret = super().__new__(cls, *args, **kwargs)
2533
lowercase_names = [x.name for x in ret if not x.name.isupper()] # type: ignore
2634
if lowercase_names:
2735
raise TypeError(f"Names {lowercase_names} should be uppercase")
2836
return ret
2937

3038

31-
class StrictEnum(str, Enum, metaclass=StrictEnumMeta):
32-
"""All members should exist in the Backend, and there will be no extras."""
33-
34-
35-
class SubsetEnumMeta(StrictEnumMeta):
39+
class AnyStringUppercaseNameEnumMeta(UppercaseNameEnumMeta):
3640
def __call__(self, value, *args, **kwargs): # type: ignore
3741
"""Return given value if it is a string and not a member of the enum.
3842
@@ -54,10 +58,21 @@ def __call__(self, value, *args, **kwargs): # type: ignore
5458
return super().__call__(value, *args, **kwargs)
5559

5660

57-
class SubsetEnum(StrictEnum, metaclass=SubsetEnumMeta):
61+
class StrictEnum(str, Enum, metaclass=UppercaseNameEnumMeta):
62+
"""All members should exist in the Backend, and there will be no extras."""
63+
64+
65+
class SubsetEnum(str, Enum, metaclass=AnyStringUppercaseNameEnumMeta):
5866
"""All members should exist in the Backend, but there may be extras."""
5967

6068

69+
class SupersetEnum(str, Enum, metaclass=UppercaseNameEnumMeta):
70+
"""Some members should exist in the Backend, and there should be no extras."""
71+
72+
73+
EnumTypes = StrictEnum | SubsetEnum | SupersetEnum
74+
75+
6176
CALCULATE_TIMEOUT = "CALCULATE_TIMEOUT"
6277
"""Sentinel used to implement ``myfunc(timeout=CalculateTimeout)``
6378
@@ -207,10 +222,11 @@ def get_dtype(datatype: type) -> np.dtype:
207222
return np.dtype(get_args(get_args(datatype)[1])[0])
208223

209224

210-
def get_enum_cls(datatype: type | None) -> type[StrictEnum] | None:
225+
def get_enum_cls(datatype: type | None) -> type[EnumTypes] | None:
211226
"""Get the enum class from a datatype.
212227
213-
:raises TypeError: if type is not a [](#StrictEnum) or [](#SubsetEnum) subclass
228+
:raises TypeError: if type is not a [](#StrictEnum) or [](#SubsetEnum)
229+
or [](#SupersetEnum) subclass
214230
```python
215231
>>> from ophyd_async.core import StrictEnum
216232
>>> from collections.abc import Sequence
@@ -227,10 +243,11 @@ def get_enum_cls(datatype: type | None) -> type[StrictEnum] | None:
227243
if get_origin(datatype) is Sequence:
228244
datatype = get_args(datatype)[0]
229245
if datatype and issubclass(datatype, Enum):
230-
if not issubclass(datatype, StrictEnum):
246+
if not issubclass(datatype, EnumTypes):
231247
raise TypeError(
232248
f"{datatype} should inherit from ophyd_async.core.SubsetEnum "
233-
"or ophyd_async.core.StrictEnum"
249+
"or ophyd_async.core.StrictEnum "
250+
"or ophyd_async.core.SupersetEnum."
234251
)
235252
return datatype
236253
return None

src/ophyd_async/epics/adcore/_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
SignalRW,
88
StrictEnum,
99
SubsetEnum,
10+
SupersetEnum,
1011
wait_for_value,
1112
)
1213

1314

14-
class ADBaseDataType(StrictEnum):
15+
class ADBaseDataType(SupersetEnum):
1516
INT8 = "Int8"
1617
UINT8 = "UInt8"
1718
INT16 = "Int16"
@@ -22,6 +23,9 @@ class ADBaseDataType(StrictEnum):
2223
UINT64 = "UInt64"
2324
FLOAT32 = "Float32"
2425
FLOAT64 = "Float64"
26+
# Driver database override will blank the enum string if it doesn't
27+
# support a datatype
28+
UNDEFINED = ""
2529

2630

2731
def convert_ad_dtype_to_np(ad_dtype: ADBaseDataType) -> str:
@@ -37,7 +41,12 @@ def convert_ad_dtype_to_np(ad_dtype: ADBaseDataType) -> str:
3741
ADBaseDataType.FLOAT32: "<f4",
3842
ADBaseDataType.FLOAT64: "<f8",
3943
}
40-
return ad_dtype_to_np_dtype[ad_dtype]
44+
np_type = ad_dtype_to_np_dtype.get(ad_dtype)
45+
if np_type is None:
46+
raise ValueError(
47+
"Areadetector driver has a blank DataType, this is not supported"
48+
)
49+
return np_type
4150

4251

4352
def convert_pv_dtype_to_np(datatype: str) -> str:

src/ophyd_async/epics/core/_aioca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import sys
33
import typing
4-
from collections.abc import Sequence
4+
from collections.abc import Mapping, Sequence
55
from functools import cache
66
from math import isnan, nan
77
from typing import Any, Generic, cast
@@ -146,7 +146,7 @@ def value(self, value: AugmentedValue) -> bool:
146146

147147

148148
class CaEnumConverter(CaConverter[str]):
149-
def __init__(self, supported_values: dict[str, str]):
149+
def __init__(self, supported_values: Mapping[str, str]):
150150
self.supported_values = supported_values
151151
super().__init__(
152152
str, dbr.DBR_STRING, metadata=SignalMetadata(choices=list(supported_values))

src/ophyd_async/epics/core/_util.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
from collections.abc import Sequence
2-
from typing import Any, get_args, get_origin
1+
from collections.abc import Mapping, Sequence
2+
from typing import Any, TypeVar, get_args, get_origin
33

44
import numpy as np
55

66
from ophyd_async.core import (
77
SignalBackend,
88
SignalDatatypeT,
9+
StrictEnum,
910
SubsetEnum,
11+
SupersetEnum,
1012
get_dtype,
1113
get_enum_cls,
1214
)
1315

16+
T = TypeVar("T")
17+
1418

1519
def get_pv_basename_and_field(pv: str) -> tuple[str, str | None]:
1620
"""Split PV into record name and field."""
@@ -23,26 +27,30 @@ def get_pv_basename_and_field(pv: str) -> tuple[str, str | None]:
2327

2428
def get_supported_values(
2529
pv: str,
26-
datatype: type,
30+
datatype: type[T],
2731
pv_choices: Sequence[str],
28-
) -> dict[str, str]:
32+
) -> Mapping[str, T | str]:
2933
enum_cls = get_enum_cls(datatype)
3034
if not enum_cls:
3135
raise TypeError(f"{datatype} is not an Enum")
3236
choices = [v.value for v in enum_cls]
3337
error_msg = f"{pv} has choices {pv_choices}, but {datatype} requested {choices} "
34-
if issubclass(enum_cls, SubsetEnum):
38+
if issubclass(enum_cls, StrictEnum):
39+
if set(choices) != set(pv_choices):
40+
raise TypeError(error_msg + "to be strictly equal to them.")
41+
elif issubclass(enum_cls, SubsetEnum):
3542
if not set(choices).issubset(pv_choices):
3643
raise TypeError(error_msg + "to be a subset of them.")
44+
elif issubclass(enum_cls, SupersetEnum):
45+
if not set(pv_choices).issubset(choices):
46+
raise TypeError(error_msg + "to be a superset of them.")
3747
else:
38-
if set(choices) != set(pv_choices):
39-
raise TypeError(error_msg + "to be strictly equal to them.")
40-
41-
# Take order from the pv choices
42-
supported_values = {x: x for x in pv_choices}
43-
# But override those that we specify via the datatype
44-
for v in enum_cls:
45-
supported_values[v.value] = v
48+
raise TypeError(f"{datatype} is not a StrictEnum, SubsetEnum, or SupersetEnum")
49+
# Create a map from the string value to the enum instance
50+
# For StrictEnum and SupersetEnum, all values here will be enum values
51+
# For SubsetEnum, only the values in choices will be enum values, the rest will be
52+
# strings
53+
supported_values = {x: enum_cls(x) for x in pv_choices}
4654
return supported_values
4755

4856

tests/epics/signal/test_common.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from ophyd_async.core import StrictEnum
5+
from ophyd_async.core import StrictEnum, SupersetEnum
66
from ophyd_async.epics.core._util import get_supported_values # noqa: PLC2701
77

88

@@ -45,3 +45,41 @@ class MyEnum(StrictEnum):
4545
assert len(supported_vals) == 2
4646
assert "test_1" in supported_vals
4747
assert "test_2" in supported_vals
48+
49+
50+
def test_given_supersetenum_that_partial_matches_are_valid():
51+
class MyEnum(SupersetEnum):
52+
TEST_1 = "test_1"
53+
TEST_2 = "test_2"
54+
55+
supported_vals = get_supported_values("", MyEnum, ("test_1",))
56+
57+
assert "test_1" in supported_vals
58+
assert supported_vals.get("test_2") is None
59+
60+
61+
class MySupersetEnum(SupersetEnum):
62+
TEST_1 = "test_1"
63+
TEST_2 = "test_2"
64+
65+
66+
SUPERSETENUM_ERROR_MESSAGE = "to be a superset of them."
67+
68+
69+
def test_given_supersetenum_that_all_values_plus_extra_values_are_invalid():
70+
with pytest.raises(TypeError, match=SUPERSETENUM_ERROR_MESSAGE):
71+
get_supported_values(
72+
"",
73+
MySupersetEnum,
74+
(MySupersetEnum.TEST_1, MySupersetEnum.TEST_2, "extra_1"),
75+
)
76+
77+
78+
def test_given_supersetenum_that_partial_values_plus_extra_values_are_invalid():
79+
with pytest.raises(TypeError, match=SUPERSETENUM_ERROR_MESSAGE):
80+
get_supported_values("", MySupersetEnum, (MySupersetEnum.TEST_1, "extra_1"))
81+
82+
83+
def test_given_supersetenum_that_no_matches_is_invalid():
84+
with pytest.raises(TypeError, match=SUPERSETENUM_ERROR_MESSAGE):
85+
get_supported_values("", MySupersetEnum, ("no_match_1", "no_match_2"))

0 commit comments

Comments
 (0)