diff --git a/ophyd_devices/utils/signal_monitoring.py b/ophyd_devices/utils/signal_monitoring.py new file mode 100644 index 0000000..02c7355 --- /dev/null +++ b/ophyd_devices/utils/signal_monitoring.py @@ -0,0 +1,118 @@ +"""Utility class for monitoring signals in ophyd devices.""" + +from __future__ import annotations + +import threading +import uuid +from typing import Callable + +from bec_lib.logger import bec_logger +from ophyd import Signal + +logger = bec_logger.logger + + +class SignalMonitoring: + """ + This class allows you to register Signal instances or callables that will be polled + with a specified interval. The interval can be customized. The monitoring may be started + and stopped when needed, and will happen in a separate thread. + + In general, it should be used with ophyd.Signal instances, but it also accepts a callable + in case a certain 'script' method needs to be called in order to update a signal. + + """ + + def __init__(self, name: str = "SignalMonitoring"): + self.name = name + self._signal_instances = {} + self._callables = {} + self._lock = threading.RLock() + self._poll_thread = threading.Thread(target=self._poll_signals, daemon=True) + self._kill_event = threading.Event() + self._start_poll_event = threading.Event() + self._polling_interval_event = threading.Event() + self._polling_interval = 0.1 # seconds + self._poll_thread.start() + + @property + def polling_interval(self): + """Polling interval in seconds.""" + return self._polling_interval + + @polling_interval.setter + def polling_interval(self, value: float): + if value <= 0: + raise ValueError("Polling interval must be positive.") + self._polling_interval = value + + def _poll_signals(self): + """Poll loop that checks registered signals and callables at the specified interval.""" + while ( + self._start_poll_event.wait() and not self._kill_event.is_set() + ): # Wait until polling is started + self._polling_interval_event.wait( + timeout=self._polling_interval + ) # Poll at the specified interval + with self._lock: + try: + for signal in self._signal_instances.values(): + signal.get() + for call in self._callables.values(): + call() + except Exception as e: + logger.error(f"Error while polling signals: {e}") + + def register_signal(self, signal: Signal | Callable[[], None]) -> str: + """ + Register a Signal instance or a callable to be monitored. + + Args: + signal (Signal | Callable[[], None]): The Signal instance or callable to register. + """ + callback_id = str(uuid.uuid4()) + with self._lock: + if isinstance(signal, Signal): + self._signal_instances[callback_id] = signal + elif callable(signal): + self._callables[callback_id] = signal + else: + raise ValueError( + f"Only Signal instances or callables can be registered, got {type(signal)}." + ) + return callback_id + + def remove_signal(self, callback_id: str): + """ + Remove a registered signal or callable by its callback ID. + + Args: + callback_id (str): The unique ID of the signal or callable to remove. + """ + with self._lock: + if callback_id in self._signal_instances: + del self._signal_instances[callback_id] + elif callback_id in self._callables: + del self._callables[callback_id] + else: + logger.warning( + f"Callback ID {callback_id} not found in registered signals or callables." + ) + + def start(self): + """Start the polling thread to monitor registered signals and callables.""" + self._start_poll_event.set() + + def stop(self): + """Stop the polling thread without shutting it down.""" + self._start_poll_event.clear() + + def shutdown(self): + """Shutdown the monitoring thread and clean up resources.""" + with self._lock: + self._callables.clear() + self._signal_instances.clear() + self._kill_event.set() + self._start_poll_event.set() # Ensure the polling thread is not waiting + self._polling_interval_event.set() # Ensure the polling thread is not waiting + self._poll_thread.join() diff --git a/tests/test_utils.py b/tests/test_utils.py index f0023ea..7da000c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,6 +38,7 @@ TaskStatus, TransitionStatus, ) +from ophyd_devices.utils.signal_monitoring import SignalMonitoring # pylint: disable=protected-access # pylint: disable=redefined-outer-name @@ -227,6 +228,94 @@ def cb2(): assert status1.exception().__class__ == TaskKilledError +@pytest.mark.timeout(10) +def test_utils_signal_monitoring_of_script(): + """Verify polling triggers repeatedly and approximately at the configured cadence.""" + monitoring = SignalMonitoring(name="test_signal_monitoring") + monitoring.polling_interval = 0.05 + + callback_timestamps = [] + reached_target_calls = threading.Event() + + def monitored_callback(): + callback_timestamps.append(time.perf_counter()) + if len(callback_timestamps) >= 6: + reached_target_calls.set() + + callback_id = monitoring.register_signal(monitored_callback) + + try: + monitoring.start() + assert reached_target_calls.wait( + timeout=1.5 + ), "SignalMonitoring did not poll the registered callback enough times in time." + monitoring.stop() + + # Use a subset of intervals to reduce startup/shutdown jitter influence. + intervals = [ + callback_timestamps[idx + 1] - callback_timestamps[idx] + for idx in range(len(callback_timestamps) - 1) + ] + stable_intervals = intervals[1:-1] if len(intervals) > 4 else intervals + + assert len(callback_timestamps) >= 6 + assert stable_intervals + + mean_interval = float(np.mean(stable_intervals)) + assert mean_interval == pytest.approx(monitoring.polling_interval, abs=0.03) + + # Guard against pathological bursts or stalls. + assert min(stable_intervals) > 0.02 + assert max(stable_intervals) < 0.15 + finally: + monitoring.remove_signal(callback_id) + monitoring.shutdown() + + +def test_utils_signal_monitoring_of_ophyd_signal(): + """Test that SignalMonitoring can monitor an ophyd Signal and trigger callbacks on value changes.""" + + class MockSignalWithCounter(Signal): + + _target_event = threading.Event() + _get_counter = 0 + + def get(self): + self._get_counter += 1 + if self._get_counter == 20: # After 20 polls, trigger the event to stop the test + self._target_event.set() + return self._readback + + monitoring = SignalMonitoring(name="test_signal_monitoring") + monitoring.polling_interval = 0.05 # 20 times per second + + signal = MockSignalWithCounter(name="test_signal", value=0) + assert signal._get_counter == 0 + assert signal._target_event.is_set() is False + + monitoring.register_signal(signal=signal) + monitoring.start() + assert signal._target_event.wait( + timeout=1.5 + ), "SignalMonitoring did not poll the Signal enough times in time." + monitoring.stop() + signal_counter = signal._get_counter + assert signal_counter >= 20, f"Expected at least 20 polls, got {signal_counter}" + time.sleep(0.2) # Wait to ensure no more polls happen after stopping + assert np.isclose( + signal._get_counter, signal_counter, atol=1 + ), ( # Allow for 1 additional poll due to timing uncertainty + "SignalMonitoring continued polling after stopping" + ) + + monitoring.shutdown() + timer = time.time() + while monitoring._poll_thread.is_alive(): + time.sleep(0.1) + if time.time() - timer > 2: + raise TimeoutError("Polling thread did not shut down within expected time.") + + ########################################## ######### Test PSI cusomt signals ###### ##########################################