diff --git a/Changelog.md b/Changelog.md index 4eecc721..02bdceb3 100644 --- a/Changelog.md +++ b/Changelog.md @@ -5,6 +5,7 @@ - Bug #688: `Machine.remove_transitions` did not work with `State` and `Enum` even though the signature implied this (thanks @hookokoko) - PR #696: Improve handling of `StrEnum` which were previously confused as string (thanks @jbrocher) - Bug #697: An empty string as a destination made a transition internal but only `dest=None` should do this (thanks @rudy-lath-vizio) +- Bug #704: `AsyncMachine` processed all `CancelledErrors` but will from now on only do so if the error message is equal to `asyncio.CANCELLED_MSG`; this should make bypassing catch clauses easier; requires Python 3.11+ (thanks @Salier13) ## 0.9.3 (July 2024) diff --git a/tests/test_async.py b/tests/test_async.py index 71346a81..db4d67b4 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,4 +1,5 @@ from asyncio import CancelledError +import sys from transitions.extensions.factory import AsyncGraphMachine, HierarchicalAsyncGraphMachine from transitions.extensions.states import add_state_features @@ -402,6 +403,30 @@ async def run(): assert machine.is_A() asyncio.run(run()) + @skipIf(sys.version_info < (3, 11), "Cancel requires Python 3.11+") + def test_user_cancel(self): + machine = self.machine_cls(states=['A', 'B'], initial='A', before_state_change=self.cancel_soon) + + async def run1(): + try: + await asyncio.wait_for(machine.to_B(), timeout=0.5) + except asyncio.TimeoutError: + return # expected case + assert False, "Expected a TimeoutError" + + async def run2(): + async def raise_timeout(): + raise asyncio.TimeoutError("My custom timeout") + try: + machine.add_transition('cancelled', 'A', 'B', before=raise_timeout) + await machine.cancelled() + except asyncio.TimeoutError: + return # expected case + assert False, "Expected a TimeoutError" + asyncio.run(run1()) + assert machine.is_A() + asyncio.run(run2()) + def test_queued_timeout_cancel(self): error_mock = MagicMock() timout_mock = MagicMock() @@ -696,6 +721,28 @@ async def run(): asyncio.run(run()) + def test_deprecation_warnings(self): + import warnings + + async def run(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + machine = self.machine_cls(states=['A', 'B'], initial='A') + await machine.cancel_running_transitions(self) + self.assertEqual(len(w), 0) + # msg is deprecated, should not be used + await machine.cancel_running_transitions(self, msg="Custom message") + self.assertEqual(len(w), 1) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.assertEqual(len(w), 0) + # should use cancel_running_transitions instead + await machine.switch_model_context(self) + self.assertEqual(len(w), 1) + + asyncio.run(run()) + @skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz") class TestAsyncGraphMachine(TestAsync): diff --git a/transitions/extensions/asyncio.py b/transitions/extensions/asyncio.py index 7dbba231..d2525e53 100644 --- a/transitions/extensions/asyncio.py +++ b/transitions/extensions/asyncio.py @@ -18,6 +18,7 @@ import asyncio import contextvars import inspect +import sys import warnings from collections import deque from functools import partial, reduce @@ -32,6 +33,10 @@ _LOGGER.addHandler(logging.NullHandler()) +CANCELLED_MSG = "_transition" +"""A message passed to a cancelled task to indicate that the cancellation was caused by transitions.""" + + class AsyncState(State): """A persistent representation of a state managed by a ``Machine``. Callback execution is done asynchronously.""" @@ -117,7 +122,7 @@ async def execute(self, event_data): machine = event_data.machine # cancel running tasks since the transition will happen - await machine.cancel_running_transitions(event_data.model, event_data.event.name) + await machine.cancel_running_transitions(event_data.model) await event_data.machine.callbacks(event_data.machine.before_state_change, event_data) await event_data.machine.callbacks(self.before, event_data) @@ -386,14 +391,20 @@ async def cancel_running_transitions(self, model, msg=None): and the transition will happen. This requires already running tasks to be cancelled. Args: model (object): The currently processed model - msg (str): Optional message to pass to a running task's cancel request + msg (str): Optional message to pass to a running task's cancel request (deprecated). """ + if msg is not None: + warnings.warn( + "When you call cancel_running_transitions with a custom message " + "transitions will re-raise all raised CancelledError. " + "Make sure to catch them in your code. " + "The parameter 'msg' will likely be removed in a future release.", category=DeprecationWarning) for running_task in self.async_tasks.get(id(model), []): if self.current_context.get() == running_task or running_task in self.protected_tasks: continue if running_task.done() is False: _LOGGER.debug("Cancel running tasks...") - running_task.cancel(msg) + running_task.cancel(msg or CANCELLED_MSG) async def process_context(self, func, model): """ @@ -414,7 +425,12 @@ async def process_context(self, func, model): self.async_tasks[id(model)] = [asyncio.current_task()] try: res = await self._process_async(func, model) - except asyncio.CancelledError: + except asyncio.CancelledError as err: + # raise CancelledError only if the task was not cancelled by internal processes + # we indicate internal cancellation by passing CANCELLED_MSG to cancel() + if CANCELLED_MSG not in err.args and sys.version_info >= (3, 11): + _LOGGER.debug("%sExternal cancellation of task. Raise CancelledError...", self.name) + raise res = False finally: self.async_tasks[id(model)].remove(asyncio.current_task()) @@ -681,11 +697,8 @@ def create_timer(self, event_data): Returns (cancellable): A running timer with a cancel method """ async def _timeout(): - try: - await asyncio.sleep(self.timeout) - await asyncio.shield(self._process_timeout(event_data)) - except asyncio.CancelledError: - pass + await asyncio.sleep(self.timeout) + await asyncio.shield(self._process_timeout(event_data)) return asyncio.create_task(_timeout()) async def _process_timeout(self, event_data): @@ -707,7 +720,7 @@ async def _process_timeout(self, event_data): except BaseException as err2: _LOGGER.error("%sHandling timeout exception '%s' caused another exception: %s. " "Cancel running transitions...", event_data.machine.name, repr(err), repr(err2)) - await event_data.machine.cancel_running_transitions(event_data.model, "timeout") + await event_data.machine.cancel_running_transitions(event_data.model) finally: AsyncMachine.current_context.reset(token) _LOGGER.info("%sTimeout state %s processed.", event_data.machine.name, self.name) diff --git a/transitions/extensions/asyncio.pyi b/transitions/extensions/asyncio.pyi index 00b3c972..b02374b1 100644 --- a/transitions/extensions/asyncio.pyi +++ b/transitions/extensions/asyncio.pyi @@ -13,6 +13,8 @@ from ..core import StateIdentifier, CallbackList _LOGGER: Logger +CANCELLED_MSG: str = ... + AsyncCallbackFunc = Callable[..., Coroutine[Any, Any, Optional[bool]]] AsyncCallback = Union[str, AsyncCallbackFunc] AsyncCallbacksArg = Optional[Union[Callback, Iterable[Callback], AsyncCallback, Iterable[AsyncCallback]]]