diff --git a/src/dodal/devices/bimorph_mirror.py b/src/dodal/devices/bimorph_mirror.py index 768c138b2a..c1961327ba 100644 --- a/src/dodal/devices/bimorph_mirror.py +++ b/src/dodal/devices/bimorph_mirror.py @@ -1,10 +1,8 @@ import asyncio -from collections.abc import Mapping from typing import Annotated as A from bluesky.protocols import Movable from ophyd_async.core import ( - DEFAULT_TIMEOUT, AsyncStatus, DeviceVector, SignalR, @@ -12,6 +10,7 @@ SignalW, StandardReadable, StrictEnum, + set_and_wait_for_other_value, wait_for_value, ) from ophyd_async.core import StandardReadableFormat as Format @@ -23,6 +22,8 @@ epics_signal_x, ) +DEFAULT_TIMEOUT = 60 + class BimorphMirrorOnOff(StrictEnum): ON = "ON" @@ -41,7 +42,7 @@ class BimorphMirrorStatus(StrictEnum): ERROR = "Error" -class BimorphMirrorChannel(StandardReadable, Movable[float], EpicsDevice): +class BimorphMirrorChannel(StandardReadable, EpicsDevice): """Collection of PVs comprising a single bimorph channel. Attributes: @@ -56,23 +57,13 @@ class BimorphMirrorChannel(StandardReadable, Movable[float], EpicsDevice): status: A[SignalR[BimorphMirrorOnOff], PvSuffix("STATUS"), Format.CONFIG_SIGNAL] shift: A[SignalW[float], PvSuffix("SHIFT")] - @AsyncStatus.wrap - async def set(self, value: float): - """Sets channel's VOUT to given value. - - Args: - value: float to set VOUT to - """ - await self.output_voltage.set(value) - -class BimorphMirror(StandardReadable, Movable[Mapping[int, float]]): +class BimorphMirror(StandardReadable, Movable): """Class to represent CAENels Bimorph Mirrors. Attributes: channels: DeviceVector of BimorphMirrorChannel, indexed from 1, for each channel enabled: Writeable BimorphOnOff - commit_target_voltages: Procable signal that writes values in each channel's VTRGT to VOUT status: Readable BimorphMirrorStatus Busy/Idle status err: Alarm status""" @@ -103,49 +94,51 @@ def __init__(self, prefix: str, number_of_channels: int, name=""): super().__init__(name=name) @AsyncStatus.wrap - async def set(self, value: Mapping[int, float], tolerance: float = 0.0001) -> None: - """Sets bimorph voltages in parrallel via target voltage and all proc. + async def set(self, value: list[float]) -> None: + """Sets bimorph voltages in parallel via target voltage and all proc. Args: - value: Dict of channel numbers to target voltages + value: List of float target voltages Raises: ValueError: On set to non-existent channel""" - if any(key not in self.channels for key in value): + if len(value) != len(self.channels): raise ValueError( - f"Attempting to put to non-existent channels: {[key for key in value if (key not in self.channels)]}" + f"Length of value input array does not match number of \ + channels: {len(value)} and {len(self.channels)}" ) - # Write target voltages: - await asyncio.gather( - *[ - self.channels[i].target_voltage.set(target, wait=True) - for i, target in value.items() - ] - ) + # Write target voltages in serial + # Voltages are written in serial as bimorph PSU cannot handle simultaneous sets + for i, target in enumerate(value): + await wait_for_value( + self.status, BimorphMirrorStatus.IDLE, timeout=DEFAULT_TIMEOUT + ) + await set_and_wait_for_other_value( + self.channels[i + 1].target_voltage, + target, + self.status, + BimorphMirrorStatus.BUSY, + ) # Trigger set target voltages: + await wait_for_value( + self.status, BimorphMirrorStatus.IDLE, timeout=DEFAULT_TIMEOUT + ) await self.commit_target_voltages.trigger() # Wait for values to propogate to voltage out rbv: await asyncio.gather( *[ wait_for_value( - self.channels[i].output_voltage, - tolerance_func_builder(tolerance, target), + self.channels[i + 1].output_voltage, + target, timeout=DEFAULT_TIMEOUT, ) - for i, target in value.items() + for i, target in enumerate(value) ], wait_for_value( self.status, BimorphMirrorStatus.IDLE, timeout=DEFAULT_TIMEOUT ), ) - - -def tolerance_func_builder(tolerance: float, target_value: float): - def is_within_value(x): - return abs(x - target_value) <= tolerance - - return is_within_value diff --git a/src/dodal/plans/__init__.py b/src/dodal/plans/__init__.py index fb40245969..7483048b8b 100644 --- a/src/dodal/plans/__init__.py +++ b/src/dodal/plans/__init__.py @@ -1,4 +1,5 @@ +from .bimorph import bimorph_optimisation from .scanspec import spec_scan from .wrapped import count -__all__ = ["count", "spec_scan"] +__all__ = ["count", "spec_scan", "bimorph_optimisation"] diff --git a/src/dodal/plans/bimorph.py b/src/dodal/plans/bimorph.py new file mode 100644 index 0000000000..7ea3a76e92 --- /dev/null +++ b/src/dodal/plans/bimorph.py @@ -0,0 +1,315 @@ +from collections.abc import Generator +from dataclasses import dataclass +from enum import Enum + +import bluesky.plan_stubs as bps +import bluesky.preprocessors as bpp +from bluesky.protocols import Preparable, Readable +from bluesky.utils import MsgGenerator +from numpy import linspace +from ophyd_async.core import TriggerInfo + +from dodal.devices.bimorph_mirror import BimorphMirror +from dodal.devices.slits import Slits + + +class SlitDimension(Enum, str): + """Enum representing the dimensions of a 2d slit + + Used to describe which dimension the pencil beam scan should move across. + The other dimension will be held constant. + + Attributes: + X: Represents X dimension + Y: Represents Y dimension + """ + + X = "X" + Y = "Y" + + +def move_slits(slits: Slits, dimension: SlitDimension, gap: float, center: float): + """Moves ones dimension of Slits object to given position. + + Args: + slits: Slits to move + dimension: SlitDimension (X or Y) + gap: float size of gap + center: float position of center + """ + if dimension == SlitDimension.X: + yield from bps.mv(slits.x_gap, gap) # type: ignore + yield from bps.mv(slits.x_centre, center) # type: ignore + else: + yield from bps.mv(slits.y_gap, gap) # type: ignore + yield from bps.mv(slits.y_centre, center) # type: ignore + + +def check_valid_bimorph_state( + voltage_list: list[float], abs_range: float, abs_diff: float +) -> bool: + """Checks that a set of bimorph voltages is valid. + Args: + voltage_list: float amount each actuator will be increased by per scan + abs_range: float absolute value of maximum possible voltage of each actuator + abs_diff: float absolute maximum difference between two consecutive actuators + + Returns: + Bool representing state validity + """ + for voltage in voltage_list: + if abs(voltage) > abs_range: + return False + + for i in range(len(voltage_list) - 1): + if abs(voltage_list[i] - voltage_list[i - 1]) > abs_diff: + return False + + return True + + +def validate_bimorph_plan( + initial_voltage_list: list[float], + voltage_increment: float, + abs_range: float, + abs_diff: float, +) -> bool: + """Checks that every position the bimorph will move through will not error. + + Args: + initial_voltage_list: float list starting position + voltage_increment: float amount each actuator will be increased by per scan + abs_range: float absolute value of maximum possible voltage of each actuator + abs_diff: float absolute maximum difference between two consecutive actuators + + Raises: + Exception if the plan will lead to an error state""" + voltage_list = initial_voltage_list.copy() + + if not check_valid_bimorph_state(voltage_list, abs_range, abs_diff): + raise Exception(f"Bimorph plan reaches invalid state at: {voltage_list}") + + for i in range(len(initial_voltage_list)): + voltage_list[i] += voltage_increment + + if not check_valid_bimorph_state(voltage_list, abs_range, abs_diff): + raise Exception(f"Bimorph plan reaches invalid state at: {voltage_list}") + + return True + + +@dataclass +class BimorphState: + """Data class containing positions of BimorphMirror and Slits""" + + voltages: list[float] + x_gap: float + y_gap: float + x_center: float + y_center: float + + +def capture_bimorph_state(mirror: BimorphMirror, slits: Slits): + """Plan stub that captures current position of BimorphMirror and Slits. + + Args: + mirror: BimorphMirror to read from + slits: Slits to read from + + Returns: + A BimorphState containing BimorphMirror and Slits positions""" + original_voltage_list = [] + + for channel in mirror.channels.values(): + position = yield from bps.rd(channel.output_voltage) + original_voltage_list.append(position) + + original_x_gap = yield from bps.rd(slits.x_gap) + original_y_gap = yield from bps.rd(slits.y_gap) + original_x_center = yield from bps.rd(slits.x_centre) + original_y_center = yield from bps.rd(slits.y_centre) + return BimorphState( + original_voltage_list, + original_x_gap, + original_y_gap, + original_x_center, + original_y_center, + ) + + +def restore_bimorph_state(mirror: BimorphMirror, slits: Slits, state: BimorphState): + """Moves BimorphMirror and Slits to state given in BirmophState. + + Args: + mirror: BimorphMirror to move + slits: Slits to move + state: BimorphState to move to. + """ + yield from move_slits(slits, SlitDimension.X, state.x_gap, state.x_center) + yield from move_slits(slits, SlitDimension.Y, state.y_gap, state.y_center) + + yield from bps.mv(mirror, state.voltages) # type: ignore + + +def bimorph_position_generator( + initial_voltage_list: list[float], voltage_increment: float +) -> Generator[list[float], None, None]: + """Generator that produces bimorph positions, starting with the initial_voltage_list. + + Args: + initial_voltage_list: list starting position for bimorph + voltage_increment: float amount to increase each actuator by in turn + + Yields: + List bimorph positions, starting with initial_voltage_list + """ + voltage_list = initial_voltage_list.copy() + + for i in range(-1, len(initial_voltage_list)): + yield [ + voltage + voltage_increment if i >= j else voltage + for (j, voltage) in enumerate(voltage_list) + ] + + +def bimorph_optimisation( + detectors: list[Readable], + mirror: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list | None = None, +) -> MsgGenerator: + """Plan for performing bimorph mirror optimisation. + + Bluesky plan that performs a series of pencil beam scans across one axis of a + bimorph mirror, of using a 2-dimensional slit. + + Args: + bimorph: BimorphMirror to move + slit: Slits + oav: oav on-axis viewer + voltage_increment: float voltage increment applied to each bimorph electrode + active_dimension: SlitDimension that slit will move in (X or Y) + active_slit_center_start: float start position of center of slit in active dimension + active_slit_center_end: float final position of center of slit in active dimension + active_slit_size: float size of slit in active dimension + inactive_slit_center: float center of slit in inactive dimension + inactive_slit_size: float size of slit in inactive dimension + number_of_slit_positions: int number of slit positions per pencil beam scan + bimorph_settle_time: float time in seconds to wait after bimorph move + slit_settle_time: float time in seconds to wait after slit move + initial_voltage_list: optional list[float] starting voltages for bimorph (defaults to current voltages) + """ + + state = yield from capture_bimorph_state(mirror, slits) + + # If a starting set of voltages is not provided, default to current: + initial_voltage_list = initial_voltage_list or state.voltages + + bimorph_positions = bimorph_position_generator( + initial_voltage_list, voltage_increment + ) + + validate_bimorph_plan(initial_voltage_list, voltage_increment, 1000, 500) + + inactive_dimension = ( + SlitDimension.Y if active_dimension == SlitDimension.X else SlitDimension.X + ) + + metadata = { + "voltage_increment": voltage_increment, + "dimension": active_dimension, + "slit_positions": number_of_slit_positions, + "channels": len(mirror.channels), + } + + @bpp.run_decorator(md=metadata) + @bpp.stage_decorator((*detectors, mirror, slits)) + def outer_scan(): + """Outer plan stub, which moves mirror and calls inner_scan.""" + for detector in detectors: + if isinstance(detector, Preparable): + yield from bps.prepare(detector, TriggerInfo(), wait=True) + + stream_name = "0" + yield from bps.declare_stream(*detectors, mirror, slits, name=stream_name) + + # Move slits into starting position: + yield from move_slits( + slits, active_dimension, active_slit_size, active_slit_center_start + ) + yield from move_slits( + slits, inactive_dimension, inactive_slit_size, inactive_slit_center + ) + yield from bps.sleep(slit_settle_time) + + for bimorph_position in bimorph_positions: + yield from bps.mv( + mirror, # type: ignore + bimorph_position, # type: ignore + ) + yield from bps.sleep(bimorph_settle_time) + + yield from bps.declare_stream(*detectors, mirror, slits, name=stream_name) + + yield from inner_scan( + detectors, + mirror, + slits, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + number_of_slit_positions, + slit_settle_time, + stream_name, + ) + + stream_name = str(int(stream_name) + 1) + + yield from outer_scan() + + yield from restore_bimorph_state(mirror, slits, state) + + +def inner_scan( + detectors: list[Readable], + mirror: BimorphMirror, + slits: Slits, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + number_of_slit_positions: int, + slit_settle_time: float, + stream_name: str, +): + """Inner plan stub, which moves Slits and performs a read. + + Args: + mirror: BimorphMirror to move + slit: Slits + oav: oav on-axis viewer + active_dimension: SlitDimension that slit will move in (X or Y) + active_slit_center_start: float start position of center of slit in active dimension + active_slit_center_end: float final position of center of slit in active dimension + active_slit_size: float size of slit in active dimension + number_of_slit_positions: int number of slit positions per pencil beam scan + slit_settle_time: float time in seconds to wait after slit move + stream_name: str name to pass to trigger_and_read + """ + for value in linspace( + active_slit_center_start, active_slit_center_end, number_of_slit_positions + ): + yield from move_slits(slits, active_dimension, active_slit_size, value) + yield from bps.sleep(slit_settle_time) + yield from bps.trigger_and_read([*detectors, mirror, slits], name=stream_name) diff --git a/tests/devices/unit_tests/test_bimorph_mirror.py b/tests/devices/unit_tests/test_bimorph_mirror.py index 7c1b5d39ee..18b880b915 100644 --- a/tests/devices/unit_tests/test_bimorph_mirror.py +++ b/tests/devices/unit_tests/test_bimorph_mirror.py @@ -1,16 +1,23 @@ +import asyncio +from collections.abc import Callable +from typing import Any from unittest.mock import ANY, call, patch import pytest from bluesky.run_engine import RunEngine -from ophyd_async.core import init_devices -from ophyd_async.testing import get_mock_put +from ophyd_async.core import init_devices, walk_rw_signals +from ophyd_async.testing import callback_on_mock_put, get_mock_put, set_mock_value -from dodal.devices.bimorph_mirror import BimorphMirror, BimorphMirrorStatus +from dodal.devices.bimorph_mirror import ( + BimorphMirror, + BimorphMirrorChannel, + BimorphMirrorStatus, +) VALID_BIMORPH_CHANNELS = [8, 12, 16, 24] -@pytest.fixture +@pytest.fixture(params=VALID_BIMORPH_CHANNELS) def mirror(request, RE: RunEngine) -> BimorphMirror: number_of_channels = request.param @@ -24,122 +31,117 @@ def mirror(request, RE: RunEngine) -> BimorphMirror: @pytest.fixture -def valid_bimorph_values(mirror: BimorphMirror) -> dict[int, float]: - return {i: float(i) for i in range(1, len(mirror.channels) + 1)} +def valid_bimorph_values(mirror: BimorphMirror) -> list[float]: + return [float(i) for i in range(1, len(mirror.channels) + 1)] @pytest.fixture -def mock_vtrgt_vout_propogation(mirror: BimorphMirror): - for channel in mirror.channels.values(): +def mirror_with_mocked_put(mirror: BimorphMirror): + """Returns BimorphMirror with some simulated behaviour. + + BimorphMirror that simulates BimorphMirrorStatus BUSY/IDLE behaviour on all + rw_signals, and propogation from target_voltage to output_voltage on each + channel. + + Args: + mirror: BimorphMirror fixture + """ + + async def busy_idle(): + await asyncio.sleep(0) + set_mock_value(mirror.status, BimorphMirrorStatus.BUSY) + await asyncio.sleep(0) + set_mock_value(mirror.status, BimorphMirrorStatus.IDLE) + + async def start_busy_idle(*_: Any, **__: Any): + asyncio.create_task(busy_idle()) + + for signal in walk_rw_signals(mirror).values(): + callback_on_mock_put(signal, start_busy_idle) + + def callback_function( + channel: BimorphMirrorChannel, + ) -> Callable[[float, bool], None]: + def output_voltage_propogation_and_status( + value: float, + wait: bool = False, + ): + channel.output_voltage.set(value, wait=wait) + asyncio.create_task(busy_idle()) - def effect(value: float, wait=False, signal=channel.output_voltage): - signal.set(value, wait=wait) + return output_voltage_propogation_and_status - get_mock_put(channel.target_voltage).side_effect = effect + for channel in mirror.channels.values(): + callback_on_mock_put(channel.target_voltage, callback_function(channel)) + + return mirror -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) async def test_set_channels_waits_for_readback( - mirror: BimorphMirror, - valid_bimorph_values: dict[int, float], - mock_vtrgt_vout_propogation, + mirror_with_mocked_put: BimorphMirror, + valid_bimorph_values: list[float], ): - await mirror.set(valid_bimorph_values) + await mirror_with_mocked_put.set(valid_bimorph_values) - assert { - key: await mirror.channels[key].target_voltage.get_value() - for key in valid_bimorph_values - } == valid_bimorph_values + assert [ + await mirror_with_mocked_put.channels[i].target_voltage.get_value() + for i in range(1, len(valid_bimorph_values) + 1) + ] == valid_bimorph_values -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) async def test_set_channels_triggers_alltrgt_proc( - mirror: BimorphMirror, - valid_bimorph_values: dict[int, float], - mock_vtrgt_vout_propogation, + mirror_with_mocked_put: BimorphMirror, + valid_bimorph_values: list[float], ): - mock_alltrgt_proc = get_mock_put(mirror.commit_target_voltages) + mock_alltrgt_proc = get_mock_put(mirror_with_mocked_put.commit_target_voltages) mock_alltrgt_proc.assert_not_called() - await mirror.set(valid_bimorph_values) + await mirror_with_mocked_put.set(valid_bimorph_values) mock_alltrgt_proc.assert_called_once() -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) -async def test_set_channels_waits_for_vout_readback( - mirror: BimorphMirror, - valid_bimorph_values: dict[int, float], - mock_vtrgt_vout_propogation, +async def test_set_channels_waits_for_output_voltage_readback( + mirror_with_mocked_put: BimorphMirror, + valid_bimorph_values: list[float], ): with patch("dodal.devices.bimorph_mirror.wait_for_value") as mock_wait_for_value: mock_wait_for_value.assert_not_called() - await mirror.set(valid_bimorph_values) + await mirror_with_mocked_put.set(valid_bimorph_values) expected_call_arg_list = [ - call(mirror.channels[i].output_voltage, ANY, timeout=ANY) - for i, val in valid_bimorph_values.items() + call( + mirror_with_mocked_put.channels[i + 1].output_voltage, ANY, timeout=ANY + ) + for i, val in enumerate(valid_bimorph_values) ] - expected_call_arg_list.append( - call(mirror.status, BimorphMirrorStatus.IDLE, timeout=ANY) - ) - assert expected_call_arg_list == mock_wait_for_value.call_args_list - - -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) -async def test_set_channels_allows_tolerance( - mirror: BimorphMirror, - valid_bimorph_values: dict[int, float], -): - for channel in mirror.channels.values(): - def out_by_a_little(value: float, wait=False, signal=channel.output_voltage): - signal.set(value + 0.00001, wait=wait) - - get_mock_put(channel.target_voltage).side_effect = out_by_a_little - - await mirror.set(valid_bimorph_values) - - -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) -async def test_set_one_channel(mirror: BimorphMirror, mock_vtrgt_vout_propogation): - values = {1: 1} - - await mirror.set(values) - - read = await mirror.read() - - assert [ - await mirror.channels[key].target_voltage.get_value() for key in values - ] == list(values) - - assert [ - read[f"{mirror.name}-channels-{key}-output_voltage"]["value"] for key in values - ] == list(values) + assert all( + c in mock_wait_for_value.call_args_list for c in expected_call_arg_list + ) -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) async def test_read( - mirror: BimorphMirror, - valid_bimorph_values: dict[int, float], - mock_vtrgt_vout_propogation, + mirror_with_mocked_put: BimorphMirror, + valid_bimorph_values: list[float], ): - await mirror.set(valid_bimorph_values) + await mirror_with_mocked_put.set(valid_bimorph_values) - read = await mirror.read() + read = await mirror_with_mocked_put.read() assert [ - read[f"{mirror.name}-channels-{i}-output_voltage"]["value"] - for i in range(1, len(mirror.channels) + 1) - ] == list(valid_bimorph_values.values()) + read[f"{mirror_with_mocked_put.name}-channels-{i}-output_voltage"]["value"] + for i in range(1, len(mirror_with_mocked_put.channels) + 1) + ] == list(valid_bimorph_values) -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) -async def test_set_invalid_channel_throws_error(mirror: BimorphMirror): +async def test_set_invalid_value_throws_error(mirror_with_mocked_put: BimorphMirror): with pytest.raises(ValueError): - await mirror.set({len(mirror.channels) + 1: 0.0}) + await mirror_with_mocked_put.set( + list(range(len(mirror_with_mocked_put.channels) + 1)) + ) @pytest.mark.parametrize("number_of_channels", [-1]) @@ -150,18 +152,7 @@ async def test_init_mirror_with_invalid_channels_throws_error(number_of_channels @pytest.mark.parametrize("number_of_channels", [0]) async def test_init_mirror_with_zero_channels(number_of_channels): - mirror = BimorphMirror(prefix="FAKE-PREFIX", number_of_channels=number_of_channels) - assert len(mirror.channels) == 0 - - -@pytest.mark.parametrize("mirror", VALID_BIMORPH_CHANNELS, indirect=True) -async def test_bimorph_mirror_channel_set( - mirror: BimorphMirror, - valid_bimorph_values: dict[int, float], -): - for value, channel in zip( - valid_bimorph_values.values(), mirror.channels.values(), strict=True - ): - assert await channel.output_voltage.get_value() != value - await channel.set(value) - assert await channel.output_voltage.get_value() == value + mirror_with_mocked_put = BimorphMirror( + prefix="FAKE-PREFIX", number_of_channels=number_of_channels + ) + assert len(mirror_with_mocked_put.channels) == 0 diff --git a/tests/plans/test_bimorph.py b/tests/plans/test_bimorph.py new file mode 100644 index 0000000000..efea5c4f48 --- /dev/null +++ b/tests/plans/test_bimorph.py @@ -0,0 +1,810 @@ +import asyncio +import unittest +import unittest.mock +from collections.abc import Generator +from typing import Any +from unittest.mock import ANY, Mock, call + +import bluesky.plan_stubs as bps +import pytest +from bluesky.protocols import Readable +from bluesky.run_engine import RunEngine +from bluesky.utils import Msg +from numpy import linspace +from ophyd_async.core import ( + PathProvider, + StandardDetector, + init_devices, + walk_rw_signals, +) +from ophyd_async.sim import SimBlobDetector +from ophyd_async.testing import callback_on_mock_put, get_mock_put, set_mock_value + +from dodal.devices.bimorph_mirror import BimorphMirror, BimorphMirrorStatus +from dodal.devices.slits import Slits +from dodal.plans.bimorph import ( + BimorphState, + SlitDimension, + bimorph_optimisation, + bimorph_position_generator, + capture_bimorph_state, + check_valid_bimorph_state, + inner_scan, + move_slits, + restore_bimorph_state, + validate_bimorph_plan, +) + +VALID_BIMORPH_CHANNELS = [2] + + +@pytest.fixture(params=VALID_BIMORPH_CHANNELS) +def mirror(request, RE: RunEngine) -> BimorphMirror: + number_of_channels = request.param + + with init_devices(mock=True): + bm = BimorphMirror( + prefix="FAKE-PREFIX:", + number_of_channels=number_of_channels, + ) + + return bm + + +@pytest.fixture +def mirror_with_mocked_put(mirror: BimorphMirror) -> BimorphMirror: + async def busy_idle(): + await asyncio.sleep(0) + set_mock_value(mirror.status, BimorphMirrorStatus.BUSY) + await asyncio.sleep(0) + set_mock_value(mirror.status, BimorphMirrorStatus.IDLE) + + async def status(*_, **__): + asyncio.create_task(busy_idle()) + + for signal in walk_rw_signals(mirror).values(): + callback_on_mock_put(signal, status) + + for channel in mirror.channels.values(): + + def vout_propogation_and_status( + value: float, wait=False, signal=channel.output_voltage + ): + signal.set(value, wait=wait) + asyncio.create_task(busy_idle()) + + callback_on_mock_put(channel.target_voltage, vout_propogation_and_status) + + return mirror + + +@pytest.fixture +def slits(RE: RunEngine) -> Slits: + """Mock slits with propagation from setpoint to readback.""" + with init_devices(mock=True): + slits = Slits("FAKE-PREFIX:") + + for motor in [slits.x_gap, slits.y_gap, slits.x_centre, slits.y_centre]: + # Set velocity to avoid zero velocity error: + set_mock_value(motor.velocity, 1) + + def callback(value, wait=False, signal=motor.user_readback): + set_mock_value(signal, value) + + callback_on_mock_put(motor.user_setpoint, callback) + return slits + + +@pytest.fixture +async def oav(RE: RunEngine, static_path_provider: PathProvider) -> StandardDetector: + with init_devices(): + det = SimBlobDetector(static_path_provider) + return det + + +@pytest.fixture(params=[0, 1]) +async def detectors(request, oav: StandardDetector) -> list[Readable]: + return [oav] * request.param + + +@pytest.fixture(params=[True, False]) +def initial_voltage_list(request, mirror) -> list[float] | None: + if request.param: + return [0.0 for _ in range(len(mirror.channels))] + else: + return None + + +@pytest.mark.parametrize("dimension", [SlitDimension.X, SlitDimension.Y]) +@pytest.mark.parametrize("gap", [1.0]) +@pytest.mark.parametrize("center", [2.0]) +async def test_move_slits( + slits: Slits, + dimension: SlitDimension, + gap: float, + center: float, +): + messages = list(move_slits(slits, dimension, gap, center)) + + if dimension == SlitDimension.X: + gap_signal = slits.x_gap + centre_signal = slits.x_centre + else: + gap_signal = slits.y_gap + centre_signal = slits.y_centre + + assert [ + Msg("set", gap_signal, gap, group=ANY), + Msg("wait", None, group=ANY), + Msg("set", centre_signal, center, group=ANY), + Msg("wait", None, group=ANY), + ] == messages + + +async def test_save_and_restore( + RE: RunEngine, mirror_with_mocked_put: BimorphMirror, slits: Slits +): + signals = [ + slits.x_gap.user_setpoint, + slits.y_gap.user_setpoint, + slits.x_centre.user_setpoint, + slits.y_centre.user_setpoint, + mirror_with_mocked_put.channels[1].output_voltage, + ] + puts = [get_mock_put(signal) for signal in signals] + + def plan(): + state = yield from capture_bimorph_state(mirror_with_mocked_put, slits) + + for signal in signals: + yield from bps.abs_set(signal, 4.0, wait=True) + + yield from restore_bimorph_state(mirror_with_mocked_put, slits, state) + + RE(plan()) + + for put in puts: + assert put.call_args_list == [call(4.0, wait=True), call(0.0, wait=True)] + + +@pytest.mark.parametrize("voltage_list", [[0.0 for _ in range(8)]]) +@pytest.mark.parametrize("abs_range", [1000.0]) +@pytest.mark.parametrize("abs_diff", [200.0]) +class TestPlanValidation: + def test_valid_bimorph_state( + self, voltage_list: list[float], abs_range: float, abs_diff: float + ): + assert check_valid_bimorph_state(voltage_list, abs_range, abs_diff) + + def test_invalid_range_bimorph_state( + self, voltage_list: list[float], abs_range: float, abs_diff: float + ): + assert not check_valid_bimorph_state([abs_range + 1], abs_range, abs_diff) + + def test_invalid_diff_bimorph_state( + self, voltage_list: list[float], abs_range: float, abs_diff: float + ): + assert not check_valid_bimorph_state([abs_diff, -abs_diff], abs_range, abs_diff) + + def test_invalid_plan( + self, voltage_list: list[float], abs_range: float, abs_diff: float + ): + with pytest.raises(Exception): # noqa: B017 + validate_bimorph_plan([1000.0, 0.0], 200.0, abs_range, abs_diff) + + def test_valid_plan( + self, voltage_list: list[float], abs_range: float, abs_diff: float + ): + assert validate_bimorph_plan(voltage_list, 200.0, abs_range, abs_diff) + + +@pytest.mark.parametrize("initial_voltage_list", [[0 for _ in range(2)]]) +@pytest.mark.parametrize("voltage_increment", [200.0]) +class TestBimorphPositionGenerator: + def test_copies_list( + self, initial_voltage_list: list[float], voltage_increment: float + ): + list_copy = initial_voltage_list.copy() + + positions = list( + bimorph_position_generator(initial_voltage_list, voltage_increment) + ) + + assert positions[1] != initial_voltage_list + + assert initial_voltage_list == list_copy + + def test_generated_positions( + self, initial_voltage_list: list[float], voltage_increment: float + ): + positions = list( + bimorph_position_generator(initial_voltage_list, voltage_increment) + ) + + assert positions == [ + [0.0, 0.0], + [200.0, 0.0], + [200.0, 200.0], + ] + + +@pytest.mark.parametrize("active_dimension", [SlitDimension.X, SlitDimension.Y]) +@pytest.mark.parametrize("active_slit_center_start", [0.0]) +@pytest.mark.parametrize("active_slit_center_end", [200]) +@pytest.mark.parametrize("active_slit_size", [0.05]) +@pytest.mark.parametrize("number_of_slit_positions", [2]) +@pytest.mark.parametrize("slit_settle_time", [0.0]) +@pytest.mark.parametrize("stream_name", [0]) +@unittest.mock.patch("dodal.plans.bimorph.bps.sleep") +@unittest.mock.patch("dodal.plans.bimorph.bps.trigger_and_read") +@unittest.mock.patch("dodal.plans.bimorph.move_slits") +class TestInnerScan: + def test_inner_scan_moves_slits( + self, + mock_move_slits: Mock, + mock_bps_trigger_and_read: Mock, + mock_bps_sleep: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror: BimorphMirror, + slits: Slits, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + number_of_slit_positions: int, + slit_settle_time: float, + stream_name: str, + ): + RE( + inner_scan( + detectors, + mirror, + slits, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + number_of_slit_positions, + slit_settle_time, + stream_name, + ) + ) + + call_list = [ + call(slits, active_dimension, active_slit_size, value) + for value in linspace( + active_slit_center_start, + active_slit_center_end, + number_of_slit_positions, + ) + ] + + assert mock_move_slits.call_args_list == call_list + + def test_inner_scan_triggers_and_reads( + self, + mock_move_slits: Mock, + mock_bps_trigger_and_read: Mock, + mock_bps_sleep: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror: BimorphMirror, + slits: Slits, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + number_of_slit_positions: int, + slit_settle_time: float, + stream_name: str, + ): + RE( + inner_scan( + detectors, + mirror, + slits, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + number_of_slit_positions, + slit_settle_time, + stream_name, + ) + ) + + call_list = [ + call([*detectors, mirror, slits], name=stream_name) + for _ in linspace( + active_slit_center_start, + active_slit_center_end, + number_of_slit_positions, + ) + ] + assert mock_bps_trigger_and_read.call_args_list == call_list + + def test_inner_scan_slit_settle( + self, + mock_move_slits: Mock, + mock_bps_trigger_and_read: Mock, + mock_bps_sleep: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror: BimorphMirror, + slits: Slits, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + number_of_slit_positions: int, + slit_settle_time: float, + stream_name: str, + ): + RE( + inner_scan( + detectors, + mirror, + slits, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + number_of_slit_positions, + slit_settle_time, + stream_name, + ) + ) + assert [ + call(slit_settle_time) for _ in range(number_of_slit_positions) + ] == mock_bps_sleep.call_args_list + + +@pytest.mark.parametrize("voltage_increment", [100.0]) +@pytest.mark.parametrize("active_dimension", [SlitDimension.X, SlitDimension.Y]) +@pytest.mark.parametrize("active_slit_center_start", [0.0]) +@pytest.mark.parametrize("active_slit_center_end", [200.0]) +@pytest.mark.parametrize("active_slit_size", [0.05]) +@pytest.mark.parametrize("inactive_slit_center", [0.0]) +@pytest.mark.parametrize("inactive_slit_size", [0.05]) +@pytest.mark.parametrize("number_of_slit_positions", [3]) +@pytest.mark.parametrize("bimorph_settle_time", [0.0]) +@pytest.mark.parametrize("slit_settle_time", [0.0]) +@unittest.mock.patch("dodal.plans.bimorph.bps.sleep") +@unittest.mock.patch("dodal.plans.bimorph.restore_bimorph_state") +@unittest.mock.patch("dodal.plans.bimorph.move_slits") +@unittest.mock.patch("dodal.plans.bimorph.inner_scan") +class TestBimorphOptimisation: + """Run full bimorph_optimisation plan with mocked devices and plan stubs.""" + + @pytest.fixture + def start_state(self, mirror_with_mocked_put: BimorphMirror) -> BimorphState: + return BimorphState( + [10.0 for _ in range(len(mirror_with_mocked_put.channels))], + 0.0, + 0.0, + 0.0, + 0.0, + ) + + @pytest.fixture + def mock_capture_bimorph_state( + self, start_state: BimorphState + ) -> Generator[Mock, None, None]: + with unittest.mock.patch( + "dodal.plans.bimorph.capture_bimorph_state" + ) as mock_obj: + + def mock_capture_plan_stub( + *args: Any, **kwargs: Any + ) -> Generator[None, None, BimorphState]: + # return start_state without yielding Msg to RE: + yield from iter([]) + return start_state + + mock_obj.side_effect = mock_capture_plan_stub + + yield mock_obj + + async def test_metadata( + self, + mock_inner_scan: Mock, + mock_move_slits: Mock, + mock_restore_bimorph_state: Mock, + mock_bps_sleep: Mock, + mock_capture_bimorph_state: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + ): + def start_subscription(name, doc): + assert { + "voltage_increment": voltage_increment, + "dimension": active_dimension, + "slit_positions": number_of_slit_positions, + "channels": len(mirror_with_mocked_put.channels), + }.items() <= doc.items() + + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ), + {"start": start_subscription}, + ) + + async def test_settle_time_called( + self, + mock_inner_scan: Mock, + mock_move_slits: Mock, + mock_restore_bimorph_state: Mock, + mock_bps_sleep: Mock, + mock_capture_bimorph_state: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + ): + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ) + ) + # Once for each bimorph position, once for initial slit position: + assert [ + call(bimorph_settle_time) + for _ in range(len(mirror_with_mocked_put.channels) + 2) + ] == mock_bps_sleep.call_args_list + + async def test_bimorph_state_captured( + self, + mock_inner_scan: Mock, + mock_move_slits: Mock, + mock_restore_bimorph_state: Mock, + mock_bps_sleep: Mock, + mock_capture_bimorph_state: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + start_state, + ): + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ) + ) + + assert mock_capture_bimorph_state.call_args == call( + mirror_with_mocked_put, slits + ) + + async def test_plan_sets_mirror_start_position( + self, + mock_inner_scan: Mock, + mock_move_slits: Mock, + mock_restore_bimorph_state: Mock, + mock_bps_sleep: Mock, + mock_capture_bimorph_state: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + start_state: BimorphState, + ): + inactive_dimension = ( + SlitDimension.Y if active_dimension == SlitDimension.X else SlitDimension.X + ) + + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ) + ) + assert [ + call(slits, active_dimension, active_slit_size, active_slit_center_start), + call(slits, inactive_dimension, inactive_slit_size, inactive_slit_center), + ] == mock_move_slits.call_args_list + + async def test_plan_calls_inner_scan( + self, + mock_inner_scan: Mock, + mock_move_slits: Mock, + mock_restore_bimorph_state: Mock, + mock_bps_sleep: Mock, + mock_capture_bimorph_state: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + start_state: BimorphState, + ): + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ) + ) + assert [ + call( + detectors, + mirror_with_mocked_put, + slits, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + number_of_slit_positions, + slit_settle_time, + str(i), + ) + for i in range(len(mirror_with_mocked_put.channels) + 1) + ] == mock_inner_scan.call_args_list + + async def test_plan_puts_to_bimorph( + self, + mock_inner_scan: Mock, + mock_move_slits: Mock, + mock_restore_bimorph_state: Mock, + mock_bps_sleep: Mock, + mock_capture_bimorph_state: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + start_state: BimorphState, + ): + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ) + ) + + initial_voltage_list = initial_voltage_list or start_state.voltages + + assert [ + call(initial_voltage_list[i] + voltage_increment) + == get_mock_put(channel.target_voltage).call_args + for i, channel in enumerate(mirror_with_mocked_put.channels.values()) + ] + + async def test_bimorph_state_restored( + self, + mock_inner_scan: Mock, + mock_move_slits: Mock, + mock_restore_bimorph_state: Mock, + mock_bps_sleep: Mock, + mock_capture_bimorph_state: Mock, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + start_state: BimorphState, + ): + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ) + ) + + assert [ + call(mirror_with_mocked_put, slits, start_state) + ] == mock_restore_bimorph_state.call_args_list + + +@pytest.mark.parametrize("voltage_increment", [100.0]) +@pytest.mark.parametrize("active_dimension", [SlitDimension.X, SlitDimension.Y]) +@pytest.mark.parametrize("active_slit_center_start", [0.0]) +@pytest.mark.parametrize("active_slit_center_end", [200.0]) +@pytest.mark.parametrize("active_slit_size", [0.05]) +@pytest.mark.parametrize("inactive_slit_center", [0.0]) +@pytest.mark.parametrize("inactive_slit_size", [0.05]) +@pytest.mark.parametrize("number_of_slit_positions", [3]) +@pytest.mark.parametrize("bimorph_settle_time", [0.0]) +@pytest.mark.parametrize("slit_settle_time", [0.0]) +class TestIntegration: + def test_full_plan( + self, + detectors: list[Readable], + RE: RunEngine, + mirror_with_mocked_put: BimorphMirror, + slits: Slits, + voltage_increment: float, + active_dimension: SlitDimension, + active_slit_center_start: float, + active_slit_center_end: float, + active_slit_size: float, + inactive_slit_center: float, + inactive_slit_size: float, + number_of_slit_positions: int, + bimorph_settle_time: float, + slit_settle_time: float, + initial_voltage_list: list[float], + ): + RE( + bimorph_optimisation( + detectors, + mirror_with_mocked_put, + slits, + voltage_increment, + active_dimension, + active_slit_center_start, + active_slit_center_end, + active_slit_size, + inactive_slit_center, + inactive_slit_size, + number_of_slit_positions, + bimorph_settle_time, + slit_settle_time, + initial_voltage_list, + ), + ) + + assert True