Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Bug #688: `Machine.remove_transitions` did not work with `State` and `Enum` even though the signature implied this (thanks @hookokoko)
- PR #696: Improve handling of `StrEnum` which were previously confused as string (thanks @jbrocher)
- Bug #697: An empty string as a destination made a transition internal but only `dest=None` should do this (thanks @rudy-lath-vizio)
- Bug #704: `AsyncMachine` processed all `CancelledErrors` but will from now on only do so if the error message is equal to `asyncio.CANCELLED_MSG`; this should make bypassing catch clauses easier; requires Python 3.11+ (thanks @Salier13)

## 0.9.3 (July 2024)

Expand Down
47 changes: 47 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncio import CancelledError
import sys

from transitions.extensions.factory import AsyncGraphMachine, HierarchicalAsyncGraphMachine
from transitions.extensions.states import add_state_features
Expand Down Expand Up @@ -402,6 +403,30 @@ async def run():
assert machine.is_A()
asyncio.run(run())

@skipIf(sys.version_info < (3, 11), "Cancel requires Python 3.11+")
def test_user_cancel(self):
machine = self.machine_cls(states=['A', 'B'], initial='A', before_state_change=self.cancel_soon)

async def run1():
try:
await asyncio.wait_for(machine.to_B(), timeout=0.5)
except asyncio.TimeoutError:
return # expected case
assert False, "Expected a TimeoutError"

async def run2():
async def raise_timeout():
raise asyncio.TimeoutError("My custom timeout")
try:
machine.add_transition('cancelled', 'A', 'B', before=raise_timeout)
await machine.cancelled()
except asyncio.TimeoutError:
return # expected case
assert False, "Expected a TimeoutError"
asyncio.run(run1())
assert machine.is_A()
asyncio.run(run2())

def test_queued_timeout_cancel(self):
error_mock = MagicMock()
timout_mock = MagicMock()
Expand Down Expand Up @@ -696,6 +721,28 @@ async def run():

asyncio.run(run())

def test_deprecation_warnings(self):
import warnings

async def run():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
machine = self.machine_cls(states=['A', 'B'], initial='A')
await machine.cancel_running_transitions(self)
self.assertEqual(len(w), 0)
# msg is deprecated, should not be used
await machine.cancel_running_transitions(self, msg="Custom message")
self.assertEqual(len(w), 1)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.assertEqual(len(w), 0)
# should use cancel_running_transitions instead
await machine.switch_model_context(self)
self.assertEqual(len(w), 1)

asyncio.run(run())


@skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz")
class TestAsyncGraphMachine(TestAsync):
Expand Down
33 changes: 23 additions & 10 deletions transitions/extensions/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import contextvars
import inspect
import sys
import warnings
from collections import deque
from functools import partial, reduce
Expand All @@ -32,6 +33,10 @@
_LOGGER.addHandler(logging.NullHandler())


CANCELLED_MSG = "_transition"
"""A message passed to a cancelled task to indicate that the cancellation was caused by transitions."""


class AsyncState(State):
"""A persistent representation of a state managed by a ``Machine``. Callback execution is done asynchronously."""

Expand Down Expand Up @@ -117,7 +122,7 @@ async def execute(self, event_data):

machine = event_data.machine
# cancel running tasks since the transition will happen
await machine.cancel_running_transitions(event_data.model, event_data.event.name)
await machine.cancel_running_transitions(event_data.model)

await event_data.machine.callbacks(event_data.machine.before_state_change, event_data)
await event_data.machine.callbacks(self.before, event_data)
Expand Down Expand Up @@ -386,14 +391,20 @@ async def cancel_running_transitions(self, model, msg=None):
and the transition will happen. This requires already running tasks to be cancelled.
Args:
model (object): The currently processed model
msg (str): Optional message to pass to a running task's cancel request
msg (str): Optional message to pass to a running task's cancel request (deprecated).
"""
if msg is not None:
warnings.warn(
"When you call cancel_running_transitions with a custom message "
"transitions will re-raise all raised CancelledError. "
"Make sure to catch them in your code. "
"The parameter 'msg' will likely be removed in a future release.", category=DeprecationWarning)
for running_task in self.async_tasks.get(id(model), []):
if self.current_context.get() == running_task or running_task in self.protected_tasks:
continue
if running_task.done() is False:
_LOGGER.debug("Cancel running tasks...")
running_task.cancel(msg)
running_task.cancel(msg or CANCELLED_MSG)

async def process_context(self, func, model):
"""
Expand All @@ -414,7 +425,12 @@ async def process_context(self, func, model):
self.async_tasks[id(model)] = [asyncio.current_task()]
try:
res = await self._process_async(func, model)
except asyncio.CancelledError:
except asyncio.CancelledError as err:
# raise CancelledError only if the task was not cancelled by internal processes
# we indicate internal cancellation by passing CANCELLED_MSG to cancel()
if CANCELLED_MSG not in err.args and sys.version_info >= (3, 11):
_LOGGER.debug("%sExternal cancellation of task. Raise CancelledError...", self.name)
raise
res = False
finally:
self.async_tasks[id(model)].remove(asyncio.current_task())
Expand Down Expand Up @@ -681,11 +697,8 @@ def create_timer(self, event_data):
Returns (cancellable): A running timer with a cancel method
"""
async def _timeout():
try:
await asyncio.sleep(self.timeout)
await asyncio.shield(self._process_timeout(event_data))
except asyncio.CancelledError:
pass
await asyncio.sleep(self.timeout)
await asyncio.shield(self._process_timeout(event_data))
return asyncio.create_task(_timeout())

async def _process_timeout(self, event_data):
Expand All @@ -707,7 +720,7 @@ async def _process_timeout(self, event_data):
except BaseException as err2:
_LOGGER.error("%sHandling timeout exception '%s' caused another exception: %s. "
"Cancel running transitions...", event_data.machine.name, repr(err), repr(err2))
await event_data.machine.cancel_running_transitions(event_data.model, "timeout")
await event_data.machine.cancel_running_transitions(event_data.model)
finally:
AsyncMachine.current_context.reset(token)
_LOGGER.info("%sTimeout state %s processed.", event_data.machine.name, self.name)
Expand Down
2 changes: 2 additions & 0 deletions transitions/extensions/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ from ..core import StateIdentifier, CallbackList

_LOGGER: Logger

CANCELLED_MSG: str = ...

AsyncCallbackFunc = Callable[..., Coroutine[Any, Any, Optional[bool]]]
AsyncCallback = Union[str, AsyncCallbackFunc]
AsyncCallbacksArg = Optional[Union[Callback, Iterable[Callback], AsyncCallback, Iterable[AsyncCallback]]]
Expand Down