Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions ophyd_devices/utils/signal_monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Utility class for monitoring signals in ophyd devices."""

from __future__ import annotations

import threading
import uuid
from typing import Callable

from bec_lib.logger import bec_logger
from ophyd import Signal

logger = bec_logger.logger


class SignalMonitoring:
"""
This class allows you to register Signal instances or callables that will be polled
with a specified interval. The interval can be customized. The monitoring may be started
and stopped when needed, and will happen in a separate thread.

In general, it should be used with ophyd.Signal instances, but it also accepts a callable
in case a certain 'script' method needs to be called in order to update a signal.

"""

def __init__(self, name: str = "SignalMonitoring"):
self.name = name
self._signal_instances = {}
self._callables = {}
self._lock = threading.RLock()
self._poll_thread = threading.Thread(target=self._poll_signals, daemon=True)
self._kill_event = threading.Event()
self._start_poll_event = threading.Event()
self._polling_interval_event = threading.Event()
self._polling_interval = 0.1 # seconds
self._poll_thread.start()
Comment thread
cappel89 marked this conversation as resolved.

@property
def polling_interval(self):
"""Polling interval in seconds."""
return self._polling_interval

@polling_interval.setter
def polling_interval(self, value: float):
if value <= 0:
raise ValueError("Polling interval must be positive.")
self._polling_interval = value

def _poll_signals(self):
"""Poll loop that checks registered signals and callables at the specified interval."""
while (
self._start_poll_event.wait() and not self._kill_event.is_set()
): # Wait until polling is started
self._polling_interval_event.wait(
timeout=self._polling_interval
) # Poll at the specified interval
Comment thread
cappel89 marked this conversation as resolved.
with self._lock:
try:
for signal in self._signal_instances.values():
signal.get()
for call in self._callables.values():
Comment thread
cappel89 marked this conversation as resolved.
call()
except Exception as e:
logger.error(f"Error while polling signals: {e}")
Comment thread
cappel89 marked this conversation as resolved.

def register_signal(self, signal: Signal | Callable[[], None]) -> str:
"""
Register a Signal instance or a callable to be monitored.

Args:
signal (Signal | Callable[[], None]): The Signal instance or callable to register.
"""
callback_id = str(uuid.uuid4())
with self._lock:
if isinstance(signal, Signal):
self._signal_instances[callback_id] = signal
elif callable(signal):
self._callables[callback_id] = signal
else:
raise ValueError(
f"Only Signal instances or callables can be registered, got {type(signal)}."
)
return callback_id

def remove_signal(self, callback_id: str):
"""
Remove a registered signal or callable by its callback ID.

Args:
callback_id (str): The unique ID of the signal or callable to remove.
"""
with self._lock:
if callback_id in self._signal_instances:
del self._signal_instances[callback_id]
elif callback_id in self._callables:
del self._callables[callback_id]
else:
logger.warning(
f"Callback ID {callback_id} not found in registered signals or callables."
)

def start(self):
"""Start the polling thread to monitor registered signals and callables."""
self._start_poll_event.set()

def stop(self):
"""Stop the polling thread without shutting it down."""
self._start_poll_event.clear()

def shutdown(self):
"""Shutdown the monitoring thread and clean up resources."""
with self._lock:
self._callables.clear()
self._signal_instances.clear()
self._kill_event.set()
self._start_poll_event.set() # Ensure the polling thread is not waiting
self._polling_interval_event.set() # Ensure the polling thread is not waiting
self._poll_thread.join()
89 changes: 89 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TaskStatus,
TransitionStatus,
)
from ophyd_devices.utils.signal_monitoring import SignalMonitoring

# pylint: disable=protected-access
# pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -227,6 +228,94 @@ def cb2():
assert status1.exception().__class__ == TaskKilledError


@pytest.mark.timeout(10)
def test_utils_signal_monitoring_of_script():
"""Verify polling triggers repeatedly and approximately at the configured cadence."""
monitoring = SignalMonitoring(name="test_signal_monitoring")
monitoring.polling_interval = 0.05

callback_timestamps = []
reached_target_calls = threading.Event()

def monitored_callback():
callback_timestamps.append(time.perf_counter())
if len(callback_timestamps) >= 6:
reached_target_calls.set()

callback_id = monitoring.register_signal(monitored_callback)

try:
monitoring.start()
assert reached_target_calls.wait(
timeout=1.5
), "SignalMonitoring did not poll the registered callback enough times in time."
monitoring.stop()

# Use a subset of intervals to reduce startup/shutdown jitter influence.
intervals = [
callback_timestamps[idx + 1] - callback_timestamps[idx]
for idx in range(len(callback_timestamps) - 1)
]
stable_intervals = intervals[1:-1] if len(intervals) > 4 else intervals

assert len(callback_timestamps) >= 6
assert stable_intervals

mean_interval = float(np.mean(stable_intervals))
assert mean_interval == pytest.approx(monitoring.polling_interval, abs=0.03)

# Guard against pathological bursts or stalls.
assert min(stable_intervals) > 0.02
assert max(stable_intervals) < 0.15
finally:
monitoring.remove_signal(callback_id)
monitoring.shutdown()


def test_utils_signal_monitoring_of_ophyd_signal():
"""Test that SignalMonitoring can monitor an ophyd Signal and trigger callbacks on value changes."""

class MockSignalWithCounter(Signal):

_target_event = threading.Event()
_get_counter = 0

def get(self):
self._get_counter += 1
if self._get_counter == 20: # After 20 polls, trigger the event to stop the test
self._target_event.set()
return self._readback

monitoring = SignalMonitoring(name="test_signal_monitoring")
monitoring.polling_interval = 0.05 # 20 times per second

signal = MockSignalWithCounter(name="test_signal", value=0)
assert signal._get_counter == 0
assert signal._target_event.is_set() is False

monitoring.register_signal(signal=signal)
monitoring.start()
assert signal._target_event.wait(
timeout=1.5
), "SignalMonitoring did not poll the Signal enough times in time."
monitoring.stop()
signal_counter = signal._get_counter
assert signal_counter >= 20, f"Expected at least 20 polls, got {signal_counter}"
time.sleep(0.2) # Wait to ensure no more polls happen after stopping
assert np.isclose(
signal._get_counter, signal_counter, atol=1
), ( # Allow for 1 additional poll due to timing uncertainty
"SignalMonitoring continued polling after stopping"
)

monitoring.shutdown()
timer = time.time()
while monitoring._poll_thread.is_alive():
time.sleep(0.1)
if time.time() - timer > 2:
raise TimeoutError("Polling thread did not shut down within expected time.")


##########################################
######### Test PSI cusomt signals ######
##########################################
Expand Down
Loading