diff --git a/Changelog.md b/Changelog.md index 2b22fa5b..d7bbd980 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,7 @@ ## 0.9.3 () +- Bug #682: `AsyncTimeout` did not stop execution (thanks @matt3o) - Bug #683: Typing wrongly suggested that `Transition` instances can be passed to `Machine.__init__` and/or `Machine.add_transition(s)` (thanks @antonio-antuan) - Bug #692: When adding an already constructed `NestedState`, FunctionWrapper was not properly initialized (thanks drpjm) - Typing should be more precise now diff --git a/tests/test_async.py b/tests/test_async.py index 03e25302..71346a81 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,9 +1,12 @@ +from asyncio import CancelledError + from transitions.extensions.factory import AsyncGraphMachine, HierarchicalAsyncGraphMachine +from transitions.extensions.states import add_state_features try: import asyncio from transitions.extensions.asyncio import AsyncMachine, HierarchicalAsyncMachine, AsyncEventData, \ - AsyncTransition + AsyncTransition, AsyncTimeout except (ImportError, SyntaxError): asyncio = None # type: ignore @@ -342,9 +345,6 @@ async def run(): asyncio.run(run()) def test_async_timeout(self): - from transitions.extensions.states import add_state_features - from transitions.extensions.asyncio import AsyncTimeout - timeout_called = MagicMock() @add_state_features(AsyncTimeout) @@ -376,6 +376,72 @@ async def run(): asyncio.run(run()) + def test_timeout_cancel(self): + error_mock = MagicMock() + timout_mock = MagicMock() + long_op_mock = MagicMock() + + @add_state_features(AsyncTimeout) + class TimeoutMachine(self.machine_cls): # type: ignore + async def on_enter_B(self): + await asyncio.sleep(0.2) + long_op_mock() # should never be called + + async def handle_timeout(self): + timout_mock() + await self.to_A() + + machine = TimeoutMachine(states=["A", {"name": "B", "timeout": 0.1, "on_timeout": "handle_timeout"}], + initial="A", on_exception=error_mock) + + async def run(): + await machine.to_B() + assert timout_mock.called + assert error_mock.call_count == 1 # should only be one CancelledError + assert not long_op_mock.called + assert machine.is_A() + asyncio.run(run()) + + def test_queued_timeout_cancel(self): + error_mock = MagicMock() + timout_mock = MagicMock() + long_op_mock = MagicMock() + + @add_state_features(AsyncTimeout) + class TimeoutMachine(self.machine_cls): # type: ignore + async def long_op(self, event_data): + await self.to_C() + await self.to_D() + await self.to_E() + await asyncio.sleep(1) + long_op_mock() + + async def handle_timeout(self, event_data): + timout_mock() + raise TimeoutError() + + async def handle_error(self, event_data): + if isinstance(event_data.error, CancelledError): + if error_mock.called: + raise RuntimeError() + error_mock() + raise event_data.error + + machine = TimeoutMachine(states=["A", "C", "D", "E", + {"name": "B", "timeout": 0.1, "on_timeout": "handle_timeout", + "on_enter": "long_op"}], + initial="A", queued=True, send_event=True, on_exception="handle_error") + + async def run(): + await machine.to_B() + assert timout_mock.called + assert error_mock.called + assert not long_op_mock.called + assert machine.is_B() + with self.assertRaises(RuntimeError): + await machine.to_B() + asyncio.run(run()) + def test_callback_order(self): finished = [] diff --git a/transitions/extensions/asyncio.py b/transitions/extensions/asyncio.py index 161d419b..6155caf7 100644 --- a/transitions/extensions/asyncio.py +++ b/transitions/extensions/asyncio.py @@ -18,6 +18,7 @@ import asyncio import contextvars import inspect +import warnings from collections import deque from functools import partial, reduce import copy @@ -116,7 +117,7 @@ async def execute(self, event_data): machine = event_data.machine # cancel running tasks since the transition will happen - await machine.switch_model_context(event_data.model) + await machine.cancel_running_transitions(event_data.model, event_data.event.name) await event_data.machine.callbacks(event_data.machine.before_state_change, event_data) await event_data.machine.callbacks(self.before, event_data) @@ -189,7 +190,8 @@ async def _trigger(self, event_data): if self._is_valid_source(event_data.state): await self._process(event_data) except BaseException as err: # pylint: disable=broad-except; Exception will be handled elsewhere - _LOGGER.error("%sException was raised while processing the trigger: %s", self.machine.name, err) + _LOGGER.error("%sException was raised while processing the trigger '%s': %s", + self.machine.name, event_data.event.name, repr(err)) event_data.error = err if self.machine.on_exception: await self.machine.callbacks(self.machine.on_exception, event_data) @@ -374,18 +376,24 @@ async def await_all(callables): return await asyncio.gather(*[func() for func in callables]) async def switch_model_context(self, model): + warnings.warn("Please replace 'AsyncMachine.switch_model_context' with " + "'AsyncMachine.cancel_running_transitions'.", category=DeprecationWarning) + await self.cancel_running_transitions(model) + + async def cancel_running_transitions(self, model, msg=None): """ This method is called by an `AsyncTransition` when all conditional tests have passed and the transition will happen. This requires already running tasks to be cancelled. Args: model (object): The currently processed model + msg (str): Optional message to pass to a running task's cancel request """ for running_task in self.async_tasks.get(id(model), []): if self.current_context.get() == running_task or running_task in self.protected_tasks: continue if running_task.done() is False: _LOGGER.debug("Cancel running tasks...") - running_task.cancel() + running_task.cancel(msg) async def process_context(self, func, model): """ @@ -399,7 +407,7 @@ async def process_context(self, func, model): bool: returns the success state of the triggered event """ if self.current_context.get() is None: - self.current_context.set(asyncio.current_task()) + token = self.current_context.set(asyncio.current_task()) if id(model) in self.async_tasks: self.async_tasks[id(model)].append(asyncio.current_task()) else: @@ -410,6 +418,7 @@ async def process_context(self, func, model): res = False finally: self.async_tasks[id(model)].remove(asyncio.current_task()) + self.current_context.reset(token) if len(self.async_tasks[id(model)]) == 0: del self.async_tasks[id(model)] else: @@ -677,12 +686,30 @@ async def _timeout(): await asyncio.shield(self._process_timeout(event_data)) except asyncio.CancelledError: pass - - return asyncio.ensure_future(_timeout()) + return asyncio.create_task(_timeout()) async def _process_timeout(self, event_data): _LOGGER.debug("%sTimeout state %s. Processing callbacks...", event_data.machine.name, self.name) - await event_data.machine.callbacks(self.on_timeout, event_data) + event_data = AsyncEventData(event_data.state, AsyncEvent("timeout", event_data.machine), + event_data.machine, event_data.model, args=tuple(), kwargs={}) + token = AsyncMachine.current_context.set(None) + try: + await event_data.machine.callbacks(self.on_timeout, event_data) + except BaseException as err: + _LOGGER.warning("%sException raised while processing timeout!", + event_data.machine.name) + event_data.error = err + try: + if event_data.machine.on_exception: + await event_data.machine.callbacks(event_data.machine.on_exception, event_data) + else: + raise + except BaseException as err2: + _LOGGER.error("%sHandling timeout exception '%s' caused another exception: %s. " + "Cancel running transitions...", event_data.machine.name, repr(err), repr(err2)) + await event_data.machine.cancel_running_transitions(event_data.model, "timeout") + finally: + AsyncMachine.current_context.reset(token) _LOGGER.info("%sTimeout state %s processed.", event_data.machine.name, self.name) @property diff --git a/transitions/extensions/asyncio.pyi b/transitions/extensions/asyncio.pyi index 34b32476..00b3c972 100644 --- a/transitions/extensions/asyncio.pyi +++ b/transitions/extensions/asyncio.pyi @@ -112,6 +112,7 @@ class AsyncMachine(Machine): async def callback(self, func: AsyncCallback, event_data: AsyncEventData) -> None: ... # type: ignore[override] @staticmethod async def await_all(callables: List[AsyncCallbackFunc]) -> List[Optional[bool]]: ... + async def cancel_running_transitions(self, model: object, msg: Optional[str] = ...) -> None: ... async def switch_model_context(self, model: object) -> None: ... def get_state(self, state: Union[str, Enum]) -> AsyncState: ... async def process_context(self, func: Callable[[], Awaitable[None]], model: object) -> bool: ...