diff --git a/src/ophyd_async/epics/adcore/__init__.py b/src/ophyd_async/epics/adcore/__init__.py index 16587efa01..fab5fe5ff5 100644 --- a/src/ophyd_async/epics/adcore/__init__.py +++ b/src/ophyd_async/epics/adcore/__init__.py @@ -1,5 +1,6 @@ from ._core_detector import AreaDetector from ._core_io import ( + ADBaseColorMode, ADBaseDatasetDescriber, ADBaseIO, DetectorState, @@ -26,6 +27,7 @@ ) __all__ = [ + "ADBaseColorMode", "ADBaseIO", "AreaDetector", "DetectorState", diff --git a/src/ophyd_async/epics/adcore/_core_io.py b/src/ophyd_async/epics/adcore/_core_io.py index 52332de994..4bb013ffec 100644 --- a/src/ophyd_async/epics/adcore/_core_io.py +++ b/src/ophyd_async/epics/adcore/_core_io.py @@ -1,6 +1,6 @@ import asyncio -from ophyd_async.core import Device, StrictEnum +from ophyd_async.core import Device, StrictEnum, SubsetEnum from ophyd_async.core._providers import DatasetDescriber from ophyd_async.epics.core import ( epics_signal_r, @@ -16,6 +16,12 @@ class Callback(StrictEnum): DISABLE = "Disable" +# For now, only support mono and RGB1 +class ADBaseColorMode(SubsetEnum): + MONO = "Mono" + RGB = "RGB1" + + class NDArrayBaseIO(Device): def __init__(self, prefix: str, name: str = "") -> None: self.unique_id = epics_signal_r(int, prefix + "UniqueId_RBV") @@ -24,6 +30,7 @@ def __init__(self, prefix: str, name: str = "") -> None: self.array_size_x = epics_signal_r(int, prefix + "ArraySizeX_RBV") self.array_size_y = epics_signal_r(int, prefix + "ArraySizeY_RBV") self.data_type = epics_signal_r(ADBaseDataType, prefix + "DataType_RBV") + self.color_mode = epics_signal_r(ADBaseColorMode, prefix + "ColorMode_RBV") self.array_counter = epics_signal_rw_rbv(int, prefix + "ArrayCounter") # There is no _RBV for this one self.wait_for_plugins = epics_signal_rw(bool, prefix + "WaitForPlugins") @@ -37,11 +44,18 @@ def __init__(self, driver: NDArrayBaseIO) -> None: async def np_datatype(self) -> str: return convert_ad_dtype_to_np(await self._driver.data_type.get_value()) - async def shape(self) -> tuple[int, int]: + async def shape(self) -> tuple[int, int] | tuple[int, int, int]: + current_color_mode = await self._driver.color_mode.get_value() + if current_color_mode not in ADBaseColorMode: + raise ValueError(f"Current color mode {current_color_mode} not currently supported!") + shape = await asyncio.gather( self._driver.array_size_y.get_value(), self._driver.array_size_x.get_value(), ) + if current_color_mode == ADBaseColorMode.RGB: + shape = (3, *shape) + return shape diff --git a/tests/epics/adsimdetector/test_sim.py b/tests/epics/adsimdetector/test_sim.py index 11c06b9dd5..f8a22cd720 100644 --- a/tests/epics/adsimdetector/test_sim.py +++ b/tests/epics/adsimdetector/test_sim.py @@ -41,7 +41,7 @@ def two_test_adsimdetectors( ad_standard_det_factory: Callable, ) -> Sequence[adsimdetector.SimDetector]: deta = ad_standard_det_factory(adsimdetector.SimDetector) - detb = ad_standard_det_factory(adsimdetector.SimDetector, number=2) + detb = ad_standard_det_factory(adsimdetector.SimDetector, number=2, color=True) return deta, detb @@ -212,7 +212,7 @@ def plan(): "test_adsim2-driver-acquire_time" ] == pytest.approx(1.8) assert descriptor["data_keys"]["test_adsim1"]["shape"] == [10, 10] - assert descriptor["data_keys"]["test_adsim2"]["shape"] == [11, 11] + assert descriptor["data_keys"]["test_adsim2"]["shape"] == [3, 11, 11] assert sda["stream_resource"] == sra["uid"] assert sdb["stream_resource"] == srb["uid"] assert ( @@ -277,6 +277,15 @@ async def test_detector_writes_to_file( ] +async def test_invalid_color_mode( + test_adsimdetector: adsimdetector.SimDetector, +): + set_mock_value(test_adsimdetector.driver.color_mode, "Bayer") + with pytest.raises(ValueError) as exc_info: + await test_adsimdetector._writer._dataset_describer.shape() + assert "not currently supported!" in str(exc_info.value) + + async def test_read_and_describe_detector( test_adsimdetector: adsimdetector.SimDetector, ): diff --git a/tests/epics/conftest.py b/tests/epics/conftest.py index 5422e0f388..caa46254d5 100644 --- a/tests/epics/conftest.py +++ b/tests/epics/conftest.py @@ -20,6 +20,7 @@ def generate_ad_standard_det( detector_cls: type[adcore.AreaDetector], writer_cls: type[adcore.ADWriter] = adcore.ADHDFWriter, number=1, + color=False, **kwargs, ) -> adcore.AreaDetector: # Dynamically generate a name based on the class of controller @@ -64,6 +65,9 @@ def on_set_file_path_callback(value: str, wait: bool = True): if isinstance(test_adstandard_det.fileio, adcore.NDFileHDFIO): set_mock_value(test_adstandard_det.fileio.num_frames_chunks, 1) + if color: + set_mock_value(test_adstandard_det.driver.color_mode, adcore.ADBaseColorMode.RGB) + return test_adstandard_det return generate_ad_standard_det