diff --git a/examples/python/cancellation/worker.py b/examples/python/cancellation/worker.py index 47d84f6cf2..bd31fe9fab 100644 --- a/examples/python/cancellation/worker.py +++ b/examples/python/cancellation/worker.py @@ -1,7 +1,7 @@ import asyncio import time -from hatchet_sdk import Context, EmptyModel, Hatchet +from hatchet_sdk import CancellationReason, CancelledError, Context, EmptyModel, Hatchet hatchet = Hatchet(debug=True) @@ -40,6 +40,25 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]: +# > Handling cancelled error +@cancellation_workflow.task() +def my_task(input: EmptyModel, ctx: Context) -> dict: + try: + result = ctx.playground("test", "default") + except CancelledError as e: + # Handle parent cancellation - i.e. perform cleanup, then re-raise + print(f"Parent Task cancelled: {e.reason}") + # Always re-raise CancelledError so Hatchet can properly handle the cancellation + raise + except Exception as e: + # This will NOT catch CancelledError + print(f"Other error: {e}") + raise + return result + + + + def main() -> None: worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow]) worker.start() diff --git a/examples/python/simple/worker.py b/examples/python/simple/worker.py index 85bb98a8ce..3dbb7d1cca 100644 --- a/examples/python/simple/worker.py +++ b/examples/python/simple/worker.py @@ -1,5 +1,5 @@ # > Simple -from hatchet_sdk import Context, EmptyModel, Hatchet +from hatchet_sdk import Context, DurableContext, EmptyModel, Hatchet hatchet = Hatchet(debug=True) @@ -10,7 +10,9 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]: @hatchet.durable_task() -def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]: +async def simple_durable(input: EmptyModel, ctx: DurableContext) -> dict[str, str]: + res = await simple.aio_run(input) + print(res) return {"result": "Hello, world!"} diff --git a/frontend/docs/pages/home/cancellation.mdx b/frontend/docs/pages/home/cancellation.mdx index 7fce04cd7e..88675a180b 100644 --- a/frontend/docs/pages/home/cancellation.mdx +++ b/frontend/docs/pages/home/cancellation.mdx @@ -22,15 +22,38 @@ When a task is canceled, Hatchet sends a cancellation signal to the task. The ta /> +### CancelledError Exception + +When a sync task is cancelled while waiting for a child workflow or during a cancellation-aware operation, a `CancelledError` exception is raised. + + + **Important:** `CancelledError` inherits from `BaseException`, not + `Exception`. This means it will **not** be caught by bare `except Exception:` + handlers. This is intentional and mirrors the behavior of Python's + `asyncio.CancelledError`. + + + + +### Cancellation Reasons + +The `CancelledError` includes a `reason` attribute that indicates why the cancellation occurred: + +| Reason | Description | +| --------------------------------------- | --------------------------------------------------------------------- | +| `CancellationReason.USER_REQUESTED` | The user explicitly requested cancellation via `ctx.cancel()` | +| `CancellationReason.WORKFLOW_CANCELLED` | The workflow run was cancelled (e.g., via API or concurrency control) | +| `CancellationReason.PARENT_CANCELLED` | The parent workflow was cancelled while waiting for a child | +| `CancellationReason.TIMEOUT` | The operation timed out | +| `CancellationReason.UNKNOWN` | Unknown or unspecified reason | + diff --git a/sdks/python/CHANGELOG.md b/sdks/python/CHANGELOG.md index 00f232a091..047f210a96 100644 --- a/sdks/python/CHANGELOG.md +++ b/sdks/python/CHANGELOG.md @@ -5,6 +5,22 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.25.0] - 2026-02-17 + +### Added + +- Adds a `CancellationToken` class for coordinating cancellation across async and sync operations. The token provides both `asyncio.Event` and `threading.Event` primitives, and supports registering child workflow run IDs and callbacks. +- Adds a `CancellationReason` enum with structured reasons for cancellation (`user_requested`, `timeout`, `parent_cancelled`, `workflow_cancelled`, `token_cancelled`). +- Adds a `CancelledError` exception (inherits from `BaseException`, mirroring `asyncio.CancelledError`) for sync code paths. +- Adds `cancellation_grace_period` and `cancellation_warning_threshold` configuration options to `ClientConfig` for controlling cancellation timing behavior. +- Adds `await_with_cancellation` and `race_against_token` utility functions for racing awaitables against cancellation tokens. +- The `Context` now exposes a `cancellation_token` property, allowing tasks to observe and react to cancellation signals directly. + +### Changed + +- The `Context.exit_flag` is now backed by a `CancellationToken` instead of a plain boolean. The property is maintained for backwards compatibility. +- Durable context `aio_wait_for` now respects the cancellation token, raising `asyncio.CancelledError` if the task is cancelled while waiting. + ## [1.24.0] - 2026-02-13 ### Added diff --git a/sdks/python/examples/cancellation/worker.py b/sdks/python/examples/cancellation/worker.py index e758e6f59b..d49d4e760b 100644 --- a/sdks/python/examples/cancellation/worker.py +++ b/sdks/python/examples/cancellation/worker.py @@ -1,7 +1,7 @@ import asyncio import time -from hatchet_sdk import Context, EmptyModel, Hatchet +from hatchet_sdk import CancelledError, Context, EmptyModel, Hatchet hatchet = Hatchet(debug=True) @@ -42,6 +42,26 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]: # !! +# > Handling cancelled error +@cancellation_workflow.task() +async def my_task(input: EmptyModel, ctx: Context) -> dict[str, str]: + try: + await asyncio.sleep(10) + except CancelledError as e: + # Handle parent cancellation - i.e. perform cleanup, then re-raise + print(f"Parent Task cancelled: {e.reason}") + # Always re-raise CancelledError so Hatchet can properly handle the cancellation + raise + except Exception as e: + # This will NOT catch CancelledError + print(f"Other error: {e}") + raise + return {"error": "Task should have been cancelled"} + + +# !! + + def main() -> None: worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow]) worker.start() diff --git a/sdks/python/examples/simple/worker.py b/sdks/python/examples/simple/worker.py index 686742c4fb..f1082a9ba5 100644 --- a/sdks/python/examples/simple/worker.py +++ b/sdks/python/examples/simple/worker.py @@ -10,7 +10,7 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]: @hatchet.durable_task() -def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]: +async def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]: return {"result": "Hello, world!"} diff --git a/sdks/python/hatchet_sdk/__init__.py b/sdks/python/hatchet_sdk/__init__.py index 6fe497718a..d9dc46e41d 100644 --- a/sdks/python/hatchet_sdk/__init__.py +++ b/sdks/python/hatchet_sdk/__init__.py @@ -1,3 +1,4 @@ +from hatchet_sdk.cancellation import CancellationToken from hatchet_sdk.clients.admin import ( RunStatus, ScheduleTriggerWorkflowOptions, @@ -155,6 +156,8 @@ WorkerLabelComparator, ) from hatchet_sdk.exceptions import ( + CancellationReason, + CancelledError, DedupeViolationError, FailedTaskRunExceptionGroup, NonRetryableException, @@ -194,6 +197,9 @@ "CELEvaluationResult", "CELFailure", "CELSuccess", + "CancellationReason", + "CancellationToken", + "CancelledError", "ClientConfig", "ClientTLSConfig", "ConcurrencyExpression", diff --git a/sdks/python/hatchet_sdk/cancellation.py b/sdks/python/hatchet_sdk/cancellation.py new file mode 100644 index 0000000000..5998677695 --- /dev/null +++ b/sdks/python/hatchet_sdk/cancellation.py @@ -0,0 +1,197 @@ +"""Cancellation token for coordinating cancellation across async and sync operations.""" + +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Callable +from typing import TYPE_CHECKING + +from hatchet_sdk.exceptions import CancellationReason +from hatchet_sdk.logger import logger + +if TYPE_CHECKING: + pass + + +class CancellationToken: + """ + A token that can be used to signal cancellation across async and sync operations. + + The token provides both asyncio and threading event primitives, allowing it to work + seamlessly in both async and sync code paths. Child workflow run IDs can be registered + with the token so they can be cancelled when the parent is cancelled. + + Example: + ```python + token = CancellationToken() + + # In async code + await token.aio_wait() # Blocks until cancelled + + # In sync code + token.wait(timeout=1.0) # Returns True if cancelled within timeout + + # Check if cancelled + if token.is_cancelled: + raise CancelledError("Operation was cancelled") + + # Trigger cancellation + token.cancel() + ``` + """ + + def __init__(self) -> None: + self._cancelled = False + self._reason: CancellationReason | None = None + self._async_event: asyncio.Event | None = None + self._sync_event = threading.Event() + self._child_run_ids: list[str] = [] + self._callbacks: list[Callable[[], None]] = [] + self._lock = threading.Lock() + + def _get_async_event(self) -> asyncio.Event: + """Lazily create the asyncio event to avoid requiring an event loop at init time.""" + if self._async_event is None: + self._async_event = asyncio.Event() + # If already cancelled, set the event + if self._cancelled: + self._async_event.set() + return self._async_event + + def cancel( + self, reason: CancellationReason = CancellationReason.TOKEN_CANCELLED + ) -> None: + """ + Trigger cancellation. + + This will: + - Set the cancelled flag and reason + - Signal both async and sync events + - Invoke all registered callbacks + + Args: + reason: The reason for cancellation. + """ + with self._lock: + if self._cancelled: + logger.debug( + f"CancellationToken: cancel() called but already cancelled, " + f"reason={self._reason.value if self._reason else 'none'}" + ) + return + + logger.debug( + f"CancellationToken: cancel() called, reason={reason.value}, " + f"{len(self._child_run_ids)} children registered" + ) + + self._cancelled = True + self._reason = reason + + # Signal both event types + if self._async_event is not None: + self._async_event.set() + self._sync_event.set() + + # Snapshot callbacks under the lock, invoke outside to avoid deadlocks + callbacks = list(self._callbacks) + + for callback in callbacks: + try: + logger.debug(f"CancellationToken: invoking callback {callback}") + callback() + except Exception as e: # noqa: PERF203 + logger.warning(f"CancellationToken: callback raised exception: {e}") + + logger.debug(f"CancellationToken: cancel() complete, reason={reason.value}") + + @property + def is_cancelled(self) -> bool: + """Check if cancellation has been triggered.""" + return self._cancelled + + @property + def reason(self) -> CancellationReason | None: + """Get the reason for cancellation, or None if not cancelled.""" + return self._reason + + async def aio_wait(self) -> None: + """ + Await until cancelled (for use in asyncio). + + This will block until cancel() is called. + """ + await self._get_async_event().wait() + logger.debug( + f"CancellationToken: async wait completed (cancelled), " + f"reason={self._reason.value if self._reason else 'none'}" + ) + + def wait(self, timeout: float | None = None) -> bool: + """ + Block until cancelled (for use in sync code). + + Args: + timeout: Maximum time to wait in seconds. None means wait forever. + + Returns: + True if the token was cancelled (event was set), False if timeout expired. + """ + result = self._sync_event.wait(timeout) + if result: + logger.debug( + f"CancellationToken: sync wait interrupted by cancellation, " + f"reason={self._reason.value if self._reason else 'none'}" + ) + return result + + def register_child(self, run_id: str) -> None: + """ + Register a child workflow run ID with this token. + + When the parent is cancelled, these child run IDs can be used to cancel + the child workflows as well. + + Args: + run_id: The workflow run ID of the child workflow. + """ + with self._lock: + logger.debug(f"CancellationToken: registering child workflow {run_id}") + self._child_run_ids.append(run_id) + + @property + def child_run_ids(self) -> list[str]: + """The registered child workflow run IDs.""" + return self._child_run_ids + + def add_callback(self, callback: Callable[[], None]) -> None: + """ + Register a callback to be invoked when cancellation is triggered. + + If the token is already cancelled, the callback will be invoked immediately. + + Args: + callback: A callable that takes no arguments. + """ + with self._lock: + if self._cancelled: + invoke_now = True + else: + invoke_now = False + self._callbacks.append(callback) + + if invoke_now: + logger.debug( + f"CancellationToken: invoking callback immediately (already cancelled): {callback}" + ) + try: + callback() + except Exception as e: + logger.warning(f"CancellationToken: callback raised exception: {e}") + + def __repr__(self) -> str: + return ( + f"CancellationToken(cancelled={self._cancelled}, " + f"children={len(self._child_run_ids)}, callbacks={len(self._callbacks)})" + ) diff --git a/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py b/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py index 8a99d8fdce..1d05ba77a5 100644 --- a/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py +++ b/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import asyncio from abc import ABC, abstractmethod from collections.abc import AsyncIterator -from typing import Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Generic, Literal, TypeVar import grpc import grpc.aio @@ -14,6 +16,10 @@ from hatchet_sdk.config import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.metadata import get_metadata +from hatchet_sdk.utils.cancellation import race_against_token + +if TYPE_CHECKING: + from hatchet_sdk.cancellation import CancellationToken DEFAULT_LISTENER_RETRY_INTERVAL = 3 # seconds DEFAULT_LISTENER_RETRY_COUNT = 5 @@ -36,7 +42,7 @@ def __init__(self, id: int) -> None: self.id = id self.queue: asyncio.Queue[T | SentinelValue] = asyncio.Queue() - async def __aiter__(self) -> "Subscription[T]": + async def __aiter__(self) -> Subscription[T]: return self async def __anext__(self) -> T | SentinelValue: @@ -199,7 +205,17 @@ def cleanup_subscription(self, subscription_id: int) -> None: del self.from_subscriptions[subscription_id] del self.events[subscription_id] - async def subscribe(self, id: str) -> T: + async def subscribe( + self, id: str, cancellation_token: CancellationToken | None = None + ) -> T: + """ + Subscribe to events for the given ID. + + :param id: The ID to subscribe to (e.g., workflow run ID). + :param cancellation_token: Optional cancellation token to abort the subscription wait. + :return: The event received for this ID. + :raises asyncio.CancelledError: If the cancellation token is triggered or if externally cancelled. + """ subscription_id: int | None = None try: @@ -221,8 +237,17 @@ async def subscribe(self, id: str) -> T: if not self.listener_task or self.listener_task.done(): self.listener_task = asyncio.create_task(self._init_producer()) + logger.debug( + f"PooledListener.subscribe: waiting for event on id={id}, " + f"subscription_id={subscription_id}, token={cancellation_token is not None}" + ) + + if cancellation_token: + result_task = asyncio.create_task(self.events[subscription_id].get()) + return await race_against_token(result_task, cancellation_token) return await self.events[subscription_id].get() except asyncio.CancelledError: + logger.debug(f"PooledListener.subscribe: externally cancelled for id={id}") raise finally: if subscription_id: diff --git a/sdks/python/hatchet_sdk/config.py b/sdks/python/hatchet_sdk/config.py index e49d2c9eac..1d315bff23 100644 --- a/sdks/python/hatchet_sdk/config.py +++ b/sdks/python/hatchet_sdk/config.py @@ -52,7 +52,7 @@ def validate_event_loop_block_threshold_seconds( if isinstance(value, timedelta): return value - if isinstance(value, int | float): + if isinstance(value, (int, float)): return timedelta(seconds=float(value)) v = value.strip() @@ -135,6 +135,37 @@ class ClientConfig(BaseSettings): force_shutdown_on_shutdown_signal: bool = False tenacity: TenacityConfig = TenacityConfig() + # Cancellation configuration + cancellation_grace_period: timedelta = Field( + default=timedelta(milliseconds=1000), + description="The maximum time to wait for a task to complete after cancellation is triggered before force-cancelling. Value is interpreted as seconds when provided as int/float.", + ) + cancellation_warning_threshold: timedelta = Field( + default=timedelta(milliseconds=300), + description="If a task has not completed cancellation within this duration, a warning will be logged. Value is interpreted as seconds when provided as int/float.", + ) + + @field_validator( + "cancellation_grace_period", "cancellation_warning_threshold", mode="before" + ) + @classmethod + def validate_cancellation_timedelta( + cls, value: timedelta | int | float | str + ) -> timedelta: + """Convert int/float/string to timedelta, interpreting as seconds.""" + if isinstance(value, timedelta): + return value + + if isinstance(value, (int, float)): + return timedelta(seconds=float(value)) + + v = value.strip() + # Allow a small convenience suffix, but keep "seconds" as the contract. + if v.endswith("s"): + v = v[:-1].strip() + + return timedelta(seconds=float(v)) + @model_validator(mode="after") def validate_token_and_tenant(self) -> "ClientConfig": if not self.token: diff --git a/sdks/python/hatchet_sdk/context/context.py b/sdks/python/hatchet_sdk/context/context.py index 549d7dc7f0..01d83d6f87 100644 --- a/sdks/python/hatchet_sdk/context/context.py +++ b/sdks/python/hatchet_sdk/context/context.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, cast from warnings import warn +from hatchet_sdk.cancellation import CancellationToken from hatchet_sdk.clients.admin import AdminClient from hatchet_sdk.clients.dispatcher.dispatcher import ( # type: ignore[attr-defined] Action, @@ -21,9 +22,10 @@ flatten_conditions, ) from hatchet_sdk.context.worker_context import WorkerContext -from hatchet_sdk.exceptions import TaskRunError +from hatchet_sdk.exceptions import CancellationReason, TaskRunError from hatchet_sdk.features.runs import RunsClient from hatchet_sdk.logger import logger +from hatchet_sdk.utils.cancellation import await_with_cancellation from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr from hatchet_sdk.utils.typing import JSONSerializableMapping, LogLevel from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, LogRecord @@ -56,7 +58,7 @@ def __init__( self.action = action self.step_run_id = action.step_run_id - self.exit_flag = False + self.cancellation_token = CancellationToken() self.dispatcher_client = dispatcher_client self.admin_client = admin_client self.event_client = event_client @@ -74,6 +76,31 @@ def __init__( self._workflow_name = workflow_name self._task_name = task_name + @property + def exit_flag(self) -> bool: + """ + Check if the cancellation flag has been set. + + This property is maintained for backwards compatibility. + Use `cancellation_token.is_cancelled` for new code. + + :return: True if the task has been cancelled, False otherwise. + """ + return self.cancellation_token.is_cancelled + + @exit_flag.setter + def exit_flag(self, value: bool) -> None: + """ + Set the cancellation flag. + + This setter is maintained for backwards compatibility. + Setting to True will trigger the cancellation token. + + :param value: True to trigger cancellation, False is a no-op. + """ + if value: + self.cancellation_token.cancel(CancellationReason.USER_REQUESTED) + def _increment_stream_index(self) -> int: index = self.stream_index self.stream_index += 1 @@ -169,8 +196,25 @@ def workflow_run_id(self) -> str: """ return self.action.workflow_run_id - def _set_cancellation_flag(self) -> None: - self.exit_flag = True + def _set_cancellation_flag( + self, reason: CancellationReason = CancellationReason.WORKFLOW_CANCELLED + ) -> None: + """ + Internal method to trigger cancellation. + + This triggers the cancellation token, which will: + - Signal all waiters (async and sync) + - Set the exit_flag property to True + - Allow child workflow cancellation + + Args: + reason: The reason for cancellation. + """ + logger.debug( + f"Context: setting cancellation flag for step_run_id={self.step_run_id}, " + f"reason={reason.value}" + ) + self.cancellation_token.cancel(reason) def cancel(self) -> None: """ @@ -178,9 +222,11 @@ def cancel(self) -> None: :return: None """ - logger.debug("cancelling step...") + logger.debug( + f"Context: cancel() called for task_run_external_id={self.step_run_id}" + ) self.runs_client.cancel(self.step_run_id) - self._set_cancellation_flag() + self._set_cancellation_flag(CancellationReason.USER_REQUESTED) async def aio_cancel(self) -> None: """ @@ -188,9 +234,11 @@ async def aio_cancel(self) -> None: :return: None """ - logger.debug("cancelling step...") + logger.debug( + f"Context: aio_cancel() called for task_run_external_id={self.step_run_id}" + ) await self.runs_client.aio_cancel(self.step_run_id) - self._set_cancellation_flag() + self._set_cancellation_flag(CancellationReason.USER_REQUESTED) def done(self) -> bool: """ @@ -482,8 +530,11 @@ async def aio_wait_for( """ Durably wait for either a sleep or an event. + This method respects the context's cancellation token. If the task is cancelled + while waiting, an asyncio.CancelledError will be raised. + :param signal_key: The key to use for the durable event. This is used to identify the event in the Hatchet API. - :param *conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition. + :param \\*conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition. :return: A dictionary containing the results of the wait. :raises ValueError: If the durable event listener is not available. @@ -493,6 +544,10 @@ async def aio_wait_for( task_id = self.step_run_id + logger.debug( + f"DurableContext.aio_wait_for: waiting for signal_key={signal_key}, task_id={task_id}" + ) + request = RegisterDurableEventRequest( task_id=task_id, signal_key=signal_key, @@ -502,20 +557,30 @@ async def aio_wait_for( self.durable_event_listener.register_durable_event(request) - return await self.durable_event_listener.result( - task_id, - signal_key, + # Use await_with_cancellation to respect the cancellation token + return await await_with_cancellation( + self.durable_event_listener.result(task_id, signal_key), + self.cancellation_token, ) async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]: """ Lightweight wrapper for durable sleep. Allows for shorthand usage of `ctx.aio_wait_for` when specifying a sleep condition. + This method respects the context's cancellation token. If the task is cancelled + while sleeping, an asyncio.CancelledError will be raised. + For more complicated conditions, use `ctx.aio_wait_for` directly. - """ + :param duration: The duration to sleep for. + :return: A dictionary containing the results of the wait. + """ wait_index = self._increment_wait_index() + logger.debug( + f"DurableContext.aio_sleep_for: sleeping for {duration}, wait_index={wait_index}" + ) + return await self.aio_wait_for( f"sleep:{timedelta_to_expr(duration)}-{wait_index}", SleepCondition(duration=duration), diff --git a/sdks/python/hatchet_sdk/exceptions.py b/sdks/python/hatchet_sdk/exceptions.py index 3ecc0c3e66..f13a7d35b9 100644 --- a/sdks/python/hatchet_sdk/exceptions.py +++ b/sdks/python/hatchet_sdk/exceptions.py @@ -1,5 +1,6 @@ import json import traceback +from enum import Enum from typing import cast @@ -170,3 +171,54 @@ class IllegalTaskOutputError(Exception): class LifespanSetupError(Exception): pass + + +class CancellationReason(Enum): + """Reason for cancellation of an operation.""" + + USER_REQUESTED = "user_requested" + """The user explicitly requested cancellation.""" + + TIMEOUT = "timeout" + """The operation timed out.""" + + PARENT_CANCELLED = "parent_cancelled" + """The parent workflow or task was cancelled.""" + + WORKFLOW_CANCELLED = "workflow_cancelled" + """The workflow run was cancelled.""" + + TOKEN_CANCELLED = "token_cancelled" + """The cancellation token was cancelled.""" + + +class CancelledError(BaseException): + """ + Raised when an operation is cancelled via CancellationToken. + + This exception inherits from BaseException (not Exception) so that it + won't be caught by bare `except Exception:` handlers. This mirrors the + behavior of asyncio.CancelledError in Python 3.8+. + + To catch this exception, use: + - `except CancelledError:` (recommended) + - `except BaseException:` (catches all exceptions) + + This exception is used for sync code paths. For async code paths, + asyncio.CancelledError is used instead. + + :param message: Optional message describing the cancellation. + :param reason: Optional enum indicating the reason for cancellation. + """ + + def __init__( + self, + message: str = "Operation cancelled", + reason: CancellationReason | None = None, + ) -> None: + self.reason = reason + super().__init__(message) + + @property + def message(self) -> str: + return str(self.args[0]) if self.args else "Operation cancelled" diff --git a/sdks/python/hatchet_sdk/runnables/contextvars.py b/sdks/python/hatchet_sdk/runnables/contextvars.py index 0d3c9d4904..dc0c9b2244 100644 --- a/sdks/python/hatchet_sdk/runnables/contextvars.py +++ b/sdks/python/hatchet_sdk/runnables/contextvars.py @@ -1,11 +1,17 @@ +from __future__ import annotations + import asyncio import threading from collections import Counter from contextvars import ContextVar +from typing import TYPE_CHECKING from hatchet_sdk.runnables.action import ActionKey from hatchet_sdk.utils.typing import JSONSerializableMapping +if TYPE_CHECKING: + from hatchet_sdk.cancellation import CancellationToken + ctx_workflow_run_id: ContextVar[str | None] = ContextVar( "ctx_workflow_run_id", default=None ) @@ -20,6 +26,9 @@ ctx_task_retry_count: ContextVar[int | None] = ContextVar( "ctx_task_retry_count", default=0 ) +ctx_cancellation_token: ContextVar[CancellationToken | None] = ContextVar( + "ctx_cancellation_token", default=None +) workflow_spawn_indices = Counter[ActionKey]() spawn_index_lock = asyncio.Lock() diff --git a/sdks/python/hatchet_sdk/runnables/workflow.py b/sdks/python/hatchet_sdk/runnables/workflow.py index 6eed17bbe7..927375db17 100644 --- a/sdks/python/hatchet_sdk/runnables/workflow.py +++ b/sdks/python/hatchet_sdk/runnables/workflow.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import json from collections.abc import Callable @@ -37,8 +39,11 @@ ) from hatchet_sdk.contracts.v1.workflows_pb2 import StickyStrategy as StickyStrategyProto from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion +from hatchet_sdk.exceptions import CancellationReason, CancelledError from hatchet_sdk.labels import DesiredWorkerLabel +from hatchet_sdk.logger import logger from hatchet_sdk.rate_limit import RateLimit +from hatchet_sdk.runnables.contextvars import ctx_cancellation_token from hatchet_sdk.runnables.task import Task from hatchet_sdk.runnables.types import ( ConcurrencyExpression, @@ -52,6 +57,7 @@ normalize_validator, ) from hatchet_sdk.serde import HATCHET_PYDANTIC_SENTINEL +from hatchet_sdk.utils.cancellation import await_with_cancellation from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto from hatchet_sdk.utils.timedelta_to_expression import Duration from hatchet_sdk.utils.typing import CoroutineLike, JSONSerializableMapping @@ -59,6 +65,7 @@ if TYPE_CHECKING: from hatchet_sdk import Hatchet + from hatchet_sdk.cancellation import CancellationToken T = TypeVar("T") @@ -88,7 +95,7 @@ class ComputedTaskParameters(BaseModel): task_defaults: TaskDefaults @model_validator(mode="after") - def validate_params(self) -> "ComputedTaskParameters": + def validate_params(self) -> ComputedTaskParameters: self.execution_timeout = fall_back_to_default( value=self.execution_timeout, param_default=timedelta(seconds=60), @@ -136,7 +143,7 @@ class TypedTriggerWorkflowRunConfig(BaseModel, Generic[TWorkflowInput]): class BaseWorkflow(Generic[TWorkflowInput]): - def __init__(self, config: WorkflowConfig, client: "Hatchet") -> None: + def __init__(self, config: WorkflowConfig, client: Hatchet) -> None: self.config = config self._default_tasks: list[Task[TWorkflowInput, Any]] = [] self._durable_tasks: list[Task[TWorkflowInput, Any]] = [] @@ -625,6 +632,38 @@ def greet(input, ctx): and can be arranged into complex dependency patterns. """ + def _resolve_check_cancellation_token(self) -> CancellationToken | None: + cancellation_token = ctx_cancellation_token.get() + + if cancellation_token and cancellation_token.is_cancelled: + raise CancelledError( + "Operation cancelled by cancellation token", + reason=CancellationReason.TOKEN_CANCELLED, + ) + + return cancellation_token + + def _register_child_with_token( + self, + cancellation_token: CancellationToken | None, + workflow_run_id: str, + ) -> None: + if not cancellation_token: + return + + cancellation_token.register_child(workflow_run_id) + + def _register_children_with_token( + self, + cancellation_token: CancellationToken | None, + refs: list[WorkflowRunRef], + ) -> None: + if not cancellation_token: + return + + for ref in refs: + cancellation_token.register_child(ref.workflow_run_id) + def run_no_wait( self, input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()), @@ -634,17 +673,34 @@ def run_no_wait( Synchronously trigger a workflow run without waiting for it to complete. This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs. + If a cancellation token is available via context, the child workflow will be registered + with the token. + :param input: The input data for the workflow. :param options: Additional options for workflow execution. :returns: A `WorkflowRunRef` object representing the reference to the workflow run. """ - return self.client._client.admin.run_workflow( + cancellation_token = self._resolve_check_cancellation_token() + + logger.debug( + f"Workflow.run_no_wait: triggering {self.config.name}, " + f"token={cancellation_token is not None}" + ) + + ref = self.client._client.admin.run_workflow( workflow_name=self.config.name, input=self._serialize_input(input), options=self._create_options_with_combined_additional_meta(options), ) + self._register_child_with_token( + cancellation_token, + ref.workflow_run_id, + ) + + return ref + def run( self, input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()), @@ -654,12 +710,19 @@ def run( Run the workflow synchronously and wait for it to complete. This method triggers a workflow run, blocks until completion, and returns the final result. + If a cancellation token is available via context, the wait can be interrupted. :param input: The input data for the workflow, must match the workflow's input type. :param options: Additional options for workflow execution like metadata and parent workflow ID. :returns: The result of the workflow execution as a dictionary. """ + cancellation_token = self._resolve_check_cancellation_token() + + logger.debug( + f"Workflow.run: triggering {self.config.name}, " + f"token={cancellation_token is not None}" + ) ref = self.client._client.admin.run_workflow( workflow_name=self.config.name, @@ -667,7 +730,14 @@ def run( options=self._create_options_with_combined_additional_meta(options), ) - return ref.result() + self._register_child_with_token( + cancellation_token, + ref.workflow_run_id, + ) + + logger.debug(f"Workflow.run: awaiting result for {ref.workflow_run_id}") + + return ref.result(cancellation_token=cancellation_token) async def aio_run_no_wait( self, @@ -678,18 +748,34 @@ async def aio_run_no_wait( Asynchronously trigger a workflow run without waiting for it to complete. This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs. + If a cancellation token is available via context, the child workflow will be registered + with the token. + :param input: The input data for the workflow. :param options: Additional options for workflow execution. :returns: A `WorkflowRunRef` object representing the reference to the workflow run. """ + cancellation_token = self._resolve_check_cancellation_token() - return await self.client._client.admin.aio_run_workflow( + logger.debug( + f"Workflow.aio_run_no_wait: triggering {self.config.name}, " + f"token={cancellation_token is not None}" + ) + + ref = await self.client._client.admin.aio_run_workflow( workflow_name=self.config.name, input=self._serialize_input(input), options=self._create_options_with_combined_additional_meta(options), ) + self._register_child_with_token( + cancellation_token, + ref.workflow_run_id, + ) + + return ref + async def aio_run( self, input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()), @@ -699,25 +785,47 @@ async def aio_run( Run the workflow asynchronously and wait for it to complete. This method triggers a workflow run, awaits until completion, and returns the final result. + If a cancellation token is available via context, the wait can be interrupted. :param input: The input data for the workflow, must match the workflow's input type. :param options: Additional options for workflow execution like metadata and parent workflow ID. :returns: The result of the workflow execution as a dictionary. """ + cancellation_token = self._resolve_check_cancellation_token() + + logger.debug( + f"Workflow.aio_run: triggering {self.config.name}, " + f"token={cancellation_token is not None}" + ) + ref = await self.client._client.admin.aio_run_workflow( workflow_name=self.config.name, input=self._serialize_input(input), options=self._create_options_with_combined_additional_meta(options), ) - return await ref.aio_result() + self._register_child_with_token( + cancellation_token, + ref.workflow_run_id, + ) + + logger.debug(f"Workflow.aio_run: awaiting result for {ref.workflow_run_id}") + + return await await_with_cancellation( + ref.aio_result(), + cancellation_token, + ) def _get_result( - self, ref: WorkflowRunRef, return_exceptions: bool + self, + ref: WorkflowRunRef, + return_exceptions: bool, ) -> dict[str, Any] | BaseException: try: - return ref.result() + return ref.result( + cancellation_token=self._resolve_check_cancellation_token() + ) except Exception as e: if return_exceptions: return e @@ -746,15 +854,52 @@ def run_many( Run a workflow in bulk and wait for all runs to complete. This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results. + If a cancellation token is available via context, all child workflows will be registered + with the token and the wait can be interrupted. + :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered. :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them. :returns: A list of results for each workflow run. + :raises CancelledError: If the cancellation token is triggered (and return_exceptions is False). + :raises Exception: If a workflow run fails (and return_exceptions is False). """ + cancellation_token = self._resolve_check_cancellation_token() + refs = self.client._client.admin.run_workflows( workflows=workflows, ) - return [self._get_result(ref, return_exceptions) for ref in refs] + self._register_children_with_token( + cancellation_token, + refs, + ) + + # Pass cancellation_token through to each result() call + # The cancellation check happens INSIDE result()'s polling loop + results: list[dict[str, Any] | BaseException] = [] + for ref in refs: + try: + results.append(ref.result(cancellation_token=cancellation_token)) + except CancelledError: # noqa: PERF203 + logger.debug( + f"Workflow.run_many: cancellation detected, stopping wait, " + f"reason={CancellationReason.PARENT_CANCELLED.value}" + ) + if return_exceptions: + results.append( + CancelledError( + "Operation cancelled by cancellation token", + reason=CancellationReason.PARENT_CANCELLED, + ) + ) + break + raise + except Exception as e: + if return_exceptions: + results.append(e) + else: + raise + return results @overload async def aio_run_many( @@ -779,16 +924,34 @@ async def aio_run_many( Run a workflow in bulk and wait for all runs to complete. This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results. + If a cancellation token is available via context, all child workflows will be registered + with the token and the wait can be interrupted. + :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered. :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them. :returns: A list of results for each workflow run. """ + cancellation_token = self._resolve_check_cancellation_token() + + logger.debug( + f"Workflow.aio_run_many: triggering {len(workflows)} workflows, " + f"token={cancellation_token is not None}" + ) + refs = await self.client._client.admin.aio_run_workflows( workflows=workflows, ) - return await asyncio.gather( - *[ref.aio_result() for ref in refs], return_exceptions=return_exceptions + self._register_children_with_token( + cancellation_token, + refs, + ) + + return await await_with_cancellation( + asyncio.gather( + *[ref.aio_result() for ref in refs], return_exceptions=return_exceptions + ), + cancellation_token, ) def run_many_no_wait( @@ -800,13 +963,30 @@ def run_many_no_wait( This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run. + If a cancellation token is available via context, all child workflows will be registered + with the token. + :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered. :returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run. """ - return self.client._client.admin.run_workflows( + cancellation_token = self._resolve_check_cancellation_token() + + logger.debug( + f"Workflow.run_many_no_wait: triggering {len(workflows)} workflows, " + f"token={cancellation_token is not None}" + ) + + refs = self.client._client.admin.run_workflows( workflows=workflows, ) + self._register_children_with_token( + cancellation_token, + refs, + ) + + return refs + async def aio_run_many_no_wait( self, workflows: list[WorkflowRunTriggerConfig], @@ -816,14 +996,31 @@ async def aio_run_many_no_wait( This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run. + If a cancellation token is available via context, all child workflows will be registered + with the token. + :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered. :returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run. """ - return await self.client._client.admin.aio_run_workflows( + cancellation_token = self._resolve_check_cancellation_token() + + logger.debug( + f"Workflow.aio_run_many_no_wait: triggering {len(workflows)} workflows, " + f"token={cancellation_token is not None}" + ) + + refs = await self.client._client.admin.aio_run_workflows( workflows=workflows, ) + self._register_children_with_token( + cancellation_token, + refs, + ) + + return refs + def _parse_task_name( self, name: str | None, @@ -1168,7 +1365,7 @@ def inner( return inner - def add_task(self, task: "Standalone[TWorkflowInput, Any]") -> None: + def add_task(self, task: Standalone[TWorkflowInput, Any]) -> None: """ Add a task to a workflow. Intended to be used with a previously existing task (a Standalone), such as one created with `@hatchet.task()`, which has been converted to a `Task` object using `to_task`. @@ -1207,7 +1404,7 @@ def my_task(input, ctx) -> None: class TaskRunRef(Generic[TWorkflowInput, R]): def __init__( self, - standalone: "Standalone[TWorkflowInput, R]", + standalone: Standalone[TWorkflowInput, R], workflow_run_ref: WorkflowRunRef, ): self._s = standalone @@ -1366,7 +1563,9 @@ def run_many( ) -> list[R]: ... def run_many( - self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False + self, + workflows: list[WorkflowRunTriggerConfig], + return_exceptions: bool = False, ) -> list[R] | list[R | BaseException]: """ Run a workflow in bulk and wait for all runs to complete. @@ -1400,7 +1599,9 @@ async def aio_run_many( ) -> list[R]: ... async def aio_run_many( - self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False + self, + workflows: list[WorkflowRunTriggerConfig], + return_exceptions: bool = False, ) -> list[R] | list[R | BaseException]: """ Run a workflow in bulk and wait for all runs to complete. @@ -1420,7 +1621,8 @@ async def aio_run_many( ] def run_many_no_wait( - self, workflows: list[WorkflowRunTriggerConfig] + self, + workflows: list[WorkflowRunTriggerConfig], ) -> list[TaskRunRef[TWorkflowInput, R]]: """ Run a workflow in bulk without waiting for all runs to complete. @@ -1435,7 +1637,8 @@ def run_many_no_wait( return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs] async def aio_run_many_no_wait( - self, workflows: list[WorkflowRunTriggerConfig] + self, + workflows: list[WorkflowRunTriggerConfig], ) -> list[TaskRunRef[TWorkflowInput, R]]: """ Run a workflow in bulk without waiting for all runs to complete. diff --git a/sdks/python/hatchet_sdk/utils/cancellation.py b/sdks/python/hatchet_sdk/utils/cancellation.py new file mode 100644 index 0000000000..eefe7a4d6d --- /dev/null +++ b/sdks/python/hatchet_sdk/utils/cancellation.py @@ -0,0 +1,149 @@ +"""Utilities for cancellation-aware operations.""" + +from __future__ import annotations + +import asyncio +import contextlib +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, TypeVar + +from hatchet_sdk.logger import logger + +if TYPE_CHECKING: + from hatchet_sdk.cancellation import CancellationToken + +T = TypeVar("T") + + +async def _invoke_cancel_callback( + cancel_callback: Callable[[], Awaitable[None]] | None, +) -> None: + """Invoke a cancel callback.""" + if not cancel_callback: + return + + await cancel_callback() + + +async def race_against_token( + main_task: asyncio.Task[T], + token: CancellationToken, +) -> T: + """ + Race an asyncio task against a cancellation token. + + Waits for either the task to complete or the token to be cancelled. Cleans up + whichever side loses the race. + + Args: + main_task: The asyncio task to race. + token: The cancellation token to race against. + + Returns: + The result of the main task if it completes first. + + Raises: + asyncio.CancelledError: If the token fires before the task completes. + """ + cancel_task = asyncio.create_task(token.aio_wait()) + + try: + done, pending = await asyncio.wait( + [main_task, cancel_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel pending tasks + for task in pending: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + if cancel_task in done: + raise asyncio.CancelledError("Operation cancelled by cancellation token") + + return main_task.result() + + except asyncio.CancelledError: + # Ensure both tasks are cleaned up on any cancellation (external or token) + main_task.cancel() + cancel_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await main_task + with contextlib.suppress(asyncio.CancelledError): + await cancel_task + raise + + +async def await_with_cancellation( + coro: Awaitable[T], + token: CancellationToken | None, + cancel_callback: Callable[[], Awaitable[None]] | None = None, +) -> T: + """ + Await an awaitable with cancellation support. + + This function races the given awaitable against a cancellation token. If the + token is cancelled before the awaitable completes, the awaitable is cancelled + and an asyncio.CancelledError is raised. + + Args: + coro: The awaitable to await (coroutine, Future, or asyncio.Task). + token: The cancellation token to check. If None, the coroutine is awaited directly. + cancel_callback: An optional async callback to invoke when cancellation occurs + (e.g., to cancel child workflows). + + Returns: + The result of the coroutine. + + Raises: + asyncio.CancelledError: If the token is cancelled before the coroutine completes. + + Example: + ```python + async def cleanup() -> None: + print("cleaning up...") + + async def long_running_task(): + await asyncio.sleep(10) + return "done" + + token = CancellationToken() + + # This will raise asyncio.CancelledError if token.cancel() is called + result = await await_with_cancellation( + long_running_task(), + token, + cancel_callback=cleanup, + ) + ``` + """ + + if token is None: + logger.debug("await_with_cancellation: no token provided, awaiting directly") + return await coro + + logger.debug("await_with_cancellation: starting with cancellation token") + + # Check if already cancelled + if token.is_cancelled: + logger.debug("await_with_cancellation: token already cancelled") + if cancel_callback: + logger.debug("await_with_cancellation: invoking cancel callback") + await _invoke_cancel_callback(cancel_callback) + raise asyncio.CancelledError("Operation cancelled by cancellation token") + + main_task = asyncio.ensure_future(coro) + + try: + result = await race_against_token(main_task, token) + logger.debug("await_with_cancellation: completed successfully") + return result + + except asyncio.CancelledError: + logger.debug("await_with_cancellation: cancelled") + if cancel_callback: + logger.debug("await_with_cancellation: invoking cancel callback") + with contextlib.suppress(asyncio.CancelledError): + await asyncio.shield(_invoke_cancel_callback(cancel_callback)) + raise diff --git a/sdks/python/hatchet_sdk/worker/runner/runner.py b/sdks/python/hatchet_sdk/worker/runner/runner.py index d853858b40..fde8fba734 100644 --- a/sdks/python/hatchet_sdk/worker/runner/runner.py +++ b/sdks/python/hatchet_sdk/worker/runner/runner.py @@ -2,6 +2,7 @@ import ctypes import functools import json +import time from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, is_dataclass @@ -29,6 +30,7 @@ STEP_EVENT_TYPE_STARTED, ) from hatchet_sdk.exceptions import ( + CancellationReason, IllegalTaskOutputError, NonRetryableException, TaskRunError, @@ -39,6 +41,7 @@ from hatchet_sdk.runnables.contextvars import ( ctx_action_key, ctx_additional_metadata, + ctx_cancellation_token, ctx_step_run_id, ctx_task_retry_count, ctx_worker_id, @@ -60,6 +63,7 @@ ContextVarToCopyDict, ContextVarToCopyInt, ContextVarToCopyStr, + ContextVarToCopyToken, copy_context_vars, ) @@ -251,6 +255,7 @@ async def async_wrapped_action_func( ctx_action_key.set(action.key) ctx_additional_metadata.set(action.additional_metadata) ctx_task_retry_count.set(action.retry_count) + ctx_cancellation_token.set(ctx.cancellation_token) async with task._unpack_dependencies_with_cleanup(ctx) as dependencies: try: @@ -298,6 +303,12 @@ async def async_wrapped_action_func( value=action.retry_count, ) ), + ContextVarToCopy( + var=ContextVarToCopyToken( + name="ctx_cancellation_token", + value=ctx.cancellation_token, + ) + ), ], self.thread_action_func, ctx, @@ -480,28 +491,95 @@ def force_kill_thread(self, thread: Thread) -> None: ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor async def handle_cancel_action(self, action: Action) -> None: key = action.key + start_time = time.monotonic() + + logger.info( + f"Cancellation: received cancel action for {action.action_id}, " + f"reason={CancellationReason.WORKFLOW_CANCELLED.value}" + ) + try: - # call cancel to signal the context to stop + # Trigger the cancellation token to signal the context to stop if key in self.contexts: - self.contexts[key]._set_cancellation_flag() + ctx = self.contexts[key] + child_count = len(ctx.cancellation_token.child_run_ids) + logger.debug( + f"Cancellation: triggering token for {action.action_id}, " + f"reason={CancellationReason.WORKFLOW_CANCELLED.value}, " + f"{child_count} children registered" + ) + ctx._set_cancellation_flag(CancellationReason.WORKFLOW_CANCELLED) self.cancellations[key] = True + # Note: Child workflows are not cancelled here - they run independently + # and are managed by Hatchet's normal cancellation mechanisms + else: + logger.debug(f"Cancellation: no context found for {action.action_id}") + + # Wait with supervision (using timedelta configs) + grace_period = self.config.cancellation_grace_period.total_seconds() + warning_threshold = ( + self.config.cancellation_warning_threshold.total_seconds() + ) + grace_period_ms = round(grace_period * 1000) + warning_threshold_ms = round(warning_threshold * 1000) - await asyncio.sleep(1) - - if key in self.tasks: - self.tasks[key].cancel() - - # check if thread is still running, if so, print a warning - if key in self.threads: - thread = self.threads[key] + # Wait until warning threshold + await asyncio.sleep(warning_threshold) + elapsed = time.monotonic() - start_time + elapsed_ms = round(elapsed * 1000) - if self.config.enable_force_kill_sync_threads: - self.force_kill_thread(thread) - await asyncio.sleep(1) + # Check if the task has not yet exited despite the cancellation signal. + task_still_running = key in self.tasks and not self.tasks[key].done() + if task_still_running: logger.warning( - f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running." + f"Cancellation: task {action.action_id} has not cancelled after " + f"{elapsed_ms}ms (warning threshold {warning_threshold_ms}ms). " + f"Consider checking for blocking operations. " + f"See https://docs.hatchet.run/home/cancellation" ) + + remaining = grace_period - elapsed + if remaining > 0: + await asyncio.sleep(remaining) + + if key in self.tasks and not self.tasks[key].done(): + logger.debug( + f"Cancellation: force-cancelling task {action.action_id} " + f"after grace period ({grace_period_ms}ms)" + ) + self.tasks[key].cancel() + + if key in self.threads: + thread = self.threads[key] + + if self.config.enable_force_kill_sync_threads: + logger.debug( + f"Cancellation: force-killing thread for {action.action_id}" + ) + self.force_kill_thread(thread) + await asyncio.sleep(1) + + if thread.is_alive(): + logger.warning( + f"Cancellation: thread {thread.ident} with key {key} is still running " + f"after cancellation. This could cause the thread pool to get blocked " + f"and prevent new tasks from running." + ) + + total_elapsed = time.monotonic() - start_time + total_elapsed_ms = round(total_elapsed * 1000) + if total_elapsed > grace_period: + logger.warning( + f"Cancellation: cancellation of {action.action_id} took {total_elapsed_ms}ms " + f"(exceeded grace period of {grace_period_ms}ms)" + ) + else: + logger.debug( + f"Cancellation: task {action.action_id} eventually completed in {total_elapsed_ms}ms" + ) + else: + logger.info(f"Cancellation: task {action.action_id} completed") finally: self.cleanup_run_id(key) diff --git a/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py b/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py index 6fd8b52ee1..4b105f843b 100644 --- a/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import functools import logging @@ -5,13 +7,15 @@ from io import StringIO from typing import Literal, ParamSpec, TypeVar -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field +from hatchet_sdk.cancellation import CancellationToken from hatchet_sdk.clients.events import EventClient from hatchet_sdk.logger import logger from hatchet_sdk.runnables.contextvars import ( ctx_action_key, ctx_additional_metadata, + ctx_cancellation_token, ctx_step_run_id, ctx_task_retry_count, ctx_worker_id, @@ -48,10 +52,22 @@ class ContextVarToCopyDict(BaseModel): value: JSONSerializableMapping | None +class ContextVarToCopyToken(BaseModel): + """Special type for copying CancellationToken to threads.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: Literal["ctx_cancellation_token"] + value: CancellationToken | None + + class ContextVarToCopy(BaseModel): - var: ContextVarToCopyStr | ContextVarToCopyDict | ContextVarToCopyInt = Field( - discriminator="name" - ) + var: ( + ContextVarToCopyStr + | ContextVarToCopyDict + | ContextVarToCopyInt + | ContextVarToCopyToken + ) = Field(discriminator="name") def copy_context_vars( @@ -73,6 +89,8 @@ def copy_context_vars( ctx_worker_id.set(var.var.value) elif var.var.name == "ctx_additional_metadata": ctx_additional_metadata.set(var.var.value or {}) + elif var.var.name == "ctx_cancellation_token": + ctx_cancellation_token.set(var.var.value) else: raise ValueError(f"Unknown context variable name: {var.var.name}") diff --git a/sdks/python/hatchet_sdk/workflow_run.py b/sdks/python/hatchet_sdk/workflow_run.py index 5760eef8f9..e1d0c3cc25 100644 --- a/sdks/python/hatchet_sdk/workflow_run.py +++ b/sdks/python/hatchet_sdk/workflow_run.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from typing import TYPE_CHECKING, Any @@ -6,9 +8,17 @@ RunEventListenerClient, ) from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener -from hatchet_sdk.exceptions import FailedTaskRunExceptionGroup, TaskRunError +from hatchet_sdk.exceptions import ( + CancellationReason, + CancelledError, + FailedTaskRunExceptionGroup, + TaskRunError, +) +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.cancellation import await_with_cancellation if TYPE_CHECKING: + from hatchet_sdk.cancellation import CancellationToken from hatchet_sdk.clients.admin import AdminClient @@ -18,7 +28,7 @@ def __init__( workflow_run_id: str, workflow_run_listener: PooledWorkflowRunListener, workflow_run_event_listener: RunEventListenerClient, - admin_client: "AdminClient", + admin_client: AdminClient, ): self.workflow_run_id = workflow_run_id self.workflow_run_listener = workflow_run_listener @@ -31,7 +41,25 @@ def __str__(self) -> str: def stream(self) -> RunEventListener: return self.workflow_run_event_listener.stream(self.workflow_run_id) - async def aio_result(self) -> dict[str, Any]: + async def aio_result( + self, cancellation_token: CancellationToken | None = None + ) -> dict[str, Any]: + """ + Asynchronously wait for the workflow run to complete and return the result. + + :param cancellation_token: Optional cancellation token to abort the wait. + :return: A dictionary mapping task names to their outputs. + """ + logger.debug( + f"WorkflowRunRef.aio_result: waiting for {self.workflow_run_id}, " + f"token={cancellation_token is not None}" + ) + + if cancellation_token: + return await await_with_cancellation( + self.workflow_run_listener.aio_result(self.workflow_run_id), + cancellation_token, + ) return await self.workflow_run_listener.aio_result(self.workflow_run_id) def _safely_get_action_name(self, action_id: str | None) -> str | None: @@ -43,12 +71,42 @@ def _safely_get_action_name(self, action_id: str | None) -> str | None: except IndexError: return None - def result(self) -> dict[str, Any]: + def result( + self, cancellation_token: CancellationToken | None = None + ) -> dict[str, Any]: + """ + Synchronously wait for the workflow run to complete and return the result. + + This method polls the API for the workflow run status. If a cancellation token + is provided, the polling will be interrupted when cancellation is triggered. + + :param cancellation_token: Optional cancellation token to abort the wait. + :return: A dictionary mapping task names to their outputs. + :raises CancelledError: If the cancellation token is triggered. + :raises FailedTaskRunExceptionGroup: If the workflow run fails. + :raises ValueError: If the workflow run is not found. + """ from hatchet_sdk.clients.admin import RunStatus + logger.debug( + f"WorkflowRunRef.result: waiting for {self.workflow_run_id}, " + f"token={cancellation_token is not None}" + ) + retries = 0 while True: + # Check cancellation at start of each iteration + if cancellation_token and cancellation_token.is_cancelled: + logger.debug( + f"WorkflowRunRef.result: cancellation detected for {self.workflow_run_id}, " + f"reason={CancellationReason.PARENT_CANCELLED.value}" + ) + raise CancelledError( + "Operation cancelled by cancellation token", + reason=CancellationReason.PARENT_CANCELLED, + ) + try: details = self.admin_client.get_details(self.workflow_run_id) except Exception as e: @@ -59,14 +117,42 @@ def result(self) -> dict[str, Any]: f"Workflow run {self.workflow_run_id} not found" ) from e - time.sleep(1) + # Use interruptible sleep via token.wait() + if cancellation_token: + if cancellation_token.wait(timeout=1.0): + logger.debug( + f"WorkflowRunRef.result: cancellation during retry sleep for {self.workflow_run_id}, " + f"reason={CancellationReason.PARENT_CANCELLED.value}" + ) + raise CancelledError( + "Operation cancelled by cancellation token", + reason=CancellationReason.PARENT_CANCELLED, + ) from None + else: + time.sleep(1) continue + logger.debug( + f"WorkflowRunRef.result: {self.workflow_run_id} status={details.status}" + ) + if ( details.status in [RunStatus.QUEUED, RunStatus.RUNNING] or details.done is False ): - time.sleep(1) + # Use interruptible sleep via token.wait() + if cancellation_token: + if cancellation_token.wait(timeout=1.0): + logger.debug( + f"WorkflowRunRef.result: cancellation during poll sleep for {self.workflow_run_id}, " + f"reason={CancellationReason.PARENT_CANCELLED.value}" + ) + raise CancelledError( + "Operation cancelled by cancellation token", + reason=CancellationReason.PARENT_CANCELLED, + ) + else: + time.sleep(1) continue if details.status == RunStatus.FAILED: @@ -80,6 +166,9 @@ def result(self) -> dict[str, Any]: ) if details.status == RunStatus.COMPLETED: + logger.debug( + f"WorkflowRunRef.result: {self.workflow_run_id} completed successfully" + ) return { readable_id: run.output for readable_id, run in details.task_runs.items() diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 89c49d4efd..3121ffdebe 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hatchet-sdk" -version = "1.24.0" +version = "1.25.0" description = "This is the official Python SDK for Hatchet, a distributed, fault-tolerant task queue. The SDK allows you to easily integrate Hatchet's task scheduling and workflow orchestration capabilities into your Python applications." authors = [ "Alexander Belanger ", diff --git a/sdks/python/tests/test_cancellation.py b/sdks/python/tests/test_cancellation.py new file mode 100644 index 0000000000..1b6046962e --- /dev/null +++ b/sdks/python/tests/test_cancellation.py @@ -0,0 +1,461 @@ +"""Unit tests for CancellationToken and cancellation utilities.""" + +import asyncio +import threading +import time + +import pytest + +from hatchet_sdk.cancellation import CancellationToken +from hatchet_sdk.exceptions import CancellationReason, CancelledError +from hatchet_sdk.runnables.contextvars import ctx_cancellation_token +from hatchet_sdk.utils.cancellation import await_with_cancellation + +# CancellationToken + + +def test_initial_state() -> None: + """Token should start in non-cancelled state.""" + token = CancellationToken() + assert token.is_cancelled is False + + +def test_cancel_sets_flag() -> None: + """cancel() should set is_cancelled to True.""" + token = CancellationToken() + token.cancel() + assert token.is_cancelled is True + + +def test_cancel_sets_reason() -> None: + """cancel() should set the reason.""" + token = CancellationToken() + token.cancel(CancellationReason.USER_REQUESTED) + assert token.reason == CancellationReason.USER_REQUESTED + + +def test_reason_is_none_before_cancel() -> None: + """reason should be None before cancellation.""" + token = CancellationToken() + assert token.reason is None + + +def test_cancel_idempotent() -> None: + """Multiple calls to cancel() should be safe.""" + token = CancellationToken() + token.cancel() + token.cancel() # Should not raise + assert token.is_cancelled is True + + +def test_cancel_idempotent_preserves_reason() -> None: + """Multiple calls to cancel() should preserve the original reason.""" + token = CancellationToken() + token.cancel(CancellationReason.USER_REQUESTED) + token.cancel(CancellationReason.TIMEOUT) # Second call should be ignored + assert token.reason == CancellationReason.USER_REQUESTED + + +def test_sync_wait_returns_true_when_cancelled() -> None: + """wait() should return True immediately if already cancelled.""" + token = CancellationToken() + token.cancel() + result = token.wait(timeout=0.1) + assert result is True + + +def test_sync_wait_timeout_returns_false() -> None: + """wait() should return False when timeout expires without cancellation.""" + token = CancellationToken() + start = time.monotonic() + result = token.wait(timeout=0.1) + elapsed = time.monotonic() - start + assert result is False + assert elapsed >= 0.1 + + +def test_sync_wait_interrupted_by_cancel() -> None: + """wait() should return True when cancelled during wait.""" + token = CancellationToken() + + def cancel_after_delay() -> None: + time.sleep(0.1) + token.cancel() + + thread = threading.Thread(target=cancel_after_delay) + thread.start() + + start = time.monotonic() + result = token.wait(timeout=1.0) + elapsed = time.monotonic() - start + + thread.join() + + assert result is True + assert elapsed < 0.5 # Should be much faster than timeout + + +@pytest.mark.asyncio +async def test_aio_wait_returns_when_cancelled() -> None: + """aio_wait() should return when cancelled.""" + token = CancellationToken() + + async def cancel_after_delay() -> None: + await asyncio.sleep(0.1) + token.cancel() + + asyncio.create_task(cancel_after_delay()) + + start = time.monotonic() + await token.aio_wait() + elapsed = time.monotonic() - start + + assert elapsed < 0.5 # Should be fast + + +def test_register_child() -> None: + """register_child() should add run IDs to the list.""" + token = CancellationToken() + token.register_child("run-1") + token.register_child("run-2") + + assert token.child_run_ids == ["run-1", "run-2"] + + +def test_callback_invoked_on_cancel() -> None: + """Callbacks should be invoked when cancel() is called.""" + token = CancellationToken() + called = [] + + def callback() -> None: + called.append(True) + + token.add_callback(callback) + token.cancel() + + assert called == [True] + + +def test_callback_invoked_immediately_if_already_cancelled() -> None: + """Callbacks added after cancellation should be invoked immediately.""" + token = CancellationToken() + token.cancel() + + called = [] + + def callback() -> None: + called.append(True) + + token.add_callback(callback) + + assert called == [True] + + +def test_multiple_callbacks() -> None: + """Multiple callbacks should all be invoked.""" + token = CancellationToken() + results: list[int] = [] + + token.add_callback(lambda: results.append(1)) + token.add_callback(lambda: results.append(2)) + token.add_callback(lambda: results.append(3)) + + token.cancel() + + assert results == [1, 2, 3] + + +def test_repr() -> None: + """__repr__ should provide useful debugging info.""" + token = CancellationToken() + token.register_child("run-1") + + repr_str = repr(token) + assert "cancelled=False" in repr_str + assert "children=1" in repr_str + + +# await_with_cancellation + + +@pytest.mark.asyncio +async def test_no_token_awaits_directly() -> None: + """Without a token, coroutine should be awaited directly.""" + + async def simple_coro() -> str: + return "result" + + result = await await_with_cancellation(simple_coro(), None) + assert result == "result" + + +@pytest.mark.asyncio +async def test_token_not_cancelled_returns_result() -> None: + """With a non-cancelled token, should return coroutine result.""" + token = CancellationToken() + + async def simple_coro() -> str: + await asyncio.sleep(0.01) + return "result" + + result = await await_with_cancellation(simple_coro(), token) + assert result == "result" + + +@pytest.mark.asyncio +async def test_already_cancelled_raises_immediately() -> None: + """With an already-cancelled token, should raise immediately.""" + token = CancellationToken() + token.cancel() + + async def simple_coro() -> str: + await asyncio.sleep(10) # Would block if actually awaited + return "result" + + with pytest.raises(asyncio.CancelledError): + await await_with_cancellation(simple_coro(), token) + + +@pytest.mark.asyncio +async def test_cancellation_during_await_raises() -> None: + """Should raise CancelledError when token is cancelled during await.""" + token = CancellationToken() + + async def slow_coro() -> str: + await asyncio.sleep(10) + return "result" + + async def cancel_after_delay() -> None: + await asyncio.sleep(0.1) + token.cancel() + + asyncio.create_task(cancel_after_delay()) + + start = time.monotonic() + with pytest.raises(asyncio.CancelledError): + await await_with_cancellation(slow_coro(), token) + elapsed = time.monotonic() - start + + assert elapsed < 0.5 # Should be cancelled quickly + + +@pytest.mark.asyncio +async def test_cancel_callback_invoked() -> None: + """Cancel callback should be invoked on cancellation.""" + token = CancellationToken() + callback_called = [] + + async def cancel_callback() -> None: + callback_called.append(True) + + async def slow_coro() -> str: + await asyncio.sleep(10) + return "result" + + async def cancel_after_delay() -> None: + await asyncio.sleep(0.1) + token.cancel() + + asyncio.create_task(cancel_after_delay()) + + with pytest.raises(asyncio.CancelledError): + await await_with_cancellation( + slow_coro(), token, cancel_callback=cancel_callback + ) + + assert callback_called == [True] + + +@pytest.mark.asyncio +async def test_sync_cancel_callback_invoked() -> None: + """Cancel callback should be invoked on cancellation.""" + token = CancellationToken() + callback_called = [] + + async def cancel_callback() -> None: + callback_called.append(True) + + async def slow_coro() -> str: + await asyncio.sleep(10) + return "result" + + async def cancel_after_delay() -> None: + await asyncio.sleep(0.1) + token.cancel() + + asyncio.create_task(cancel_after_delay()) + + with pytest.raises(asyncio.CancelledError): + await await_with_cancellation( + slow_coro(), token, cancel_callback=cancel_callback + ) + + assert callback_called == [True] + + +@pytest.mark.asyncio +async def test_cancel_callback_invoked_on_external_task_cancel() -> None: + """Cancel callback should be invoked if the awaiting task is cancelled externally.""" + token = CancellationToken() + callback_called = asyncio.Event() + + async def cancel_callback() -> None: + callback_called.set() + + async def slow_coro() -> str: + await asyncio.sleep(10) + return "result" + + task = asyncio.create_task( + await_with_cancellation(slow_coro(), token, cancel_callback=cancel_callback) + ) + + await asyncio.sleep(0.1) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + await asyncio.wait_for(callback_called.wait(), timeout=1.0) + + +@pytest.mark.asyncio +async def test_cancel_callback_not_invoked_on_success() -> None: + """Cancel callback should NOT be invoked when coroutine completes normally.""" + token = CancellationToken() + callback_called = [] + + async def cancel_callback() -> None: + callback_called.append(True) + + async def fast_coro() -> str: + await asyncio.sleep(0.01) + return "result" + + result = await await_with_cancellation( + fast_coro(), token, cancel_callback=cancel_callback + ) + + assert result == "result" + assert callback_called == [] + + +# CancellationReason + + +def test_all_reasons_exist() -> None: + """All expected cancellation reasons should exist.""" + assert CancellationReason.USER_REQUESTED.value == "user_requested" + assert CancellationReason.TIMEOUT.value == "timeout" + assert CancellationReason.PARENT_CANCELLED.value == "parent_cancelled" + assert CancellationReason.WORKFLOW_CANCELLED.value == "workflow_cancelled" + assert CancellationReason.TOKEN_CANCELLED.value == "token_cancelled" + + +def test_reasons_are_strings() -> None: + """Cancellation reason values should be strings.""" + for reason in CancellationReason: + assert isinstance(reason.value, str) + + +# CancelledError + + +def test_cancelled_error_is_base_exception() -> None: + """CancelledError should be a BaseException (not Exception).""" + err = CancelledError("test message") + assert isinstance(err, BaseException) + assert not isinstance(err, Exception) # Should NOT be caught by except Exception + assert str(err) == "test message" + + +def test_cancelled_error_not_caught_by_except_exception() -> None: + """CancelledError should NOT be caught by except Exception.""" + caught_by_exception = False + caught_by_cancelled_error = False + + try: + raise CancelledError("test") + except Exception: + caught_by_exception = True + except CancelledError: + caught_by_cancelled_error = True + + assert not caught_by_exception + assert caught_by_cancelled_error + + +def test_cancelled_error_with_reason() -> None: + """CancelledError should accept and store a reason.""" + err = CancelledError("test message", reason=CancellationReason.TIMEOUT) + assert err.reason == CancellationReason.TIMEOUT + + +def test_cancelled_error_reason_defaults_to_none() -> None: + """CancelledError reason should default to None.""" + err = CancelledError("test message") + assert err.reason is None + + +def test_cancelled_error_message_property() -> None: + """CancelledError should have a message property.""" + err = CancelledError("test message") + assert err.message == "test message" + + +def test_cancelled_error_default_message() -> None: + """CancelledError should have a default message.""" + err = CancelledError() + assert err.message == "Operation cancelled" + + +def test_can_be_raised_and_caught() -> None: + """CancelledError should be raisable and catchable.""" + with pytest.raises(CancelledError) as exc_info: + raise CancelledError("Operation cancelled") + + assert "Operation cancelled" in str(exc_info.value) + + +def test_can_be_raised_with_reason() -> None: + """CancelledError should be raisable with a reason.""" + with pytest.raises(CancelledError) as exc_info: + raise CancelledError( + "Parent was cancelled", reason=CancellationReason.PARENT_CANCELLED + ) + + assert exc_info.value.reason == CancellationReason.PARENT_CANCELLED + + +# Context var propagation + + +def test_context_var_default_is_none() -> None: + """ctx_cancellation_token should default to None.""" + assert ctx_cancellation_token.get() is None + + +def test_context_var_can_be_set_and_retrieved() -> None: + """ctx_cancellation_token should be settable and retrievable.""" + token = CancellationToken() + ctx_cancellation_token.set(token) + try: + assert ctx_cancellation_token.get() is token + finally: + ctx_cancellation_token.set(None) + + +@pytest.mark.asyncio +async def test_context_var_propagates_in_async() -> None: + """ctx_cancellation_token should propagate in async context.""" + token = CancellationToken() + ctx_cancellation_token.set(token) + + async def check_token() -> CancellationToken | None: + return ctx_cancellation_token.get() + + try: + retrieved = await check_token() + assert retrieved is token + finally: + ctx_cancellation_token.set(None)