diff --git a/docs/explanations/what-is-pytac.md b/docs/explanations/what-is-pytac.md index b1030c37..e670b8f3 100644 --- a/docs/explanations/what-is-pytac.md +++ b/docs/explanations/what-is-pytac.md @@ -51,7 +51,7 @@ with EPICS, readback (``pytac.RB``) or setpoint (``pytac.SP``). Data may be set to or retrieved from different data sources, from the live machine (``pytac.LIVE``) or from a simulator (``pytac.SIM``). By default the 'live' data source is implemented using -`Cothread `_ to communicate with +`aioca `_ to communicate with EPICS, as described above. The 'simulation' data source is left unimplemented, as Pytac does not include a simulator. However, ATIP, a module designed to integrate the `Accelerator Toolbox `_ simulator diff --git a/docs/tutorials/basic-tutorial.rst b/docs/tutorials/basic-tutorial.rst index d9620bc0..41f97b90 100644 --- a/docs/tutorials/basic-tutorial.rst +++ b/docs/tutorials/basic-tutorial.rst @@ -5,10 +5,10 @@ In this tutorial we will go through some of the most common ways of using pytac. The aim is to give you an understanding of the interface and how to find out what is available. -The import of the cothread channel access library and epicscorelibs will +The import of the aioca channel access library and epicscorelibs will allow us to get some live values from the Diamond accelerators. - $ pip install cothread epicscorelibs + $ pip install aioca epicscorelibs These docs are able to be run and tested, but may return different values as accelerator conditions will have changed. diff --git a/pyproject.toml b/pyproject.toml index f787e9c5..8e48c847 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ description = "Python Toolkit for Accelerator Controls (Pytac) is a Python libra dependencies = [ "numpy", "scipy", - "cothread", + "aioca", "epicscorelibs", ] # Add project dependencies here, e.g. ["click", "numpy"] dynamic = ["version"] @@ -32,6 +32,7 @@ dev = [ "pre-commit", "pydata-sphinx-theme>=0.12", "pytest", + "pytest-asyncio>=0.17", "pytest-cov", "ruff", "sphinx-autobuild", @@ -68,6 +69,8 @@ addopts = """ filterwarnings = "error" # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "src tests" +asyncio_mode = "auto" + [tool.coverage.run] patch = ["subprocess"] @@ -88,8 +91,9 @@ skipsdist=True # Don't create a virtualenv for the command, requires tox-direct plugin direct = True passenv = * -allowlist_externals = +allowlist_externals = pytest + pytest-asyncio pre-commit mypy sphinx-build diff --git a/src/pytac/cothread_cs.py b/src/pytac/aioca_cs.py similarity index 84% rename from src/pytac/cothread_cs.py rename to src/pytac/aioca_cs.py index ac5a9c06..2b6e3382 100644 --- a/src/pytac/cothread_cs.py +++ b/src/pytac/aioca_cs.py @@ -1,13 +1,13 @@ import logging -from cothread.catools import ca_nothing, caget, caput +from aioca import CANothing, caget, caput from pytac.cs import ControlSystem from pytac.exceptions import ControlSystemException -class CothreadControlSystem(ControlSystem): - """A control system using cothread to communicate with EPICS. +class AIOCAControlSystem(ControlSystem): + """A control system using aioca to communicate with EPICS. N.B. this is the default control system. It is used to communicate over channel access with the hardware in the ring. @@ -19,7 +19,7 @@ def __init__(self, timeout=1.0, wait=False): self._timeout = timeout self._wait = wait - def get_single(self, pv, throw=True): + async def get_single(self, pv, throw=True): """Get the value of a given PV. Args: @@ -35,8 +35,8 @@ def get_single(self, pv, throw=True): ControlSystemException: if it cannot connect to the specified PV. """ try: - return caget(pv, timeout=self._timeout, throw=True) - except ca_nothing: + return await caget(pv, timeout=self._timeout, throw=True) + except CANothing: error_msg = f"Cannot connect to {pv}." if throw: raise ControlSystemException(error_msg) # noqa: B904 @@ -44,7 +44,7 @@ def get_single(self, pv, throw=True): logging.warning(error_msg) return None - def get_multiple(self, pvs, throw=True): + async def get_multiple(self, pvs, throw=True): """Get the value for given PVs. Args: @@ -59,11 +59,11 @@ def get_multiple(self, pvs, throw=True): Raises: ControlSystemException: if it cannot connect to one or more PVs. """ - results = caget(pvs, timeout=self._timeout, throw=False) + results = await caget(pvs, timeout=self._timeout, throw=False) return_values = [] failures = [] for result in results: - if isinstance(result, ca_nothing): + if isinstance(result, CANothing): logging.warning(f"Cannot connect to {result.name}.") if throw: failures.append(result) @@ -75,7 +75,7 @@ def get_multiple(self, pvs, throw=True): raise ControlSystemException(f"{len(failures)} caget calls failed.") return return_values - def set_single(self, pv, value, throw=True): + async def set_single(self, pv, value, throw=True): """Set the value of a given PV. Args: @@ -91,9 +91,9 @@ def set_single(self, pv, value, throw=True): ControlSystemException: if it cannot connect to the specified PV. """ try: - caput(pv, value, timeout=self._timeout, throw=True, wait=self._wait) + await caput(pv, value, timeout=self._timeout, throw=True, wait=self._wait) return True - except ca_nothing: + except CANothing: error_msg = f"Cannot connect to {pv}." if throw: raise ControlSystemException(error_msg) # noqa: B904 @@ -101,7 +101,7 @@ def set_single(self, pv, value, throw=True): logging.warning(error_msg) return False - def set_multiple(self, pvs, values, throw=True): + async def set_multiple(self, pvs, values, throw=True): """Set the values for given PVs. Args: @@ -122,7 +122,9 @@ def set_multiple(self, pvs, values, throw=True): """ if len(pvs) != len(values): raise ValueError("Please enter the same number of values as PVs.") - status = caput(pvs, values, timeout=self._timeout, throw=False, wait=self._wait) + status = await caput( + pvs, values, timeout=self._timeout, throw=False, wait=self._wait + ) return_values = [] failures = [] for stat in status: diff --git a/src/pytac/data_source.py b/src/pytac/data_source.py index 928052e0..c32dea99 100644 --- a/src/pytac/data_source.py +++ b/src/pytac/data_source.py @@ -1,5 +1,7 @@ """Module containing pytac data source classes.""" +import inspect + import pytac from pytac.exceptions import DataSourceException, FieldException @@ -189,7 +191,7 @@ def set_unitconv(self, field, uc): """ self._uc[field] = uc - def get_value( + async def get_value( self, field: str, handle: str = pytac.RB, @@ -225,12 +227,12 @@ def get_value( if data_source_type == pytac.DEFAULT: data_source_type = self.default_data_source data_source = self.get_data_source(data_source_type) - value = data_source.get_value(field, handle, throw) + value = await data_source.get_value(field, handle, throw) return self.get_unitconv(field).convert( value, origin=data_source.units, target=units ) - def set_value( + async def set_value( self, field: str, value: float, @@ -264,7 +266,7 @@ def set_value( value = self.get_unitconv(field).convert( value, origin=units, target=data_source.units ) - data_source.set_value(field, value, throw) + await data_source.set_value(field, value, throw) class DeviceDataSource(DataSource): @@ -321,7 +323,7 @@ def get_fields(self): """ return self._devices.keys() - def get_value(self, field, handle, throw=True): + async def get_value(self, field, handle, throw=True): """Get the value of a readback or setpoint PV for a field from the data_source. @@ -337,9 +339,17 @@ def get_value(self, field, handle, throw=True): Raises: FieldException: if the device does not have the specified field. """ - return self.get_device(field).get_value(handle, throw) - - def set_value(self, field, value, throw=True): + device = self.get_device(field) + # TODO some devices dont need to be awaited as they are just retrieving stored + # data, but others get data from PVs so do, make this better + val = 0 + if inspect.iscoroutinefunction(device.get_value): + val = await device.get_value(handle, throw) + else: + val = device.get_value(handle, throw) + return val + + async def set_value(self, field, value, throw=True): """Set the value of a readback or setpoint PV for a field from the data_source. @@ -352,4 +362,10 @@ def set_value(self, field, value, throw=True): Raises: FieldException: if the device does not have the specified field. """ - self.get_device(field).set_value(value, throw) + device = self.get_device(field) + # TODO some devices dont need to be awaited as they are just setting local + # data, but others set data to PVs, so do, make this better + if inspect.iscoroutinefunction(device.set_value): + await device.set_value(value, throw) + else: + device.set_value(value, throw) diff --git a/src/pytac/device.py b/src/pytac/device.py index 4e351359..e8468ee2 100644 --- a/src/pytac/device.py +++ b/src/pytac/device.py @@ -168,7 +168,7 @@ def is_enabled(self): """ return bool(self._enabled) - def get_value(self, handle, throw=True): + async def get_value(self, handle, throw=True): """Read the value of a readback or setpoint PV. Args: @@ -182,9 +182,9 @@ def get_value(self, handle, throw=True): Raises: HandleException: if the requested PV doesn't exist. """ - return self._cs.get_single(self.get_pv_name(handle), throw) + return await self._cs.get_single(self.get_pv_name(handle), throw) - def set_value(self, value, throw=True): + async def set_value(self, value, throw=True): """Set the device value. Args: @@ -195,7 +195,7 @@ def set_value(self, value, throw=True): Raises: HandleException: if no setpoint PV exists. """ - self._cs.set_single(self.get_pv_name(pytac.SP), value, throw) + return await self._cs.set_single(self.get_pv_name(pytac.SP), value, throw) def get_pv_name(self, handle): """Get the PV name for the specified handle. @@ -220,9 +220,6 @@ def get_pv_name(self, handle): class PvEnabler: """A PvEnabler class to check whether a device is enabled. - The class will behave like True if the PV value equals enabled_value, - and False otherwise. - .. Private Attributes: _pv (str): The PV name. _enabled_value (str): The value for PV for which the device should @@ -244,11 +241,11 @@ def __init__(self, pv, enabled_value, cs): self._enabled_value = str(int(float(enabled_value))) self._cs = cs - def __bool__(self): + async def is_enabled(self): """Used to override the 'if object' clause. Returns: bool: True if the device should be considered enabled. """ - pv_value = self._cs.get_single(self._pv) - return self._enabled_value == str(int(float(pv_value))) + pv_value = await self._cs.get_single(self._pv) + return self._enabled_value == str(int(float(pv_value))) # ??? diff --git a/src/pytac/element.py b/src/pytac/element.py index d028f783..51d7cfcb 100644 --- a/src/pytac/element.py +++ b/src/pytac/element.py @@ -226,7 +226,7 @@ def is_in_family(self, family): """ return family.lower() in self._families - def get_value( + async def get_value( self, field, handle=pytac.RB, @@ -257,7 +257,7 @@ def get_value( FieldException: if the element does not have the specified field. """ try: - return self._data_source_manager.get_value( + return await self._data_source_manager.get_value( field, handle, units, data_source, throw ) except DataSourceException as e: @@ -265,7 +265,7 @@ def get_value( except FieldException as e: raise FieldException(f"{self}: {e}") from e - def set_value( + async def set_value( self, field, value, @@ -290,7 +290,9 @@ def set_value( FieldException: if the element does not have the specified field. """ try: - self._data_source_manager.set_value(field, value, units, data_source, throw) + await self._data_source_manager.set_value( + field, value, units, data_source, throw + ) except DataSourceException as e: raise DataSourceException(f"{self}: {e}") from e except FieldException as e: diff --git a/src/pytac/lattice.py b/src/pytac/lattice.py index 579525be..f5d84e7d 100644 --- a/src/pytac/lattice.py +++ b/src/pytac/lattice.py @@ -90,6 +90,7 @@ def __getitem__(self, n: int) -> Element: Returns: indexed element """ + # TODO: We should probably raise a custom exception if len(_elements) is zero return self._elements[n] def __len__(self) -> int: @@ -180,7 +181,7 @@ def set_unitconv(self, field, uc): """ self._data_source_manager.set_unitconv(field, uc) - def get_value( + async def get_value( self, field, handle=pytac.RB, @@ -210,11 +211,11 @@ def get_value( DataSourceException: if there is no data source on the given field. FieldException: if the lattice does not have the specified field. """ - return self._data_source_manager.get_value( + return await self._data_source_manager.get_value( field, handle, units, data_source, throw ) - def set_value( + async def set_value( self, field, value, @@ -238,7 +239,9 @@ def set_value( DataSourceException: if arguments are incorrect. FieldException: if the lattice does not have the specified field. """ - self._data_source_manager.set_value(field, value, units, data_source, throw) + await self._data_source_manager.set_value( + field, value, units, data_source, throw + ) def get_length(self): """Returns the length of the lattice, in meters. @@ -359,7 +362,7 @@ def get_element_device_names(self, family, field): devices = self.get_element_devices(family, field) return [device.name for device in devices] - def get_element_values( + async def get_element_values( self, family, field, @@ -389,14 +392,14 @@ def get_element_values( """ elements = self.get_elements(family) values = [ - element.get_value(field, handle, units, data_source, throw) + await element.get_value(field, handle, units, data_source, throw) for element in elements ] if dtype is not None: values = numpy.array(values, dtype=dtype) return values - def set_element_values( + async def set_element_values( self, family, field, @@ -430,7 +433,7 @@ def set_element_values( f"equal to the number of elements in the family({len(elements)})." ) for element, value in zip(elements, values, strict=False): - status = element.set_value( + status = await element.set_value( field, value, units=units, @@ -601,7 +604,7 @@ def get_element_pv_names(self, family, field, handle): pv_names.append(element.get_pv_name(field, handle)) return pv_names - def get_element_values( + async def get_element_values( self, family, field, @@ -635,20 +638,20 @@ def get_element_values( units = self.get_default_units() if data_source == pytac.LIVE: pv_names = self.get_element_pv_names(family, field, handle) - values = self._cs.get_multiple(pv_names, throw) + values = await self._cs.get_multiple(pv_names, throw) if units == pytac.PHYS: values = self.convert_family_values( family, field, values, pytac.ENG, pytac.PHYS ) else: - values = super().get_element_values( + values = await super().get_element_values( family, field, handle, units, data_source, throw ) if dtype is not None: values = numpy.array(values, dtype=dtype) return values - def set_element_values( + async def set_element_values( self, family, field, @@ -689,6 +692,8 @@ def set_element_values( "must be equal to the number of elements in " f"the family({len(pv_names)})." ) - self._cs.set_multiple(pv_names, values, throw) + await self._cs.set_multiple(pv_names, values, throw) else: - super().set_element_values(family, field, values, units, data_source, throw) + await super().set_element_values( + family, field, values, units, data_source, throw + ) diff --git a/src/pytac/load_csv.py b/src/pytac/load_csv.py index 7f9a3fc5..e9fa37f5 100644 --- a/src/pytac/load_csv.py +++ b/src/pytac/load_csv.py @@ -139,7 +139,7 @@ def resolve_unitconv( return uc -def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: +async def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: """Load the unit conversion objects from a file. Args: @@ -180,13 +180,15 @@ def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: "bend", } if item["uc_type"] != "null" and element._families & rigidity_families: # noqa: SLF001 - energy = lattice.get_value("energy", units=pytac.ENG) + energy = await lattice.get_value("energy", units=pytac.ENG) uc.set_post_eng_to_phys(utils.get_div_rigidity(energy)) uc.set_pre_phys_to_eng(utils.get_mult_rigidity(energy)) element.set_unitconv(item["field"], uc) -def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLattice: +async def load( + mode, control_system=None, directory=None, symmetry=None +) -> EpicsLattice: """Load the elements of a lattice from a directory. Args: @@ -203,20 +205,20 @@ def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLatti Lattice: The lattice containing all elements. Raises: - ControlSystemException: if the default control system, cothread, is not + ControlSystemException: if the default control system, aioca, is not installed. """ try: if control_system is None: # Don't import epics unless we need it to avoid unnecessary - # installation of cothread - from pytac import cothread_cs + # installation of aioca + from pytac import aioca_cs - control_system = cothread_cs.CothreadControlSystem() + control_system = aioca_cs.AIOCAControlSystem() except ImportError: raise ControlSystemException( - "Please install cothread to load a lattice using the default control system" - " (found in cothread_cs.py)." + "Please install aioca to load a lattice using the default control system" + " (found in aioca_cs.py)." ) from ImportError if directory is None: directory = Path(__file__).resolve().parent / "data" @@ -275,7 +277,7 @@ def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLatti lat[int(item["el_id"]) - 1].add_to_family(item["family"]) unitconv_file = mode_dir / UNITCONV_FILENAME if unitconv_file.exists(): - load_unitconv(mode_dir, lat) + await load_unitconv(mode_dir, lat) return lat diff --git a/tests/conftest.py b/tests/conftest.py index b2b4500e..c307df0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -import sys -import types from unittest import mock import pytest @@ -22,39 +20,15 @@ from pytac.units import PolyUnitConv -def pytest_sessionstart(): - """Create a dummy cothread module. - - cothread is not trivial to import, so it is better to mock it before any - tests run. In particular, we need catools (the module that pytac imports - from cothread), including the functions that pytac explicitly imports - (caget and caput). - """ - - class ca_nothing(Exception): # noqa: N801, N818 - """A minimal mock of the cothread ca_nothing exception class.""" - - def __init__(self, name, errorcode=True): - self.ok = errorcode - self.name = name - - cothread = types.ModuleType("cothread") - catools = types.ModuleType("catools") - catools.caget = mock.MagicMock() - catools.caput = mock.MagicMock() - catools.ca_nothing = ca_nothing - cothread.catools = catools - - sys.modules["cothread"] = cothread - sys.modules["cothread.catools"] = catools - - # Create mock devices and attach them to the element @pytest.fixture def x_device(): x_device = mock.MagicMock() x_device.name = "x_device" + x_device.get_value = mock.AsyncMock() x_device.get_value.return_value = DUMMY_VALUE_1 + x_device.set_value = mock.AsyncMock() + return x_device @@ -62,7 +36,10 @@ def x_device(): def y_device(): y_device = mock.MagicMock() y_device.name = "y_device" + y_device.get_value = mock.AsyncMock() y_device.get_pv_name.return_value = SP_PV + y_device.set_value = mock.AsyncMock() + return y_device @@ -71,7 +48,11 @@ def y_device(): def mock_sim_data_source(): mock_sim_data_source = mock.MagicMock() mock_sim_data_source.units = pytac.PHYS + + mock_sim_data_source.get_value = mock.AsyncMock() mock_sim_data_source.get_value.return_value = DUMMY_VALUE_2 + mock_sim_data_source.set_value = mock.AsyncMock() + return mock_sim_data_source @@ -123,19 +104,18 @@ def simple_data_source_manager( @pytest.fixture(scope="session") -def i04_ring(): - return pytac.load_csv.load("I04", mock.MagicMock, symmetry=24) +async def i04_ring(): + return await pytac.load_csv.load("I04", mock.MagicMock, symmetry=24) @pytest.fixture(scope="session") -def diad_ring(): - return pytac.load_csv.load("DIAD", mock.MagicMock, symmetry=24) +async def diad_ring(): + return await pytac.load_csv.load("DIAD", mock.MagicMock, symmetry=24) @pytest.fixture -def lattice(): - lat = load_csv.load("dummy", mock.MagicMock(), CURRENT_DIR_PATH / "data", 2) - return lat +async def lattice(): + return await load_csv.load("dummy", mock.MagicMock(), CURRENT_DIR_PATH / "data", 2) def set_func(pvs, values, throw=None): @@ -146,8 +126,11 @@ def set_func(pvs, values, throw=None): @pytest.fixture def mock_cs(): cs = mock.MagicMock() + cs.get_single = mock.AsyncMock() cs.get_single.return_value = DUMMY_VALUE_1 + cs.get_multiple = mock.AsyncMock() cs.get_multiple.return_value = DUMMY_ARRAY + cs.set_multiple = mock.AsyncMock() cs.set_multiple.side_effect = set_func return cs diff --git a/tests/test_aioca_cs.py b/tests/test_aioca_cs.py new file mode 100644 index 00000000..45a55f0a --- /dev/null +++ b/tests/test_aioca_cs.py @@ -0,0 +1,121 @@ +"""Tests for the AIOCAControlSystem class.""" + +from unittest.mock import MagicMock, patch + +import pytest +from testfixtures import LogCapture + +import pytac +from constants import RB_PV, SP_PV +from pytac.aioca_cs import AIOCAControlSystem + + +class CANothing(Exception): # noqa: N818 + """A minimal mock of the aioca CANothing exception class.""" + + def __init__(self, name, errorcode=True): + self.ok = errorcode + self.name = name + + +@pytest.fixture +def cs(): + return AIOCAControlSystem(wait=True, timeout=2.0) + + +@patch("pytac.aioca_cs.caget") +async def test_get_single_calls_caget_correctly(caget: MagicMock, cs): + caget.return_value = 42 + assert (await cs.get_single(RB_PV)) == 42 + caget.assert_called_with(RB_PV, throw=True, timeout=2.0) + + +@patch("pytac.aioca_cs.caget") +async def test_get_multiple_calls_caget_correctly(caget: MagicMock, cs): + """caget is called with throw=False despite throw=True being the default + for get_multiple as we always want our get operation to fully complete, + rather than being stopped halway through by an error raised from + aioca, so that even if one get operation to a PV fails the rest will + complete sucessfully. + """ + caget.return_value = [42, 6] + assert await cs.get_multiple([RB_PV, SP_PV]) == [42, 6] + caget.assert_called_with([RB_PV, SP_PV], throw=False, timeout=2.0) + + +@patch("pytac.aioca_cs.caput") +async def test_set_single_calls_caput_correctly(caput: MagicMock, cs): + assert await cs.set_single(SP_PV, 42) is True + caput.assert_called_with(SP_PV, 42, throw=True, timeout=2.0, wait=True) + + +@patch("pytac.aioca_cs.caput") +async def test_set_multiple_calls_caput_correctly(caput: MagicMock, cs): + """caput is called with throw=False despite throw=True being the default + for set_multiple as we always want our set operation to fully complete, + rather than being stopped halway through by an error raised from + aioca, so that even if one set operation to a PV fails the rest will + complete sucessfully. + """ + await cs.set_multiple([SP_PV, RB_PV], [42, 6]) + caput.assert_called_with( + [SP_PV, RB_PV], [42, 6], throw=False, timeout=2.0, wait=True + ) + + +@patch("pytac.aioca_cs.caget") +@patch("pytac.aioca_cs.CANothing", CANothing) +async def test_get_multiple_raises_control_system_exception(caget: MagicMock, cs): + """Here we check that errors are thrown, suppressed and logged correctly.""" + caget.return_value = [12, CANothing("pv", False)] + with pytest.raises(pytac.exceptions.ControlSystemException): + await cs.get_multiple([RB_PV, SP_PV]) + with LogCapture() as log: + assert await cs.get_multiple([RB_PV, SP_PV], throw=False) == [12, None] + log.check(("root", "WARNING", "Cannot connect to pv.")) + + +@patch("pytac.aioca_cs.caput") +@patch("pytac.aioca_cs.CANothing", CANothing) +async def test_set_multiple_raises_control_system_exception(caput: MagicMock, cs): + """Here we check that errors are thrown, suppressed and logged correctly.""" + caput.return_value = [CANothing("pv1", True), CANothing("pv2", False)] + with pytest.raises(pytac.exceptions.ControlSystemException): + await cs.set_multiple([RB_PV, SP_PV], [42, 6]) + with LogCapture() as log: + assert await cs.set_multiple([RB_PV, SP_PV], [42, 6], throw=False) == [ + True, + False, + ] + log.check(("root", "WARNING", "Cannot connect to pv2.")) + + +@patch("pytac.aioca_cs.caget") +@patch("pytac.aioca_cs.CANothing", CANothing) +async def test_get_single_raises_control_system_exception(caget: MagicMock, cs): + """Here we check that errors are thrown, suppressed and logged correctly.""" + caget.side_effect = CANothing("pv", False) + with LogCapture() as log: + assert await cs.get_single(RB_PV, throw=False) is None + with pytest.raises(pytac.exceptions.ControlSystemException): + await cs.get_single(RB_PV, throw=True) + log.check(("root", "WARNING", "Cannot connect to prefix:rb.")) + + +@patch("pytac.aioca_cs.caput") +@patch("pytac.aioca_cs.CANothing", CANothing) +async def test_set_single_raises_control_system_exception(caput: MagicMock, cs): + """Here we check that errors are thrown, suppressed and logged correctly.""" + caput.side_effect = CANothing("pv", False) + with LogCapture() as log: + assert await cs.set_single(SP_PV, 42, throw=False) is False + with pytest.raises(pytac.exceptions.ControlSystemException): + await cs.set_single(SP_PV, 42, throw=True) + log.check(("root", "WARNING", "Cannot connect to prefix:sp.")) + + +async def test_set_multiple_raises_value_error_on_input_length_mismatch(cs): + with pytest.raises(ValueError): + await cs.set_multiple([SP_PV], [42, 6]) + with pytest.raises(ValueError): + await cs.set_multiple([SP_PV, RB_PV], [42]) diff --git a/tests/test_cothread_cs.py b/tests/test_cothread_cs.py deleted file mode 100644 index 288ef7ff..00000000 --- a/tests/test_cothread_cs.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Tests for the CothreadControlSystem class. - -This module depends on the cothread module being mocked. - -See pytest_sessionstart() in conftest.py for more. -""" - -import pytest -from cothread.catools import ca_nothing, caget, caput -from testfixtures import LogCapture - -import pytac -from constants import RB_PV, SP_PV -from pytac.cothread_cs import CothreadControlSystem - - -@pytest.fixture -def cs(): - return CothreadControlSystem(wait=True, timeout=2.0) - - -def test_get_single_calls_caget_correctly(cs): - caget.return_value = 42 - assert cs.get_single(RB_PV) == 42 - caget.assert_called_with(RB_PV, throw=True, timeout=2.0) - - -def test_get_multiple_calls_caget_correctly(cs): - """caget is called with throw=False despite throw=True being the default - for get_multiple as we always want our get operation to fully complete, - rather than being stopped halway through by an error raised from - cothread, so that even if one get operation to a PV fails the rest will - complete sucessfully. - """ - caget.return_value = [42, 6] - assert cs.get_multiple([RB_PV, SP_PV]) == [42, 6] - caget.assert_called_with([RB_PV, SP_PV], throw=False, timeout=2.0) - - -def test_set_single_calls_caput_correctly(cs): - assert cs.set_single(SP_PV, 42) is True - caput.assert_called_with(SP_PV, 42, throw=True, timeout=2.0, wait=True) - - -def test_set_multiple_calls_caput_correctly(cs): - """caput is called with throw=False despite throw=True being the default - for set_multiple as we always want our set operation to fully complete, - rather than being stopped halway through by an error raised from - cothread, so that even if one set operation to a PV fails the rest will - complete sucessfully. - """ - cs.set_multiple([SP_PV, RB_PV], [42, 6]) - caput.assert_called_with( - [SP_PV, RB_PV], [42, 6], throw=False, timeout=2.0, wait=True - ) - - -def test_get_multiple_raises_control_system_exception(cs): - """Here we check that errors are thrown, suppressed and logged correctly.""" - caget.return_value = [12, ca_nothing("pv", False)] - with pytest.raises(pytac.exceptions.ControlSystemException): - cs.get_multiple([RB_PV, SP_PV]) - with LogCapture() as log: - assert cs.get_multiple([RB_PV, SP_PV], throw=False) == [12, None] - log.check(("root", "WARNING", "Cannot connect to pv.")) - - -def test_set_multiple_raises_control_system_exception(cs): - """Here we check that errors are thrown, suppressed and logged correctly.""" - caput.return_value = [ca_nothing("pv1", True), ca_nothing("pv2", False)] - with pytest.raises(pytac.exceptions.ControlSystemException): - cs.set_multiple([RB_PV, SP_PV], [42, 6]) - with LogCapture() as log: - assert cs.set_multiple([RB_PV, SP_PV], [42, 6], throw=False) == [True, False] - log.check(("root", "WARNING", "Cannot connect to pv2.")) - - -def test_get_single_raises_control_system_exception(cs): - """Here we check that errors are thrown, suppressed and logged correctly.""" - caget.side_effect = ca_nothing("pv", False) - with LogCapture() as log: - assert cs.get_single(RB_PV, throw=False) is None - with pytest.raises(pytac.exceptions.ControlSystemException): - cs.get_single(RB_PV, throw=True) - log.check(("root", "WARNING", "Cannot connect to prefix:rb.")) - - -def test_set_single_raises_control_system_exception(cs): - """Here we check that errors are thrown, suppressed and logged correctly.""" - caput.side_effect = ca_nothing("pv", False) - with LogCapture() as log: - assert cs.set_single(SP_PV, 42, throw=False) is False - with pytest.raises(pytac.exceptions.ControlSystemException): - cs.set_single(SP_PV, 42, throw=True) - log.check(("root", "WARNING", "Cannot connect to prefix:sp.")) - - -def test_set_multiple_raises_value_error_on_input_length_mismatch(cs): - with pytest.raises(ValueError): - cs.set_multiple([SP_PV], [42, 6]) - with pytest.raises(ValueError): - cs.set_multiple([SP_PV, RB_PV], [42]) diff --git a/tests/test_data_source.py b/tests/test_data_source.py index 14f522a6..95be4db1 100644 --- a/tests/test_data_source.py +++ b/tests/test_data_source.py @@ -32,26 +32,27 @@ def test_get_fields(simple_object, request): @pytest.mark.parametrize( "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_set_value(simple_object, request): +async def test_set_value(simple_object, request): simple_object = request.getfixturevalue(simple_object) - simple_object.set_value("x", DUMMY_VALUE_2, pytac.ENG, pytac.LIVE) + await simple_object.set_value("x", DUMMY_VALUE_2, pytac.ENG, pytac.LIVE) simple_object.get_device("x").set_value.assert_called_with(DUMMY_VALUE_2, True) @pytest.mark.parametrize( "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_value_sim(simple_object, request): +async def test_get_value_sim(simple_object, request): simple_object = request.getfixturevalue(simple_object) assert ( - simple_object.get_value("x", pytac.RB, pytac.PHYS, pytac.SIM) == DUMMY_VALUE_2 + await simple_object.get_value("x", pytac.RB, pytac.PHYS, pytac.SIM) + == DUMMY_VALUE_2 ) @pytest.mark.parametrize( "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_unit_conversion(simple_object, double_uc, request): +async def test_unit_conversion(simple_object, double_uc, request): simple_object = request.getfixturevalue(simple_object) - simple_object.set_value("y", DUMMY_VALUE_2, pytac.PHYS, pytac.LIVE) + await simple_object.set_value("y", DUMMY_VALUE_2, pytac.PHYS, pytac.LIVE) simple_object.get_device("y").set_value.assert_called_with(DUMMY_VALUE_2 / 2, True) diff --git a/tests/test_device.py b/tests/test_device.py index 245db4fc..9b7d79f8 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -10,6 +10,8 @@ def create_epics_device(prefix=PREFIX, rb_pv=RB_PV, sp_pv=SP_PV, enabled=True): mock_cs = mock.MagicMock() + mock_cs.set_single = mock.AsyncMock() + mock_cs.get_single = mock.AsyncMock() mock_cs.get_single.return_value = 40.0 device = EpicsDevice(prefix, mock_cs, enabled=enabled, rb_pv=rb_pv, sp_pv=sp_pv) return device @@ -21,27 +23,27 @@ def create_simple_device(value=1.0, enabled=True): # Epics device specific tests. -def test_set_epics_device_value(): +async def test_set_epics_device_value(): device = create_epics_device() - device.set_value(40) + await device.set_value(40) device._cs.set_single.assert_called_with(SP_PV, 40, True) -def test_get_epics_device_value(): +async def test_get_epics_device_value(): device = create_epics_device() - assert device.get_value(pytac.SP) == 40.0 + assert await device.get_value(pytac.SP) == 40.0 -def test_epics_device_invalid_sp_raises_exception(): +async def test_epics_device_invalid_sp_raises_exception(): device2 = create_epics_device(PREFIX, RB_PV, None) with pytest.raises(pytac.exceptions.HandleException): - device2.set_value(40) + await device2.set_value(40) -def test_get_epics_device_value_invalid_handle_raises_exception(): +async def test_get_epics_device_value_invalid_handle_raises_exception(): device = create_epics_device() with pytest.raises(pytac.exceptions.HandleException): - device.get_value("non_existent") + await device.get_value("non_existent") # Simple device specific tests. @@ -93,8 +95,8 @@ def test_device_is_enabled_returns_bool_value(device_creation_function): # PvEnabler test. -def test_pv_enabler(mock_cs): +async def test_pv_enabler(mock_cs): pve = PvEnabler("enable-pv", 40, mock_cs) - assert pve + assert await pve.is_enabled() mock_cs.get_single.return_value = 50 - assert not pve + assert not await pve.is_enabled() diff --git a/tests/test_element.py b/tests/test_element.py index 7efb4b2d..a248e2d5 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -89,16 +89,16 @@ def test_get_unitconv_raises_field_exception_if_device_not_present(simple_elemen simple_element.get_unitconv("not-a-device") -def test_get_value_uses_uc_if_necessary_for_cs_call(simple_element, double_uc): +async def test_get_value_uses_uc_if_necessary_for_cs_call(simple_element, double_uc): simple_element._data_source_manager._uc["x"] = double_uc - assert simple_element.get_value( + assert await simple_element.get_value( "x", handle=pytac.SP, units=pytac.PHYS, data_source=pytac.LIVE ) == (DUMMY_VALUE_1 * 2) -def test_get_value_uses_uc_if_necessary_for_sim_call(simple_element, double_uc): +async def test_get_value_uses_uc_if_necessary_for_sim_call(simple_element, double_uc): simple_element._data_source_manager._uc["x"] = double_uc - assert simple_element.get_value( + assert await simple_element.get_value( "x", handle=pytac.SP, units=pytac.ENG, data_source=pytac.SIM ) == (DUMMY_VALUE_2 / 2) simple_element._data_source_manager._data_sources[ @@ -106,39 +106,41 @@ def test_get_value_uses_uc_if_necessary_for_sim_call(simple_element, double_uc): ].get_value.assert_called_with("x", pytac.SP, True) -def test_set_value_eng(simple_element): - simple_element.set_value("x", DUMMY_VALUE_2) +async def test_set_value_eng(simple_element): + await simple_element.set_value("x", DUMMY_VALUE_2) # No conversion needed simple_element.get_device("x").set_value.assert_called_with(DUMMY_VALUE_2, True) -def test_set_value_phys(simple_element, double_uc): +async def test_set_value_phys(simple_element, double_uc): simple_element._data_source_manager._uc["x"] = double_uc - simple_element.set_value("x", DUMMY_VALUE_2, units=pytac.PHYS) + await simple_element.set_value("x", DUMMY_VALUE_2, units=pytac.PHYS) # Conversion fron physics to engineering units simple_element.get_device("x").set_value.assert_called_with(DUMMY_VALUE_2 / 2, True) -def test_set_exceptions(simple_element, unit_uc): +async def test_set_exceptions(simple_element, unit_uc): with pytest.raises(pytac.exceptions.FieldException): - simple_element.set_value("unknown_field", 40.0) + await simple_element.set_value("unknown_field", 40.0) with pytest.raises(pytac.exceptions.DataSourceException): - simple_element.set_value("y", 40.0, data_source="unknown_data_source") + await simple_element.set_value("y", 40.0, data_source="unknown_data_source") simple_element._data_source_manager._uc["uc_but_no_data_source"] = unit_uc with pytest.raises(pytac.exceptions.FieldException): - simple_element.set_value("uc_but_no_data_source", 40.0) + await simple_element.set_value("uc_but_no_data_source", 40.0) -def test_get_exceptions(simple_element): +async def test_get_exceptions(simple_element): with pytest.raises(pytac.exceptions.FieldException): - simple_element.get_value("unknown_field", "setpoint") + await simple_element.get_value("unknown_field", "setpoint") with pytest.raises(pytac.exceptions.DataSourceException): - simple_element.get_value("y", "setpoint", data_source="unknown_data_source") + await simple_element.get_value( + "y", "setpoint", data_source="unknown_data_source" + ) -def test_identity_conversion(simple_element): - value_physics = simple_element.get_value("x", "setpoint", pytac.PHYS) - value_machine = simple_element.get_value("x", "setpoint", pytac.ENG) +async def test_identity_conversion(simple_element): + value_physics = await simple_element.get_value("x", "setpoint", pytac.PHYS) + value_machine = await simple_element.get_value("x", "setpoint", pytac.ENG) assert value_machine == DUMMY_VALUE_1 assert value_physics == DUMMY_VALUE_1 diff --git a/tests/test_epics.py b/tests/test_epics.py index dd5462a9..e0b417f0 100644 --- a/tests/test_epics.py +++ b/tests/test_epics.py @@ -7,46 +7,46 @@ from constants import DUMMY_ARRAY, RB_PV, SP_PV -def test_get_values_live(simple_epics_lattice, mock_cs): - simple_epics_lattice.get_element_values("family", "x", pytac.RB, pytac.PHYS) +async def test_get_values_live(simple_epics_lattice, mock_cs): + await simple_epics_lattice.get_element_values("family", "x", pytac.RB, pytac.PHYS) mock_cs.get_multiple.assert_called_with([RB_PV], True) -def test_get_values_sim(simple_epics_lattice): - mock_ds = mock.Mock(units=pytac.PHYS) +async def test_get_values_sim(simple_epics_lattice): + mock_ds = mock.AsyncMock(units=pytac.PHYS) mock_uc = mock.Mock() simple_epics_lattice[0].set_data_source(mock_ds, pytac.SIM) simple_epics_lattice[0].set_unitconv("a_field", mock_uc) - simple_epics_lattice.get_element_values( + await simple_epics_lattice.get_element_values( "family", "a_field", pytac.RB, pytac.ENG, pytac.SIM ) mock_ds.get_value.assert_called_with("a_field", pytac.RB, True) mock_uc.convert.assert_called_once() -def test_set_element_values_live(simple_epics_lattice, mock_cs): - simple_epics_lattice.set_element_values("family", "x", [1], units=pytac.PHYS) +async def test_set_element_values_live(simple_epics_lattice, mock_cs): + await simple_epics_lattice.set_element_values("family", "x", [1], units=pytac.PHYS) mock_cs.set_multiple.assert_called_with([SP_PV], [1], True) -def test_set_element_values_sim(simple_epics_lattice): - mock_ds = mock.Mock(units=pytac.PHYS) +async def test_set_element_values_sim(simple_epics_lattice): + mock_ds = mock.AsyncMock(units=pytac.PHYS) mock_uc = mock.Mock() mock_uc.convert.return_value = 1 simple_epics_lattice[0].set_data_source(mock_ds, pytac.SIM) simple_epics_lattice[0].set_unitconv("a_field", mock_uc) - simple_epics_lattice.set_element_values( + await simple_epics_lattice.set_element_values( "family", "a_field", [1], pytac.ENG, pytac.SIM ) mock_ds.set_value.assert_called_with("a_field", 1, True) mock_uc.convert.assert_called_once_with(1, origin=pytac.ENG, target=pytac.PHYS) -def test_set_element_values_raises_correctly(simple_epics_lattice): +async def test_set_element_values_raises_correctly(simple_epics_lattice): with pytest.raises(IndexError): - simple_epics_lattice.set_element_values("family", "x", [1, 2]) + await simple_epics_lattice.set_element_values("family", "x", [1, 2]) with pytest.raises(IndexError): - simple_epics_lattice.set_element_values( + await simple_epics_lattice.set_element_values( "family", "x", [1, 2], data_source=pytac.SIM ) @@ -60,10 +60,10 @@ def test_set_element_values_raises_correctly(simple_epics_lattice): (None, DUMMY_ARRAY), ), ) -def test_get_values_returns_numpy_array_if_requested( +async def test_get_values_returns_numpy_array_if_requested( simple_epics_lattice, dtype, expected, mock_cs ): - values = simple_epics_lattice.get_element_values( + values = await simple_epics_lattice.get_element_values( "family", "x", pytac.RB, dtype=dtype ) numpy.testing.assert_equal(values, expected) @@ -86,16 +86,16 @@ def test_get_lattice_pv_name(pv_type, simple_epics_lattice): simple_epics_lattice.get_pv_name("not_a_field", pv_type) -def test_get_value_uses_cs_if_data_source_live(simple_epics_element, mock_cs): - simple_epics_element.get_value("x", handle=pytac.SP, data_source=pytac.LIVE) +async def test_get_value_uses_cs_if_data_source_live(simple_epics_element, mock_cs): + await simple_epics_element.get_value("x", handle=pytac.SP, data_source=pytac.LIVE) mock_cs.get_single.assert_called_with(SP_PV, True) - simple_epics_element.get_value("x", handle=pytac.RB, data_source=pytac.LIVE) + await simple_epics_element.get_value("x", handle=pytac.RB, data_source=pytac.LIVE) mock_cs.get_single.assert_called_with(RB_PV, True) -def test_get_value_raises_handle_exceptions(simple_epics_element): +async def test_get_value_raises_handle_exceptions(simple_epics_element): with pytest.raises(pytac.exceptions.HandleException): - simple_epics_element.get_value("y", "unknown_handle") + await simple_epics_element.get_value("y", "unknown_handle") def test_lattice_get_pv_name_raises_data_source_exception(simple_epics_lattice): @@ -107,22 +107,24 @@ def test_lattice_get_pv_name_raises_data_source_exception(simple_epics_lattice): basic_epics_lattice.get_pv_name("x", pytac.RB) -def test_set_element_values_length_mismatch_raises_index_error(simple_epics_lattice): +async def test_set_element_values_length_mismatch_raises_index_error( + simple_epics_lattice, +): with pytest.raises(IndexError): - simple_epics_lattice.set_element_values("family", "x", [1, 2]) + await simple_epics_lattice.set_element_values("family", "x", [1, 2]) with pytest.raises(IndexError): - simple_epics_lattice.set_element_values("family", "x", []) + await simple_epics_lattice.set_element_values("family", "x", []) -def test_element_get_pv_name_raises_exceptions(simple_epics_element): +async def test_element_get_pv_name_raises_exceptions(simple_epics_element): with pytest.raises(pytac.exceptions.FieldException): - simple_epics_element.get_pv_name("unknown_field", "setpoint") + await simple_epics_element.get_pv_name("unknown_field", "setpoint") basic_epics_element = simple_epics_element with pytest.raises(pytac.exceptions.DataSourceException): - basic_epics_element.get_pv_name("basic", pytac.RB) + await basic_epics_element.get_pv_name("basic", pytac.RB) del basic_epics_element._data_source_manager._data_sources[pytac.LIVE] with pytest.raises(pytac.exceptions.DataSourceException): - basic_epics_element.get_pv_name("x", pytac.RB) + await basic_epics_element.get_pv_name("x", pytac.RB) def test_create_epics_device_raises_data_source_exception_if_no_PVs_are_given(): # noqa: N802 diff --git a/tests/test_invalid_classes.py b/tests/test_invalid_classes.py index d662d798..dd4e4afc 100644 --- a/tests/test_invalid_classes.py +++ b/tests/test_invalid_classes.py @@ -3,7 +3,7 @@ from pytac import cs, data_source, device -def test_control_system_throws_not_implemented_error(): +async def test_control_system_throws_not_implemented_error(): test_cs = cs.ControlSystem() with pytest.raises(NotImplementedError): test_cs.get_single("dummy", "throw") @@ -15,7 +15,7 @@ def test_control_system_throws_not_implemented_error(): test_cs.set_multiple(["dummy_1", "dummy_2"], [1, 2], "throw") -def test_data_source_throws_not_implemented_error(): +async def test_data_source_throws_not_implemented_error(): test_ds = data_source.DataSource() with pytest.raises(NotImplementedError): test_ds.get_fields() @@ -25,7 +25,7 @@ def test_data_source_throws_not_implemented_error(): test_ds.set_value("field", 0.0, "throw") -def test_device_throws_not_implemented_error(): +async def test_device_throws_not_implemented_error(): test_d = device.Device() with pytest.raises(NotImplementedError): test_d.is_enabled() diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 1605aaf7..310af98d 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -79,18 +79,18 @@ def test_get_and_set_unitconv(): assert lat.get_unitconv("field1") == uc -def test_get_value_raises_exceptions_correctly(simple_lattice): +async def test_get_value_raises_exceptions_correctly(simple_lattice): with pytest.raises(pytac.exceptions.DataSourceException): - simple_lattice.get_value("x", data_source="not_a_data_source") + await simple_lattice.get_value("x", data_source="not_a_data_source") with pytest.raises(pytac.exceptions.FieldException): - simple_lattice.get_value("not_a_field") + await simple_lattice.get_value("not_a_field") -def test_set_value_raises_exceptions_correctly(simple_lattice): +async def test_set_value_raises_exceptions_correctly(simple_lattice): with pytest.raises(pytac.exceptions.DataSourceException): - simple_lattice.set_value("x", 0, data_source="not_a_data_source") + await simple_lattice.set_value("x", 0, data_source="not_a_data_source") with pytest.raises(pytac.exceptions.FieldException): - simple_lattice.set_value("not_a_field", 0) + await simple_lattice.set_value("not_a_field", 0) def test_get_element_devices_raises_value_error_for_mismatched_family(simple_lattice): @@ -148,8 +148,8 @@ def test_get_all_families(simple_lattice): assert list(families) == ["family"] -def test_get_element_values(simple_lattice): - simple_lattice.get_element_values("family", "x", pytac.RB) +async def test_get_element_values(simple_lattice): + await simple_lattice.get_element_values("family", "x", pytac.RB) simple_lattice.get_element_devices("family", "x")[0].get_value.assert_called_with( pytac.RB, True ) @@ -164,25 +164,27 @@ def test_get_element_values(simple_lattice): (None, DUMMY_ARRAY), ), ) -def test_get_element_values_returns_numpy_array_if_requested( +async def test_get_element_values_returns_numpy_array_if_requested( simple_lattice, dtype, expected ): - values = simple_lattice.get_element_values("family", "x", pytac.RB, dtype=dtype) + values = await simple_lattice.get_element_values( + "family", "x", pytac.RB, dtype=dtype + ) numpy.testing.assert_equal(values, expected) -def test_set_element_values(simple_lattice): - simple_lattice.set_element_values("family", "x", [1]) +async def test_set_element_values(simple_lattice): + await simple_lattice.set_element_values("family", "x", [1]) simple_lattice.get_element_devices("family", "x")[0].set_value.assert_called_with( 1, True ) -def test_set_element_values_raises_exceptions_correctly(simple_lattice): +async def test_set_element_values_raises_exceptions_correctly(simple_lattice): with pytest.raises(IndexError): - simple_lattice.set_element_values("family", "x", [1, 2]) + await simple_lattice.set_element_values("family", "x", [1, 2]) with pytest.raises(IndexError): - simple_lattice.set_element_values("family", "x", []) + await simple_lattice.set_element_values("family", "x", []) def test_get_family_s(simple_lattice): diff --git a/tests/test_load.py b/tests/test_load.py index 9df124e1..4d329dff 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -11,42 +11,42 @@ @pytest.fixture def mock_cs_raises_import_error(): - """We create a mock control system to replace CothreadControlSystem, so + """We create a mock control system to replace AIOCAControlSystem, so that we can check that when it raises an ImportError load_csv.load - catches it and raises a ControlSystemException instead. - N.B. Our new CothreadControlSystem is nested inside a fixture so it can be - patched into pytac.cothread_cs to replace the existing - CothreadControlSystem class. The new CothreadControlSystem created here is + catches it and raises a AIOCASystemException instead. + N.B. Our new AIOCAControlSystem is nested inside a fixture so it can be + patched into pytac.aioca_cs to replace the existing + AIOCAControlSystem class. The new AIOCAControlSystem created here is a function not a class (like the original) to prevent it from raising the ImportError when the code is compiled. """ - class CothreadControlSystem: + class AIOCAControlSystem: def __init__(self): raise ImportError - return CothreadControlSystem + return AIOCAControlSystem -def test_default_control_system_import(): +async def test_default_control_system_import(): """In this test we: - assert that the lattice is indeed loaded if no execeptions are raised - - assert that the default control system is indeed cothread and that it + - assert that the default control system is indeed aioca and that it is loaded onto the lattice correctly """ - assert bool(load(TESTING_MODE)) - assert isinstance(load(TESTING_MODE)._cs, pytac.cothread_cs.CothreadControlSystem) + assert bool(await load(TESTING_MODE)) + assert isinstance((await load(TESTING_MODE))._cs, pytac.aioca_cs.AIOCAControlSystem) -def test_import_fail_raises_control_system_exception(mock_cs_raises_import_error): +async def test_import_fail_raises_control_system_exception(mock_cs_raises_import_error): """In this test we: - - check that load corectly fails if cothread cannot be imported - - check that when the import of the CothreadControlSystem fails the + - check that load corectly fails if aioca cannot be imported + - check that when the import of the AIOCAControlSystem fails the ImportError raised is replaced with a ControlSystemException """ - with patch("pytac.cothread_cs.CothreadControlSystem", mock_cs_raises_import_error): + with patch("pytac.aioca_cs.AIOCAControlSystem", mock_cs_raises_import_error): with pytest.raises(pytac.exceptions.ControlSystemException): - load(TESTING_MODE) + await load("TESTING_MODE") def test_elements_loaded(lattice): @@ -82,11 +82,11 @@ def test_families_loaded(lattice): assert lattice.get_elements("quad")[0].families == {"quad", "qf", "qs"} -def test_load_unitconv_warns_if_pchip_or_poly_data_file_not_found( +async def test_load_unitconv_warns_if_pchip_or_poly_data_file_not_found( lattice, mode_dir, polyconv_file, pchipconv_file ): with LogCapture() as log: - load_unitconv(mode_dir, lattice) + await load_unitconv(mode_dir, lattice) log.check( ( "root", diff --git a/tests/test_machine.py b/tests/test_machine.py index 7481fd9d..fca8dbcf 100644 --- a/tests/test_machine.py +++ b/tests/test_machine.py @@ -21,8 +21,8 @@ def get_lattice(ring_mode): return lattice -def test_load_lattice_using_default_dir(): - lat = pytac.load_csv.load(TESTING_MODE, mock.MagicMock()) +async def test_load_lattice_using_default_dir(): + lat = await pytac.load_csv.load(TESTING_MODE, mock.MagicMock()) assert len(lat) == 2190