diff --git a/examples/python/cancellation/worker.py b/examples/python/cancellation/worker.py
index 47d84f6cf2..bd31fe9fab 100644
--- a/examples/python/cancellation/worker.py
+++ b/examples/python/cancellation/worker.py
@@ -1,7 +1,7 @@
import asyncio
import time
-from hatchet_sdk import Context, EmptyModel, Hatchet
+from hatchet_sdk import CancellationReason, CancelledError, Context, EmptyModel, Hatchet
hatchet = Hatchet(debug=True)
@@ -40,6 +40,25 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:
+# > Handling cancelled error
+@cancellation_workflow.task()
+def my_task(input: EmptyModel, ctx: Context) -> dict:
+ try:
+ result = ctx.playground("test", "default")
+ except CancelledError as e:
+ # Handle parent cancellation - i.e. perform cleanup, then re-raise
+ print(f"Parent Task cancelled: {e.reason}")
+ # Always re-raise CancelledError so Hatchet can properly handle the cancellation
+ raise
+ except Exception as e:
+ # This will NOT catch CancelledError
+ print(f"Other error: {e}")
+ raise
+ return result
+
+
+
+
def main() -> None:
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
worker.start()
diff --git a/examples/python/simple/worker.py b/examples/python/simple/worker.py
index 85bb98a8ce..3dbb7d1cca 100644
--- a/examples/python/simple/worker.py
+++ b/examples/python/simple/worker.py
@@ -1,5 +1,5 @@
# > Simple
-from hatchet_sdk import Context, EmptyModel, Hatchet
+from hatchet_sdk import Context, DurableContext, EmptyModel, Hatchet
hatchet = Hatchet(debug=True)
@@ -10,7 +10,9 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]:
@hatchet.durable_task()
-def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
+async def simple_durable(input: EmptyModel, ctx: DurableContext) -> dict[str, str]:
+ res = await simple.aio_run(input)
+ print(res)
return {"result": "Hello, world!"}
diff --git a/frontend/docs/pages/home/cancellation.mdx b/frontend/docs/pages/home/cancellation.mdx
index 7fce04cd7e..88675a180b 100644
--- a/frontend/docs/pages/home/cancellation.mdx
+++ b/frontend/docs/pages/home/cancellation.mdx
@@ -22,15 +22,38 @@ When a task is canceled, Hatchet sends a cancellation signal to the task. The ta
/>
+### CancelledError Exception
+
+When a sync task is cancelled while waiting for a child workflow or during a cancellation-aware operation, a `CancelledError` exception is raised.
+
+
+ **Important:** `CancelledError` inherits from `BaseException`, not
+ `Exception`. This means it will **not** be caught by bare `except Exception:`
+ handlers. This is intentional and mirrors the behavior of Python's
+ `asyncio.CancelledError`.
+
+
+
+
+### Cancellation Reasons
+
+The `CancelledError` includes a `reason` attribute that indicates why the cancellation occurred:
+
+| Reason | Description |
+| --------------------------------------- | --------------------------------------------------------------------- |
+| `CancellationReason.USER_REQUESTED` | The user explicitly requested cancellation via `ctx.cancel()` |
+| `CancellationReason.WORKFLOW_CANCELLED` | The workflow run was cancelled (e.g., via API or concurrency control) |
+| `CancellationReason.PARENT_CANCELLED` | The parent workflow was cancelled while waiting for a child |
+| `CancellationReason.TIMEOUT` | The operation timed out |
+| `CancellationReason.UNKNOWN` | Unknown or unspecified reason |
+
diff --git a/sdks/python/CHANGELOG.md b/sdks/python/CHANGELOG.md
index 00f232a091..047f210a96 100644
--- a/sdks/python/CHANGELOG.md
+++ b/sdks/python/CHANGELOG.md
@@ -5,6 +5,22 @@ All notable changes to Hatchet's Python SDK will be documented in this changelog
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [1.25.0] - 2026-02-17
+
+### Added
+
+- Adds a `CancellationToken` class for coordinating cancellation across async and sync operations. The token provides both `asyncio.Event` and `threading.Event` primitives, and supports registering child workflow run IDs and callbacks.
+- Adds a `CancellationReason` enum with structured reasons for cancellation (`user_requested`, `timeout`, `parent_cancelled`, `workflow_cancelled`, `token_cancelled`).
+- Adds a `CancelledError` exception (inherits from `BaseException`, mirroring `asyncio.CancelledError`) for sync code paths.
+- Adds `cancellation_grace_period` and `cancellation_warning_threshold` configuration options to `ClientConfig` for controlling cancellation timing behavior.
+- Adds `await_with_cancellation` and `race_against_token` utility functions for racing awaitables against cancellation tokens.
+- The `Context` now exposes a `cancellation_token` property, allowing tasks to observe and react to cancellation signals directly.
+
+### Changed
+
+- The `Context.exit_flag` is now backed by a `CancellationToken` instead of a plain boolean. The property is maintained for backwards compatibility.
+- Durable context `aio_wait_for` now respects the cancellation token, raising `asyncio.CancelledError` if the task is cancelled while waiting.
+
## [1.24.0] - 2026-02-13
### Added
diff --git a/sdks/python/examples/cancellation/worker.py b/sdks/python/examples/cancellation/worker.py
index e758e6f59b..d49d4e760b 100644
--- a/sdks/python/examples/cancellation/worker.py
+++ b/sdks/python/examples/cancellation/worker.py
@@ -1,7 +1,7 @@
import asyncio
import time
-from hatchet_sdk import Context, EmptyModel, Hatchet
+from hatchet_sdk import CancelledError, Context, EmptyModel, Hatchet
hatchet = Hatchet(debug=True)
@@ -42,6 +42,26 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:
# !!
+# > Handling cancelled error
+@cancellation_workflow.task()
+async def my_task(input: EmptyModel, ctx: Context) -> dict[str, str]:
+ try:
+ await asyncio.sleep(10)
+ except CancelledError as e:
+ # Handle parent cancellation - i.e. perform cleanup, then re-raise
+ print(f"Parent Task cancelled: {e.reason}")
+ # Always re-raise CancelledError so Hatchet can properly handle the cancellation
+ raise
+ except Exception as e:
+ # This will NOT catch CancelledError
+ print(f"Other error: {e}")
+ raise
+ return {"error": "Task should have been cancelled"}
+
+
+# !!
+
+
def main() -> None:
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
worker.start()
diff --git a/sdks/python/examples/simple/worker.py b/sdks/python/examples/simple/worker.py
index 686742c4fb..f1082a9ba5 100644
--- a/sdks/python/examples/simple/worker.py
+++ b/sdks/python/examples/simple/worker.py
@@ -10,7 +10,7 @@ def simple(input: EmptyModel, ctx: Context) -> dict[str, str]:
@hatchet.durable_task()
-def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
+async def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
return {"result": "Hello, world!"}
diff --git a/sdks/python/hatchet_sdk/__init__.py b/sdks/python/hatchet_sdk/__init__.py
index 6fe497718a..d9dc46e41d 100644
--- a/sdks/python/hatchet_sdk/__init__.py
+++ b/sdks/python/hatchet_sdk/__init__.py
@@ -1,3 +1,4 @@
+from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.admin import (
RunStatus,
ScheduleTriggerWorkflowOptions,
@@ -155,6 +156,8 @@
WorkerLabelComparator,
)
from hatchet_sdk.exceptions import (
+ CancellationReason,
+ CancelledError,
DedupeViolationError,
FailedTaskRunExceptionGroup,
NonRetryableException,
@@ -194,6 +197,9 @@
"CELEvaluationResult",
"CELFailure",
"CELSuccess",
+ "CancellationReason",
+ "CancellationToken",
+ "CancelledError",
"ClientConfig",
"ClientTLSConfig",
"ConcurrencyExpression",
diff --git a/sdks/python/hatchet_sdk/cancellation.py b/sdks/python/hatchet_sdk/cancellation.py
new file mode 100644
index 0000000000..5998677695
--- /dev/null
+++ b/sdks/python/hatchet_sdk/cancellation.py
@@ -0,0 +1,197 @@
+"""Cancellation token for coordinating cancellation across async and sync operations."""
+
+from __future__ import annotations
+
+import asyncio
+import threading
+from collections.abc import Callable
+from typing import TYPE_CHECKING
+
+from hatchet_sdk.exceptions import CancellationReason
+from hatchet_sdk.logger import logger
+
+if TYPE_CHECKING:
+ pass
+
+
+class CancellationToken:
+ """
+ A token that can be used to signal cancellation across async and sync operations.
+
+ The token provides both asyncio and threading event primitives, allowing it to work
+ seamlessly in both async and sync code paths. Child workflow run IDs can be registered
+ with the token so they can be cancelled when the parent is cancelled.
+
+ Example:
+ ```python
+ token = CancellationToken()
+
+ # In async code
+ await token.aio_wait() # Blocks until cancelled
+
+ # In sync code
+ token.wait(timeout=1.0) # Returns True if cancelled within timeout
+
+ # Check if cancelled
+ if token.is_cancelled:
+ raise CancelledError("Operation was cancelled")
+
+ # Trigger cancellation
+ token.cancel()
+ ```
+ """
+
+ def __init__(self) -> None:
+ self._cancelled = False
+ self._reason: CancellationReason | None = None
+ self._async_event: asyncio.Event | None = None
+ self._sync_event = threading.Event()
+ self._child_run_ids: list[str] = []
+ self._callbacks: list[Callable[[], None]] = []
+ self._lock = threading.Lock()
+
+ def _get_async_event(self) -> asyncio.Event:
+ """Lazily create the asyncio event to avoid requiring an event loop at init time."""
+ if self._async_event is None:
+ self._async_event = asyncio.Event()
+ # If already cancelled, set the event
+ if self._cancelled:
+ self._async_event.set()
+ return self._async_event
+
+ def cancel(
+ self, reason: CancellationReason = CancellationReason.TOKEN_CANCELLED
+ ) -> None:
+ """
+ Trigger cancellation.
+
+ This will:
+ - Set the cancelled flag and reason
+ - Signal both async and sync events
+ - Invoke all registered callbacks
+
+ Args:
+ reason: The reason for cancellation.
+ """
+ with self._lock:
+ if self._cancelled:
+ logger.debug(
+ f"CancellationToken: cancel() called but already cancelled, "
+ f"reason={self._reason.value if self._reason else 'none'}"
+ )
+ return
+
+ logger.debug(
+ f"CancellationToken: cancel() called, reason={reason.value}, "
+ f"{len(self._child_run_ids)} children registered"
+ )
+
+ self._cancelled = True
+ self._reason = reason
+
+ # Signal both event types
+ if self._async_event is not None:
+ self._async_event.set()
+ self._sync_event.set()
+
+ # Snapshot callbacks under the lock, invoke outside to avoid deadlocks
+ callbacks = list(self._callbacks)
+
+ for callback in callbacks:
+ try:
+ logger.debug(f"CancellationToken: invoking callback {callback}")
+ callback()
+ except Exception as e: # noqa: PERF203
+ logger.warning(f"CancellationToken: callback raised exception: {e}")
+
+ logger.debug(f"CancellationToken: cancel() complete, reason={reason.value}")
+
+ @property
+ def is_cancelled(self) -> bool:
+ """Check if cancellation has been triggered."""
+ return self._cancelled
+
+ @property
+ def reason(self) -> CancellationReason | None:
+ """Get the reason for cancellation, or None if not cancelled."""
+ return self._reason
+
+ async def aio_wait(self) -> None:
+ """
+ Await until cancelled (for use in asyncio).
+
+ This will block until cancel() is called.
+ """
+ await self._get_async_event().wait()
+ logger.debug(
+ f"CancellationToken: async wait completed (cancelled), "
+ f"reason={self._reason.value if self._reason else 'none'}"
+ )
+
+ def wait(self, timeout: float | None = None) -> bool:
+ """
+ Block until cancelled (for use in sync code).
+
+ Args:
+ timeout: Maximum time to wait in seconds. None means wait forever.
+
+ Returns:
+ True if the token was cancelled (event was set), False if timeout expired.
+ """
+ result = self._sync_event.wait(timeout)
+ if result:
+ logger.debug(
+ f"CancellationToken: sync wait interrupted by cancellation, "
+ f"reason={self._reason.value if self._reason else 'none'}"
+ )
+ return result
+
+ def register_child(self, run_id: str) -> None:
+ """
+ Register a child workflow run ID with this token.
+
+ When the parent is cancelled, these child run IDs can be used to cancel
+ the child workflows as well.
+
+ Args:
+ run_id: The workflow run ID of the child workflow.
+ """
+ with self._lock:
+ logger.debug(f"CancellationToken: registering child workflow {run_id}")
+ self._child_run_ids.append(run_id)
+
+ @property
+ def child_run_ids(self) -> list[str]:
+ """The registered child workflow run IDs."""
+ return self._child_run_ids
+
+ def add_callback(self, callback: Callable[[], None]) -> None:
+ """
+ Register a callback to be invoked when cancellation is triggered.
+
+ If the token is already cancelled, the callback will be invoked immediately.
+
+ Args:
+ callback: A callable that takes no arguments.
+ """
+ with self._lock:
+ if self._cancelled:
+ invoke_now = True
+ else:
+ invoke_now = False
+ self._callbacks.append(callback)
+
+ if invoke_now:
+ logger.debug(
+ f"CancellationToken: invoking callback immediately (already cancelled): {callback}"
+ )
+ try:
+ callback()
+ except Exception as e:
+ logger.warning(f"CancellationToken: callback raised exception: {e}")
+
+ def __repr__(self) -> str:
+ return (
+ f"CancellationToken(cancelled={self._cancelled}, "
+ f"children={len(self._child_run_ids)}, callbacks={len(self._callbacks)})"
+ )
diff --git a/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py b/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py
index 8a99d8fdce..1d05ba77a5 100644
--- a/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py
+++ b/sdks/python/hatchet_sdk/clients/listeners/pooled_listener.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
-from typing import Generic, Literal, TypeVar
+from typing import TYPE_CHECKING, Generic, Literal, TypeVar
import grpc
import grpc.aio
@@ -14,6 +16,10 @@
from hatchet_sdk.config import ClientConfig
from hatchet_sdk.logger import logger
from hatchet_sdk.metadata import get_metadata
+from hatchet_sdk.utils.cancellation import race_against_token
+
+if TYPE_CHECKING:
+ from hatchet_sdk.cancellation import CancellationToken
DEFAULT_LISTENER_RETRY_INTERVAL = 3 # seconds
DEFAULT_LISTENER_RETRY_COUNT = 5
@@ -36,7 +42,7 @@ def __init__(self, id: int) -> None:
self.id = id
self.queue: asyncio.Queue[T | SentinelValue] = asyncio.Queue()
- async def __aiter__(self) -> "Subscription[T]":
+ async def __aiter__(self) -> Subscription[T]:
return self
async def __anext__(self) -> T | SentinelValue:
@@ -199,7 +205,17 @@ def cleanup_subscription(self, subscription_id: int) -> None:
del self.from_subscriptions[subscription_id]
del self.events[subscription_id]
- async def subscribe(self, id: str) -> T:
+ async def subscribe(
+ self, id: str, cancellation_token: CancellationToken | None = None
+ ) -> T:
+ """
+ Subscribe to events for the given ID.
+
+ :param id: The ID to subscribe to (e.g., workflow run ID).
+ :param cancellation_token: Optional cancellation token to abort the subscription wait.
+ :return: The event received for this ID.
+ :raises asyncio.CancelledError: If the cancellation token is triggered or if externally cancelled.
+ """
subscription_id: int | None = None
try:
@@ -221,8 +237,17 @@ async def subscribe(self, id: str) -> T:
if not self.listener_task or self.listener_task.done():
self.listener_task = asyncio.create_task(self._init_producer())
+ logger.debug(
+ f"PooledListener.subscribe: waiting for event on id={id}, "
+ f"subscription_id={subscription_id}, token={cancellation_token is not None}"
+ )
+
+ if cancellation_token:
+ result_task = asyncio.create_task(self.events[subscription_id].get())
+ return await race_against_token(result_task, cancellation_token)
return await self.events[subscription_id].get()
except asyncio.CancelledError:
+ logger.debug(f"PooledListener.subscribe: externally cancelled for id={id}")
raise
finally:
if subscription_id:
diff --git a/sdks/python/hatchet_sdk/config.py b/sdks/python/hatchet_sdk/config.py
index e49d2c9eac..1d315bff23 100644
--- a/sdks/python/hatchet_sdk/config.py
+++ b/sdks/python/hatchet_sdk/config.py
@@ -52,7 +52,7 @@ def validate_event_loop_block_threshold_seconds(
if isinstance(value, timedelta):
return value
- if isinstance(value, int | float):
+ if isinstance(value, (int, float)):
return timedelta(seconds=float(value))
v = value.strip()
@@ -135,6 +135,37 @@ class ClientConfig(BaseSettings):
force_shutdown_on_shutdown_signal: bool = False
tenacity: TenacityConfig = TenacityConfig()
+ # Cancellation configuration
+ cancellation_grace_period: timedelta = Field(
+ default=timedelta(milliseconds=1000),
+ description="The maximum time to wait for a task to complete after cancellation is triggered before force-cancelling. Value is interpreted as seconds when provided as int/float.",
+ )
+ cancellation_warning_threshold: timedelta = Field(
+ default=timedelta(milliseconds=300),
+ description="If a task has not completed cancellation within this duration, a warning will be logged. Value is interpreted as seconds when provided as int/float.",
+ )
+
+ @field_validator(
+ "cancellation_grace_period", "cancellation_warning_threshold", mode="before"
+ )
+ @classmethod
+ def validate_cancellation_timedelta(
+ cls, value: timedelta | int | float | str
+ ) -> timedelta:
+ """Convert int/float/string to timedelta, interpreting as seconds."""
+ if isinstance(value, timedelta):
+ return value
+
+ if isinstance(value, (int, float)):
+ return timedelta(seconds=float(value))
+
+ v = value.strip()
+ # Allow a small convenience suffix, but keep "seconds" as the contract.
+ if v.endswith("s"):
+ v = v[:-1].strip()
+
+ return timedelta(seconds=float(v))
+
@model_validator(mode="after")
def validate_token_and_tenant(self) -> "ClientConfig":
if not self.token:
diff --git a/sdks/python/hatchet_sdk/context/context.py b/sdks/python/hatchet_sdk/context/context.py
index 549d7dc7f0..01d83d6f87 100644
--- a/sdks/python/hatchet_sdk/context/context.py
+++ b/sdks/python/hatchet_sdk/context/context.py
@@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any, cast
from warnings import warn
+from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.admin import AdminClient
from hatchet_sdk.clients.dispatcher.dispatcher import ( # type: ignore[attr-defined]
Action,
@@ -21,9 +22,10 @@
flatten_conditions,
)
from hatchet_sdk.context.worker_context import WorkerContext
-from hatchet_sdk.exceptions import TaskRunError
+from hatchet_sdk.exceptions import CancellationReason, TaskRunError
from hatchet_sdk.features.runs import RunsClient
from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.cancellation import await_with_cancellation
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
from hatchet_sdk.utils.typing import JSONSerializableMapping, LogLevel
from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, LogRecord
@@ -56,7 +58,7 @@ def __init__(
self.action = action
self.step_run_id = action.step_run_id
- self.exit_flag = False
+ self.cancellation_token = CancellationToken()
self.dispatcher_client = dispatcher_client
self.admin_client = admin_client
self.event_client = event_client
@@ -74,6 +76,31 @@ def __init__(
self._workflow_name = workflow_name
self._task_name = task_name
+ @property
+ def exit_flag(self) -> bool:
+ """
+ Check if the cancellation flag has been set.
+
+ This property is maintained for backwards compatibility.
+ Use `cancellation_token.is_cancelled` for new code.
+
+ :return: True if the task has been cancelled, False otherwise.
+ """
+ return self.cancellation_token.is_cancelled
+
+ @exit_flag.setter
+ def exit_flag(self, value: bool) -> None:
+ """
+ Set the cancellation flag.
+
+ This setter is maintained for backwards compatibility.
+ Setting to True will trigger the cancellation token.
+
+ :param value: True to trigger cancellation, False is a no-op.
+ """
+ if value:
+ self.cancellation_token.cancel(CancellationReason.USER_REQUESTED)
+
def _increment_stream_index(self) -> int:
index = self.stream_index
self.stream_index += 1
@@ -169,8 +196,25 @@ def workflow_run_id(self) -> str:
"""
return self.action.workflow_run_id
- def _set_cancellation_flag(self) -> None:
- self.exit_flag = True
+ def _set_cancellation_flag(
+ self, reason: CancellationReason = CancellationReason.WORKFLOW_CANCELLED
+ ) -> None:
+ """
+ Internal method to trigger cancellation.
+
+ This triggers the cancellation token, which will:
+ - Signal all waiters (async and sync)
+ - Set the exit_flag property to True
+ - Allow child workflow cancellation
+
+ Args:
+ reason: The reason for cancellation.
+ """
+ logger.debug(
+ f"Context: setting cancellation flag for step_run_id={self.step_run_id}, "
+ f"reason={reason.value}"
+ )
+ self.cancellation_token.cancel(reason)
def cancel(self) -> None:
"""
@@ -178,9 +222,11 @@ def cancel(self) -> None:
:return: None
"""
- logger.debug("cancelling step...")
+ logger.debug(
+ f"Context: cancel() called for task_run_external_id={self.step_run_id}"
+ )
self.runs_client.cancel(self.step_run_id)
- self._set_cancellation_flag()
+ self._set_cancellation_flag(CancellationReason.USER_REQUESTED)
async def aio_cancel(self) -> None:
"""
@@ -188,9 +234,11 @@ async def aio_cancel(self) -> None:
:return: None
"""
- logger.debug("cancelling step...")
+ logger.debug(
+ f"Context: aio_cancel() called for task_run_external_id={self.step_run_id}"
+ )
await self.runs_client.aio_cancel(self.step_run_id)
- self._set_cancellation_flag()
+ self._set_cancellation_flag(CancellationReason.USER_REQUESTED)
def done(self) -> bool:
"""
@@ -482,8 +530,11 @@ async def aio_wait_for(
"""
Durably wait for either a sleep or an event.
+ This method respects the context's cancellation token. If the task is cancelled
+ while waiting, an asyncio.CancelledError will be raised.
+
:param signal_key: The key to use for the durable event. This is used to identify the event in the Hatchet API.
- :param *conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition.
+ :param \\*conditions: The conditions to wait for. Can be a SleepCondition or UserEventCondition.
:return: A dictionary containing the results of the wait.
:raises ValueError: If the durable event listener is not available.
@@ -493,6 +544,10 @@ async def aio_wait_for(
task_id = self.step_run_id
+ logger.debug(
+ f"DurableContext.aio_wait_for: waiting for signal_key={signal_key}, task_id={task_id}"
+ )
+
request = RegisterDurableEventRequest(
task_id=task_id,
signal_key=signal_key,
@@ -502,20 +557,30 @@ async def aio_wait_for(
self.durable_event_listener.register_durable_event(request)
- return await self.durable_event_listener.result(
- task_id,
- signal_key,
+ # Use await_with_cancellation to respect the cancellation token
+ return await await_with_cancellation(
+ self.durable_event_listener.result(task_id, signal_key),
+ self.cancellation_token,
)
async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]:
"""
Lightweight wrapper for durable sleep. Allows for shorthand usage of `ctx.aio_wait_for` when specifying a sleep condition.
+ This method respects the context's cancellation token. If the task is cancelled
+ while sleeping, an asyncio.CancelledError will be raised.
+
For more complicated conditions, use `ctx.aio_wait_for` directly.
- """
+ :param duration: The duration to sleep for.
+ :return: A dictionary containing the results of the wait.
+ """
wait_index = self._increment_wait_index()
+ logger.debug(
+ f"DurableContext.aio_sleep_for: sleeping for {duration}, wait_index={wait_index}"
+ )
+
return await self.aio_wait_for(
f"sleep:{timedelta_to_expr(duration)}-{wait_index}",
SleepCondition(duration=duration),
diff --git a/sdks/python/hatchet_sdk/exceptions.py b/sdks/python/hatchet_sdk/exceptions.py
index 3ecc0c3e66..f13a7d35b9 100644
--- a/sdks/python/hatchet_sdk/exceptions.py
+++ b/sdks/python/hatchet_sdk/exceptions.py
@@ -1,5 +1,6 @@
import json
import traceback
+from enum import Enum
from typing import cast
@@ -170,3 +171,54 @@ class IllegalTaskOutputError(Exception):
class LifespanSetupError(Exception):
pass
+
+
+class CancellationReason(Enum):
+ """Reason for cancellation of an operation."""
+
+ USER_REQUESTED = "user_requested"
+ """The user explicitly requested cancellation."""
+
+ TIMEOUT = "timeout"
+ """The operation timed out."""
+
+ PARENT_CANCELLED = "parent_cancelled"
+ """The parent workflow or task was cancelled."""
+
+ WORKFLOW_CANCELLED = "workflow_cancelled"
+ """The workflow run was cancelled."""
+
+ TOKEN_CANCELLED = "token_cancelled"
+ """The cancellation token was cancelled."""
+
+
+class CancelledError(BaseException):
+ """
+ Raised when an operation is cancelled via CancellationToken.
+
+ This exception inherits from BaseException (not Exception) so that it
+ won't be caught by bare `except Exception:` handlers. This mirrors the
+ behavior of asyncio.CancelledError in Python 3.8+.
+
+ To catch this exception, use:
+ - `except CancelledError:` (recommended)
+ - `except BaseException:` (catches all exceptions)
+
+ This exception is used for sync code paths. For async code paths,
+ asyncio.CancelledError is used instead.
+
+ :param message: Optional message describing the cancellation.
+ :param reason: Optional enum indicating the reason for cancellation.
+ """
+
+ def __init__(
+ self,
+ message: str = "Operation cancelled",
+ reason: CancellationReason | None = None,
+ ) -> None:
+ self.reason = reason
+ super().__init__(message)
+
+ @property
+ def message(self) -> str:
+ return str(self.args[0]) if self.args else "Operation cancelled"
diff --git a/sdks/python/hatchet_sdk/runnables/contextvars.py b/sdks/python/hatchet_sdk/runnables/contextvars.py
index 0d3c9d4904..dc0c9b2244 100644
--- a/sdks/python/hatchet_sdk/runnables/contextvars.py
+++ b/sdks/python/hatchet_sdk/runnables/contextvars.py
@@ -1,11 +1,17 @@
+from __future__ import annotations
+
import asyncio
import threading
from collections import Counter
from contextvars import ContextVar
+from typing import TYPE_CHECKING
from hatchet_sdk.runnables.action import ActionKey
from hatchet_sdk.utils.typing import JSONSerializableMapping
+if TYPE_CHECKING:
+ from hatchet_sdk.cancellation import CancellationToken
+
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
"ctx_workflow_run_id", default=None
)
@@ -20,6 +26,9 @@
ctx_task_retry_count: ContextVar[int | None] = ContextVar(
"ctx_task_retry_count", default=0
)
+ctx_cancellation_token: ContextVar[CancellationToken | None] = ContextVar(
+ "ctx_cancellation_token", default=None
+)
workflow_spawn_indices = Counter[ActionKey]()
spawn_index_lock = asyncio.Lock()
diff --git a/sdks/python/hatchet_sdk/runnables/workflow.py b/sdks/python/hatchet_sdk/runnables/workflow.py
index 6eed17bbe7..927375db17 100644
--- a/sdks/python/hatchet_sdk/runnables/workflow.py
+++ b/sdks/python/hatchet_sdk/runnables/workflow.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import json
from collections.abc import Callable
@@ -37,8 +39,11 @@
)
from hatchet_sdk.contracts.v1.workflows_pb2 import StickyStrategy as StickyStrategyProto
from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
+from hatchet_sdk.exceptions import CancellationReason, CancelledError
from hatchet_sdk.labels import DesiredWorkerLabel
+from hatchet_sdk.logger import logger
from hatchet_sdk.rate_limit import RateLimit
+from hatchet_sdk.runnables.contextvars import ctx_cancellation_token
from hatchet_sdk.runnables.task import Task
from hatchet_sdk.runnables.types import (
ConcurrencyExpression,
@@ -52,6 +57,7 @@
normalize_validator,
)
from hatchet_sdk.serde import HATCHET_PYDANTIC_SENTINEL
+from hatchet_sdk.utils.cancellation import await_with_cancellation
from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
from hatchet_sdk.utils.timedelta_to_expression import Duration
from hatchet_sdk.utils.typing import CoroutineLike, JSONSerializableMapping
@@ -59,6 +65,7 @@
if TYPE_CHECKING:
from hatchet_sdk import Hatchet
+ from hatchet_sdk.cancellation import CancellationToken
T = TypeVar("T")
@@ -88,7 +95,7 @@ class ComputedTaskParameters(BaseModel):
task_defaults: TaskDefaults
@model_validator(mode="after")
- def validate_params(self) -> "ComputedTaskParameters":
+ def validate_params(self) -> ComputedTaskParameters:
self.execution_timeout = fall_back_to_default(
value=self.execution_timeout,
param_default=timedelta(seconds=60),
@@ -136,7 +143,7 @@ class TypedTriggerWorkflowRunConfig(BaseModel, Generic[TWorkflowInput]):
class BaseWorkflow(Generic[TWorkflowInput]):
- def __init__(self, config: WorkflowConfig, client: "Hatchet") -> None:
+ def __init__(self, config: WorkflowConfig, client: Hatchet) -> None:
self.config = config
self._default_tasks: list[Task[TWorkflowInput, Any]] = []
self._durable_tasks: list[Task[TWorkflowInput, Any]] = []
@@ -625,6 +632,38 @@ def greet(input, ctx):
and can be arranged into complex dependency patterns.
"""
+ def _resolve_check_cancellation_token(self) -> CancellationToken | None:
+ cancellation_token = ctx_cancellation_token.get()
+
+ if cancellation_token and cancellation_token.is_cancelled:
+ raise CancelledError(
+ "Operation cancelled by cancellation token",
+ reason=CancellationReason.TOKEN_CANCELLED,
+ )
+
+ return cancellation_token
+
+ def _register_child_with_token(
+ self,
+ cancellation_token: CancellationToken | None,
+ workflow_run_id: str,
+ ) -> None:
+ if not cancellation_token:
+ return
+
+ cancellation_token.register_child(workflow_run_id)
+
+ def _register_children_with_token(
+ self,
+ cancellation_token: CancellationToken | None,
+ refs: list[WorkflowRunRef],
+ ) -> None:
+ if not cancellation_token:
+ return
+
+ for ref in refs:
+ cancellation_token.register_child(ref.workflow_run_id)
+
def run_no_wait(
self,
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
@@ -634,17 +673,34 @@ def run_no_wait(
Synchronously trigger a workflow run without waiting for it to complete.
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
+ If a cancellation token is available via context, the child workflow will be registered
+ with the token.
+
:param input: The input data for the workflow.
:param options: Additional options for workflow execution.
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
"""
- return self.client._client.admin.run_workflow(
+ cancellation_token = self._resolve_check_cancellation_token()
+
+ logger.debug(
+ f"Workflow.run_no_wait: triggering {self.config.name}, "
+ f"token={cancellation_token is not None}"
+ )
+
+ ref = self.client._client.admin.run_workflow(
workflow_name=self.config.name,
input=self._serialize_input(input),
options=self._create_options_with_combined_additional_meta(options),
)
+ self._register_child_with_token(
+ cancellation_token,
+ ref.workflow_run_id,
+ )
+
+ return ref
+
def run(
self,
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
@@ -654,12 +710,19 @@ def run(
Run the workflow synchronously and wait for it to complete.
This method triggers a workflow run, blocks until completion, and returns the final result.
+ If a cancellation token is available via context, the wait can be interrupted.
:param input: The input data for the workflow, must match the workflow's input type.
:param options: Additional options for workflow execution like metadata and parent workflow ID.
:returns: The result of the workflow execution as a dictionary.
"""
+ cancellation_token = self._resolve_check_cancellation_token()
+
+ logger.debug(
+ f"Workflow.run: triggering {self.config.name}, "
+ f"token={cancellation_token is not None}"
+ )
ref = self.client._client.admin.run_workflow(
workflow_name=self.config.name,
@@ -667,7 +730,14 @@ def run(
options=self._create_options_with_combined_additional_meta(options),
)
- return ref.result()
+ self._register_child_with_token(
+ cancellation_token,
+ ref.workflow_run_id,
+ )
+
+ logger.debug(f"Workflow.run: awaiting result for {ref.workflow_run_id}")
+
+ return ref.result(cancellation_token=cancellation_token)
async def aio_run_no_wait(
self,
@@ -678,18 +748,34 @@ async def aio_run_no_wait(
Asynchronously trigger a workflow run without waiting for it to complete.
This method is useful for starting a workflow run and immediately returning a reference to the run without blocking while the workflow runs.
+ If a cancellation token is available via context, the child workflow will be registered
+ with the token.
+
:param input: The input data for the workflow.
:param options: Additional options for workflow execution.
:returns: A `WorkflowRunRef` object representing the reference to the workflow run.
"""
+ cancellation_token = self._resolve_check_cancellation_token()
- return await self.client._client.admin.aio_run_workflow(
+ logger.debug(
+ f"Workflow.aio_run_no_wait: triggering {self.config.name}, "
+ f"token={cancellation_token is not None}"
+ )
+
+ ref = await self.client._client.admin.aio_run_workflow(
workflow_name=self.config.name,
input=self._serialize_input(input),
options=self._create_options_with_combined_additional_meta(options),
)
+ self._register_child_with_token(
+ cancellation_token,
+ ref.workflow_run_id,
+ )
+
+ return ref
+
async def aio_run(
self,
input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
@@ -699,25 +785,47 @@ async def aio_run(
Run the workflow asynchronously and wait for it to complete.
This method triggers a workflow run, awaits until completion, and returns the final result.
+ If a cancellation token is available via context, the wait can be interrupted.
:param input: The input data for the workflow, must match the workflow's input type.
:param options: Additional options for workflow execution like metadata and parent workflow ID.
:returns: The result of the workflow execution as a dictionary.
"""
+ cancellation_token = self._resolve_check_cancellation_token()
+
+ logger.debug(
+ f"Workflow.aio_run: triggering {self.config.name}, "
+ f"token={cancellation_token is not None}"
+ )
+
ref = await self.client._client.admin.aio_run_workflow(
workflow_name=self.config.name,
input=self._serialize_input(input),
options=self._create_options_with_combined_additional_meta(options),
)
- return await ref.aio_result()
+ self._register_child_with_token(
+ cancellation_token,
+ ref.workflow_run_id,
+ )
+
+ logger.debug(f"Workflow.aio_run: awaiting result for {ref.workflow_run_id}")
+
+ return await await_with_cancellation(
+ ref.aio_result(),
+ cancellation_token,
+ )
def _get_result(
- self, ref: WorkflowRunRef, return_exceptions: bool
+ self,
+ ref: WorkflowRunRef,
+ return_exceptions: bool,
) -> dict[str, Any] | BaseException:
try:
- return ref.result()
+ return ref.result(
+ cancellation_token=self._resolve_check_cancellation_token()
+ )
except Exception as e:
if return_exceptions:
return e
@@ -746,15 +854,52 @@ def run_many(
Run a workflow in bulk and wait for all runs to complete.
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
+ If a cancellation token is available via context, all child workflows will be registered
+ with the token and the wait can be interrupted.
+
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
:returns: A list of results for each workflow run.
+ :raises CancelledError: If the cancellation token is triggered (and return_exceptions is False).
+ :raises Exception: If a workflow run fails (and return_exceptions is False).
"""
+ cancellation_token = self._resolve_check_cancellation_token()
+
refs = self.client._client.admin.run_workflows(
workflows=workflows,
)
- return [self._get_result(ref, return_exceptions) for ref in refs]
+ self._register_children_with_token(
+ cancellation_token,
+ refs,
+ )
+
+ # Pass cancellation_token through to each result() call
+ # The cancellation check happens INSIDE result()'s polling loop
+ results: list[dict[str, Any] | BaseException] = []
+ for ref in refs:
+ try:
+ results.append(ref.result(cancellation_token=cancellation_token))
+ except CancelledError: # noqa: PERF203
+ logger.debug(
+ f"Workflow.run_many: cancellation detected, stopping wait, "
+ f"reason={CancellationReason.PARENT_CANCELLED.value}"
+ )
+ if return_exceptions:
+ results.append(
+ CancelledError(
+ "Operation cancelled by cancellation token",
+ reason=CancellationReason.PARENT_CANCELLED,
+ )
+ )
+ break
+ raise
+ except Exception as e:
+ if return_exceptions:
+ results.append(e)
+ else:
+ raise
+ return results
@overload
async def aio_run_many(
@@ -779,16 +924,34 @@ async def aio_run_many(
Run a workflow in bulk and wait for all runs to complete.
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
+ If a cancellation token is available via context, all child workflows will be registered
+ with the token and the wait can be interrupted.
+
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
:returns: A list of results for each workflow run.
"""
+ cancellation_token = self._resolve_check_cancellation_token()
+
+ logger.debug(
+ f"Workflow.aio_run_many: triggering {len(workflows)} workflows, "
+ f"token={cancellation_token is not None}"
+ )
+
refs = await self.client._client.admin.aio_run_workflows(
workflows=workflows,
)
- return await asyncio.gather(
- *[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
+ self._register_children_with_token(
+ cancellation_token,
+ refs,
+ )
+
+ return await await_with_cancellation(
+ asyncio.gather(
+ *[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
+ ),
+ cancellation_token,
)
def run_many_no_wait(
@@ -800,13 +963,30 @@ def run_many_no_wait(
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
+ If a cancellation token is available via context, all child workflows will be registered
+ with the token.
+
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
"""
- return self.client._client.admin.run_workflows(
+ cancellation_token = self._resolve_check_cancellation_token()
+
+ logger.debug(
+ f"Workflow.run_many_no_wait: triggering {len(workflows)} workflows, "
+ f"token={cancellation_token is not None}"
+ )
+
+ refs = self.client._client.admin.run_workflows(
workflows=workflows,
)
+ self._register_children_with_token(
+ cancellation_token,
+ refs,
+ )
+
+ return refs
+
async def aio_run_many_no_wait(
self,
workflows: list[WorkflowRunTriggerConfig],
@@ -816,14 +996,31 @@ async def aio_run_many_no_wait(
This method triggers multiple workflow runs and immediately returns a list of references to the runs without blocking while the workflows run.
+ If a cancellation token is available via context, all child workflows will be registered
+ with the token.
+
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
:returns: A list of `WorkflowRunRef` objects, each representing a reference to a workflow run.
"""
- return await self.client._client.admin.aio_run_workflows(
+ cancellation_token = self._resolve_check_cancellation_token()
+
+ logger.debug(
+ f"Workflow.aio_run_many_no_wait: triggering {len(workflows)} workflows, "
+ f"token={cancellation_token is not None}"
+ )
+
+ refs = await self.client._client.admin.aio_run_workflows(
workflows=workflows,
)
+ self._register_children_with_token(
+ cancellation_token,
+ refs,
+ )
+
+ return refs
+
def _parse_task_name(
self,
name: str | None,
@@ -1168,7 +1365,7 @@ def inner(
return inner
- def add_task(self, task: "Standalone[TWorkflowInput, Any]") -> None:
+ def add_task(self, task: Standalone[TWorkflowInput, Any]) -> None:
"""
Add a task to a workflow. Intended to be used with a previously existing task (a Standalone),
such as one created with `@hatchet.task()`, which has been converted to a `Task` object using `to_task`.
@@ -1207,7 +1404,7 @@ def my_task(input, ctx) -> None:
class TaskRunRef(Generic[TWorkflowInput, R]):
def __init__(
self,
- standalone: "Standalone[TWorkflowInput, R]",
+ standalone: Standalone[TWorkflowInput, R],
workflow_run_ref: WorkflowRunRef,
):
self._s = standalone
@@ -1366,7 +1563,9 @@ def run_many(
) -> list[R]: ...
def run_many(
- self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
+ self,
+ workflows: list[WorkflowRunTriggerConfig],
+ return_exceptions: bool = False,
) -> list[R] | list[R | BaseException]:
"""
Run a workflow in bulk and wait for all runs to complete.
@@ -1400,7 +1599,9 @@ async def aio_run_many(
) -> list[R]: ...
async def aio_run_many(
- self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
+ self,
+ workflows: list[WorkflowRunTriggerConfig],
+ return_exceptions: bool = False,
) -> list[R] | list[R | BaseException]:
"""
Run a workflow in bulk and wait for all runs to complete.
@@ -1420,7 +1621,8 @@ async def aio_run_many(
]
def run_many_no_wait(
- self, workflows: list[WorkflowRunTriggerConfig]
+ self,
+ workflows: list[WorkflowRunTriggerConfig],
) -> list[TaskRunRef[TWorkflowInput, R]]:
"""
Run a workflow in bulk without waiting for all runs to complete.
@@ -1435,7 +1637,8 @@ def run_many_no_wait(
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
async def aio_run_many_no_wait(
- self, workflows: list[WorkflowRunTriggerConfig]
+ self,
+ workflows: list[WorkflowRunTriggerConfig],
) -> list[TaskRunRef[TWorkflowInput, R]]:
"""
Run a workflow in bulk without waiting for all runs to complete.
diff --git a/sdks/python/hatchet_sdk/utils/cancellation.py b/sdks/python/hatchet_sdk/utils/cancellation.py
new file mode 100644
index 0000000000..eefe7a4d6d
--- /dev/null
+++ b/sdks/python/hatchet_sdk/utils/cancellation.py
@@ -0,0 +1,149 @@
+"""Utilities for cancellation-aware operations."""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+from collections.abc import Awaitable, Callable
+from typing import TYPE_CHECKING, TypeVar
+
+from hatchet_sdk.logger import logger
+
+if TYPE_CHECKING:
+ from hatchet_sdk.cancellation import CancellationToken
+
+T = TypeVar("T")
+
+
+async def _invoke_cancel_callback(
+ cancel_callback: Callable[[], Awaitable[None]] | None,
+) -> None:
+ """Invoke a cancel callback."""
+ if not cancel_callback:
+ return
+
+ await cancel_callback()
+
+
+async def race_against_token(
+ main_task: asyncio.Task[T],
+ token: CancellationToken,
+) -> T:
+ """
+ Race an asyncio task against a cancellation token.
+
+ Waits for either the task to complete or the token to be cancelled. Cleans up
+ whichever side loses the race.
+
+ Args:
+ main_task: The asyncio task to race.
+ token: The cancellation token to race against.
+
+ Returns:
+ The result of the main task if it completes first.
+
+ Raises:
+ asyncio.CancelledError: If the token fires before the task completes.
+ """
+ cancel_task = asyncio.create_task(token.aio_wait())
+
+ try:
+ done, pending = await asyncio.wait(
+ [main_task, cancel_task],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+
+ # Cancel pending tasks
+ for task in pending:
+ task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await task
+
+ if cancel_task in done:
+ raise asyncio.CancelledError("Operation cancelled by cancellation token")
+
+ return main_task.result()
+
+ except asyncio.CancelledError:
+ # Ensure both tasks are cleaned up on any cancellation (external or token)
+ main_task.cancel()
+ cancel_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await main_task
+ with contextlib.suppress(asyncio.CancelledError):
+ await cancel_task
+ raise
+
+
+async def await_with_cancellation(
+ coro: Awaitable[T],
+ token: CancellationToken | None,
+ cancel_callback: Callable[[], Awaitable[None]] | None = None,
+) -> T:
+ """
+ Await an awaitable with cancellation support.
+
+ This function races the given awaitable against a cancellation token. If the
+ token is cancelled before the awaitable completes, the awaitable is cancelled
+ and an asyncio.CancelledError is raised.
+
+ Args:
+ coro: The awaitable to await (coroutine, Future, or asyncio.Task).
+ token: The cancellation token to check. If None, the coroutine is awaited directly.
+ cancel_callback: An optional async callback to invoke when cancellation occurs
+ (e.g., to cancel child workflows).
+
+ Returns:
+ The result of the coroutine.
+
+ Raises:
+ asyncio.CancelledError: If the token is cancelled before the coroutine completes.
+
+ Example:
+ ```python
+ async def cleanup() -> None:
+ print("cleaning up...")
+
+ async def long_running_task():
+ await asyncio.sleep(10)
+ return "done"
+
+ token = CancellationToken()
+
+ # This will raise asyncio.CancelledError if token.cancel() is called
+ result = await await_with_cancellation(
+ long_running_task(),
+ token,
+ cancel_callback=cleanup,
+ )
+ ```
+ """
+
+ if token is None:
+ logger.debug("await_with_cancellation: no token provided, awaiting directly")
+ return await coro
+
+ logger.debug("await_with_cancellation: starting with cancellation token")
+
+ # Check if already cancelled
+ if token.is_cancelled:
+ logger.debug("await_with_cancellation: token already cancelled")
+ if cancel_callback:
+ logger.debug("await_with_cancellation: invoking cancel callback")
+ await _invoke_cancel_callback(cancel_callback)
+ raise asyncio.CancelledError("Operation cancelled by cancellation token")
+
+ main_task = asyncio.ensure_future(coro)
+
+ try:
+ result = await race_against_token(main_task, token)
+ logger.debug("await_with_cancellation: completed successfully")
+ return result
+
+ except asyncio.CancelledError:
+ logger.debug("await_with_cancellation: cancelled")
+ if cancel_callback:
+ logger.debug("await_with_cancellation: invoking cancel callback")
+ with contextlib.suppress(asyncio.CancelledError):
+ await asyncio.shield(_invoke_cancel_callback(cancel_callback))
+ raise
diff --git a/sdks/python/hatchet_sdk/worker/runner/runner.py b/sdks/python/hatchet_sdk/worker/runner/runner.py
index d853858b40..fde8fba734 100644
--- a/sdks/python/hatchet_sdk/worker/runner/runner.py
+++ b/sdks/python/hatchet_sdk/worker/runner/runner.py
@@ -2,6 +2,7 @@
import ctypes
import functools
import json
+import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, is_dataclass
@@ -29,6 +30,7 @@
STEP_EVENT_TYPE_STARTED,
)
from hatchet_sdk.exceptions import (
+ CancellationReason,
IllegalTaskOutputError,
NonRetryableException,
TaskRunError,
@@ -39,6 +41,7 @@
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_additional_metadata,
+ ctx_cancellation_token,
ctx_step_run_id,
ctx_task_retry_count,
ctx_worker_id,
@@ -60,6 +63,7 @@
ContextVarToCopyDict,
ContextVarToCopyInt,
ContextVarToCopyStr,
+ ContextVarToCopyToken,
copy_context_vars,
)
@@ -251,6 +255,7 @@ async def async_wrapped_action_func(
ctx_action_key.set(action.key)
ctx_additional_metadata.set(action.additional_metadata)
ctx_task_retry_count.set(action.retry_count)
+ ctx_cancellation_token.set(ctx.cancellation_token)
async with task._unpack_dependencies_with_cleanup(ctx) as dependencies:
try:
@@ -298,6 +303,12 @@ async def async_wrapped_action_func(
value=action.retry_count,
)
),
+ ContextVarToCopy(
+ var=ContextVarToCopyToken(
+ name="ctx_cancellation_token",
+ value=ctx.cancellation_token,
+ )
+ ),
],
self.thread_action_func,
ctx,
@@ -480,28 +491,95 @@ def force_kill_thread(self, thread: Thread) -> None:
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
async def handle_cancel_action(self, action: Action) -> None:
key = action.key
+ start_time = time.monotonic()
+
+ logger.info(
+ f"Cancellation: received cancel action for {action.action_id}, "
+ f"reason={CancellationReason.WORKFLOW_CANCELLED.value}"
+ )
+
try:
- # call cancel to signal the context to stop
+ # Trigger the cancellation token to signal the context to stop
if key in self.contexts:
- self.contexts[key]._set_cancellation_flag()
+ ctx = self.contexts[key]
+ child_count = len(ctx.cancellation_token.child_run_ids)
+ logger.debug(
+ f"Cancellation: triggering token for {action.action_id}, "
+ f"reason={CancellationReason.WORKFLOW_CANCELLED.value}, "
+ f"{child_count} children registered"
+ )
+ ctx._set_cancellation_flag(CancellationReason.WORKFLOW_CANCELLED)
self.cancellations[key] = True
+ # Note: Child workflows are not cancelled here - they run independently
+ # and are managed by Hatchet's normal cancellation mechanisms
+ else:
+ logger.debug(f"Cancellation: no context found for {action.action_id}")
+
+ # Wait with supervision (using timedelta configs)
+ grace_period = self.config.cancellation_grace_period.total_seconds()
+ warning_threshold = (
+ self.config.cancellation_warning_threshold.total_seconds()
+ )
+ grace_period_ms = round(grace_period * 1000)
+ warning_threshold_ms = round(warning_threshold * 1000)
- await asyncio.sleep(1)
-
- if key in self.tasks:
- self.tasks[key].cancel()
-
- # check if thread is still running, if so, print a warning
- if key in self.threads:
- thread = self.threads[key]
+ # Wait until warning threshold
+ await asyncio.sleep(warning_threshold)
+ elapsed = time.monotonic() - start_time
+ elapsed_ms = round(elapsed * 1000)
- if self.config.enable_force_kill_sync_threads:
- self.force_kill_thread(thread)
- await asyncio.sleep(1)
+ # Check if the task has not yet exited despite the cancellation signal.
+ task_still_running = key in self.tasks and not self.tasks[key].done()
+ if task_still_running:
logger.warning(
- f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
+ f"Cancellation: task {action.action_id} has not cancelled after "
+ f"{elapsed_ms}ms (warning threshold {warning_threshold_ms}ms). "
+ f"Consider checking for blocking operations. "
+ f"See https://docs.hatchet.run/home/cancellation"
)
+
+ remaining = grace_period - elapsed
+ if remaining > 0:
+ await asyncio.sleep(remaining)
+
+ if key in self.tasks and not self.tasks[key].done():
+ logger.debug(
+ f"Cancellation: force-cancelling task {action.action_id} "
+ f"after grace period ({grace_period_ms}ms)"
+ )
+ self.tasks[key].cancel()
+
+ if key in self.threads:
+ thread = self.threads[key]
+
+ if self.config.enable_force_kill_sync_threads:
+ logger.debug(
+ f"Cancellation: force-killing thread for {action.action_id}"
+ )
+ self.force_kill_thread(thread)
+ await asyncio.sleep(1)
+
+ if thread.is_alive():
+ logger.warning(
+ f"Cancellation: thread {thread.ident} with key {key} is still running "
+ f"after cancellation. This could cause the thread pool to get blocked "
+ f"and prevent new tasks from running."
+ )
+
+ total_elapsed = time.monotonic() - start_time
+ total_elapsed_ms = round(total_elapsed * 1000)
+ if total_elapsed > grace_period:
+ logger.warning(
+ f"Cancellation: cancellation of {action.action_id} took {total_elapsed_ms}ms "
+ f"(exceeded grace period of {grace_period_ms}ms)"
+ )
+ else:
+ logger.debug(
+ f"Cancellation: task {action.action_id} eventually completed in {total_elapsed_ms}ms"
+ )
+ else:
+ logger.info(f"Cancellation: task {action.action_id} completed")
finally:
self.cleanup_run_id(key)
diff --git a/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py b/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py
index 6fd8b52ee1..4b105f843b 100644
--- a/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py
+++ b/sdks/python/hatchet_sdk/worker/runner/utils/capture_logs.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import functools
import logging
@@ -5,13 +7,15 @@
from io import StringIO
from typing import Literal, ParamSpec, TypeVar
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, ConfigDict, Field
+from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.events import EventClient
from hatchet_sdk.logger import logger
from hatchet_sdk.runnables.contextvars import (
ctx_action_key,
ctx_additional_metadata,
+ ctx_cancellation_token,
ctx_step_run_id,
ctx_task_retry_count,
ctx_worker_id,
@@ -48,10 +52,22 @@ class ContextVarToCopyDict(BaseModel):
value: JSONSerializableMapping | None
+class ContextVarToCopyToken(BaseModel):
+ """Special type for copying CancellationToken to threads."""
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ name: Literal["ctx_cancellation_token"]
+ value: CancellationToken | None
+
+
class ContextVarToCopy(BaseModel):
- var: ContextVarToCopyStr | ContextVarToCopyDict | ContextVarToCopyInt = Field(
- discriminator="name"
- )
+ var: (
+ ContextVarToCopyStr
+ | ContextVarToCopyDict
+ | ContextVarToCopyInt
+ | ContextVarToCopyToken
+ ) = Field(discriminator="name")
def copy_context_vars(
@@ -73,6 +89,8 @@ def copy_context_vars(
ctx_worker_id.set(var.var.value)
elif var.var.name == "ctx_additional_metadata":
ctx_additional_metadata.set(var.var.value or {})
+ elif var.var.name == "ctx_cancellation_token":
+ ctx_cancellation_token.set(var.var.value)
else:
raise ValueError(f"Unknown context variable name: {var.var.name}")
diff --git a/sdks/python/hatchet_sdk/workflow_run.py b/sdks/python/hatchet_sdk/workflow_run.py
index 5760eef8f9..e1d0c3cc25 100644
--- a/sdks/python/hatchet_sdk/workflow_run.py
+++ b/sdks/python/hatchet_sdk/workflow_run.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import time
from typing import TYPE_CHECKING, Any
@@ -6,9 +8,17 @@
RunEventListenerClient,
)
from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
-from hatchet_sdk.exceptions import FailedTaskRunExceptionGroup, TaskRunError
+from hatchet_sdk.exceptions import (
+ CancellationReason,
+ CancelledError,
+ FailedTaskRunExceptionGroup,
+ TaskRunError,
+)
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.cancellation import await_with_cancellation
if TYPE_CHECKING:
+ from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.admin import AdminClient
@@ -18,7 +28,7 @@ def __init__(
workflow_run_id: str,
workflow_run_listener: PooledWorkflowRunListener,
workflow_run_event_listener: RunEventListenerClient,
- admin_client: "AdminClient",
+ admin_client: AdminClient,
):
self.workflow_run_id = workflow_run_id
self.workflow_run_listener = workflow_run_listener
@@ -31,7 +41,25 @@ def __str__(self) -> str:
def stream(self) -> RunEventListener:
return self.workflow_run_event_listener.stream(self.workflow_run_id)
- async def aio_result(self) -> dict[str, Any]:
+ async def aio_result(
+ self, cancellation_token: CancellationToken | None = None
+ ) -> dict[str, Any]:
+ """
+ Asynchronously wait for the workflow run to complete and return the result.
+
+ :param cancellation_token: Optional cancellation token to abort the wait.
+ :return: A dictionary mapping task names to their outputs.
+ """
+ logger.debug(
+ f"WorkflowRunRef.aio_result: waiting for {self.workflow_run_id}, "
+ f"token={cancellation_token is not None}"
+ )
+
+ if cancellation_token:
+ return await await_with_cancellation(
+ self.workflow_run_listener.aio_result(self.workflow_run_id),
+ cancellation_token,
+ )
return await self.workflow_run_listener.aio_result(self.workflow_run_id)
def _safely_get_action_name(self, action_id: str | None) -> str | None:
@@ -43,12 +71,42 @@ def _safely_get_action_name(self, action_id: str | None) -> str | None:
except IndexError:
return None
- def result(self) -> dict[str, Any]:
+ def result(
+ self, cancellation_token: CancellationToken | None = None
+ ) -> dict[str, Any]:
+ """
+ Synchronously wait for the workflow run to complete and return the result.
+
+ This method polls the API for the workflow run status. If a cancellation token
+ is provided, the polling will be interrupted when cancellation is triggered.
+
+ :param cancellation_token: Optional cancellation token to abort the wait.
+ :return: A dictionary mapping task names to their outputs.
+ :raises CancelledError: If the cancellation token is triggered.
+ :raises FailedTaskRunExceptionGroup: If the workflow run fails.
+ :raises ValueError: If the workflow run is not found.
+ """
from hatchet_sdk.clients.admin import RunStatus
+ logger.debug(
+ f"WorkflowRunRef.result: waiting for {self.workflow_run_id}, "
+ f"token={cancellation_token is not None}"
+ )
+
retries = 0
while True:
+ # Check cancellation at start of each iteration
+ if cancellation_token and cancellation_token.is_cancelled:
+ logger.debug(
+ f"WorkflowRunRef.result: cancellation detected for {self.workflow_run_id}, "
+ f"reason={CancellationReason.PARENT_CANCELLED.value}"
+ )
+ raise CancelledError(
+ "Operation cancelled by cancellation token",
+ reason=CancellationReason.PARENT_CANCELLED,
+ )
+
try:
details = self.admin_client.get_details(self.workflow_run_id)
except Exception as e:
@@ -59,14 +117,42 @@ def result(self) -> dict[str, Any]:
f"Workflow run {self.workflow_run_id} not found"
) from e
- time.sleep(1)
+ # Use interruptible sleep via token.wait()
+ if cancellation_token:
+ if cancellation_token.wait(timeout=1.0):
+ logger.debug(
+ f"WorkflowRunRef.result: cancellation during retry sleep for {self.workflow_run_id}, "
+ f"reason={CancellationReason.PARENT_CANCELLED.value}"
+ )
+ raise CancelledError(
+ "Operation cancelled by cancellation token",
+ reason=CancellationReason.PARENT_CANCELLED,
+ ) from None
+ else:
+ time.sleep(1)
continue
+ logger.debug(
+ f"WorkflowRunRef.result: {self.workflow_run_id} status={details.status}"
+ )
+
if (
details.status in [RunStatus.QUEUED, RunStatus.RUNNING]
or details.done is False
):
- time.sleep(1)
+ # Use interruptible sleep via token.wait()
+ if cancellation_token:
+ if cancellation_token.wait(timeout=1.0):
+ logger.debug(
+ f"WorkflowRunRef.result: cancellation during poll sleep for {self.workflow_run_id}, "
+ f"reason={CancellationReason.PARENT_CANCELLED.value}"
+ )
+ raise CancelledError(
+ "Operation cancelled by cancellation token",
+ reason=CancellationReason.PARENT_CANCELLED,
+ )
+ else:
+ time.sleep(1)
continue
if details.status == RunStatus.FAILED:
@@ -80,6 +166,9 @@ def result(self) -> dict[str, Any]:
)
if details.status == RunStatus.COMPLETED:
+ logger.debug(
+ f"WorkflowRunRef.result: {self.workflow_run_id} completed successfully"
+ )
return {
readable_id: run.output
for readable_id, run in details.task_runs.items()
diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml
index 89c49d4efd..3121ffdebe 100644
--- a/sdks/python/pyproject.toml
+++ b/sdks/python/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
-version = "1.24.0"
+version = "1.25.0"
description = "This is the official Python SDK for Hatchet, a distributed, fault-tolerant task queue. The SDK allows you to easily integrate Hatchet's task scheduling and workflow orchestration capabilities into your Python applications."
authors = [
"Alexander Belanger ",
diff --git a/sdks/python/tests/test_cancellation.py b/sdks/python/tests/test_cancellation.py
new file mode 100644
index 0000000000..1b6046962e
--- /dev/null
+++ b/sdks/python/tests/test_cancellation.py
@@ -0,0 +1,461 @@
+"""Unit tests for CancellationToken and cancellation utilities."""
+
+import asyncio
+import threading
+import time
+
+import pytest
+
+from hatchet_sdk.cancellation import CancellationToken
+from hatchet_sdk.exceptions import CancellationReason, CancelledError
+from hatchet_sdk.runnables.contextvars import ctx_cancellation_token
+from hatchet_sdk.utils.cancellation import await_with_cancellation
+
+# CancellationToken
+
+
+def test_initial_state() -> None:
+ """Token should start in non-cancelled state."""
+ token = CancellationToken()
+ assert token.is_cancelled is False
+
+
+def test_cancel_sets_flag() -> None:
+ """cancel() should set is_cancelled to True."""
+ token = CancellationToken()
+ token.cancel()
+ assert token.is_cancelled is True
+
+
+def test_cancel_sets_reason() -> None:
+ """cancel() should set the reason."""
+ token = CancellationToken()
+ token.cancel(CancellationReason.USER_REQUESTED)
+ assert token.reason == CancellationReason.USER_REQUESTED
+
+
+def test_reason_is_none_before_cancel() -> None:
+ """reason should be None before cancellation."""
+ token = CancellationToken()
+ assert token.reason is None
+
+
+def test_cancel_idempotent() -> None:
+ """Multiple calls to cancel() should be safe."""
+ token = CancellationToken()
+ token.cancel()
+ token.cancel() # Should not raise
+ assert token.is_cancelled is True
+
+
+def test_cancel_idempotent_preserves_reason() -> None:
+ """Multiple calls to cancel() should preserve the original reason."""
+ token = CancellationToken()
+ token.cancel(CancellationReason.USER_REQUESTED)
+ token.cancel(CancellationReason.TIMEOUT) # Second call should be ignored
+ assert token.reason == CancellationReason.USER_REQUESTED
+
+
+def test_sync_wait_returns_true_when_cancelled() -> None:
+ """wait() should return True immediately if already cancelled."""
+ token = CancellationToken()
+ token.cancel()
+ result = token.wait(timeout=0.1)
+ assert result is True
+
+
+def test_sync_wait_timeout_returns_false() -> None:
+ """wait() should return False when timeout expires without cancellation."""
+ token = CancellationToken()
+ start = time.monotonic()
+ result = token.wait(timeout=0.1)
+ elapsed = time.monotonic() - start
+ assert result is False
+ assert elapsed >= 0.1
+
+
+def test_sync_wait_interrupted_by_cancel() -> None:
+ """wait() should return True when cancelled during wait."""
+ token = CancellationToken()
+
+ def cancel_after_delay() -> None:
+ time.sleep(0.1)
+ token.cancel()
+
+ thread = threading.Thread(target=cancel_after_delay)
+ thread.start()
+
+ start = time.monotonic()
+ result = token.wait(timeout=1.0)
+ elapsed = time.monotonic() - start
+
+ thread.join()
+
+ assert result is True
+ assert elapsed < 0.5 # Should be much faster than timeout
+
+
+@pytest.mark.asyncio
+async def test_aio_wait_returns_when_cancelled() -> None:
+ """aio_wait() should return when cancelled."""
+ token = CancellationToken()
+
+ async def cancel_after_delay() -> None:
+ await asyncio.sleep(0.1)
+ token.cancel()
+
+ asyncio.create_task(cancel_after_delay())
+
+ start = time.monotonic()
+ await token.aio_wait()
+ elapsed = time.monotonic() - start
+
+ assert elapsed < 0.5 # Should be fast
+
+
+def test_register_child() -> None:
+ """register_child() should add run IDs to the list."""
+ token = CancellationToken()
+ token.register_child("run-1")
+ token.register_child("run-2")
+
+ assert token.child_run_ids == ["run-1", "run-2"]
+
+
+def test_callback_invoked_on_cancel() -> None:
+ """Callbacks should be invoked when cancel() is called."""
+ token = CancellationToken()
+ called = []
+
+ def callback() -> None:
+ called.append(True)
+
+ token.add_callback(callback)
+ token.cancel()
+
+ assert called == [True]
+
+
+def test_callback_invoked_immediately_if_already_cancelled() -> None:
+ """Callbacks added after cancellation should be invoked immediately."""
+ token = CancellationToken()
+ token.cancel()
+
+ called = []
+
+ def callback() -> None:
+ called.append(True)
+
+ token.add_callback(callback)
+
+ assert called == [True]
+
+
+def test_multiple_callbacks() -> None:
+ """Multiple callbacks should all be invoked."""
+ token = CancellationToken()
+ results: list[int] = []
+
+ token.add_callback(lambda: results.append(1))
+ token.add_callback(lambda: results.append(2))
+ token.add_callback(lambda: results.append(3))
+
+ token.cancel()
+
+ assert results == [1, 2, 3]
+
+
+def test_repr() -> None:
+ """__repr__ should provide useful debugging info."""
+ token = CancellationToken()
+ token.register_child("run-1")
+
+ repr_str = repr(token)
+ assert "cancelled=False" in repr_str
+ assert "children=1" in repr_str
+
+
+# await_with_cancellation
+
+
+@pytest.mark.asyncio
+async def test_no_token_awaits_directly() -> None:
+ """Without a token, coroutine should be awaited directly."""
+
+ async def simple_coro() -> str:
+ return "result"
+
+ result = await await_with_cancellation(simple_coro(), None)
+ assert result == "result"
+
+
+@pytest.mark.asyncio
+async def test_token_not_cancelled_returns_result() -> None:
+ """With a non-cancelled token, should return coroutine result."""
+ token = CancellationToken()
+
+ async def simple_coro() -> str:
+ await asyncio.sleep(0.01)
+ return "result"
+
+ result = await await_with_cancellation(simple_coro(), token)
+ assert result == "result"
+
+
+@pytest.mark.asyncio
+async def test_already_cancelled_raises_immediately() -> None:
+ """With an already-cancelled token, should raise immediately."""
+ token = CancellationToken()
+ token.cancel()
+
+ async def simple_coro() -> str:
+ await asyncio.sleep(10) # Would block if actually awaited
+ return "result"
+
+ with pytest.raises(asyncio.CancelledError):
+ await await_with_cancellation(simple_coro(), token)
+
+
+@pytest.mark.asyncio
+async def test_cancellation_during_await_raises() -> None:
+ """Should raise CancelledError when token is cancelled during await."""
+ token = CancellationToken()
+
+ async def slow_coro() -> str:
+ await asyncio.sleep(10)
+ return "result"
+
+ async def cancel_after_delay() -> None:
+ await asyncio.sleep(0.1)
+ token.cancel()
+
+ asyncio.create_task(cancel_after_delay())
+
+ start = time.monotonic()
+ with pytest.raises(asyncio.CancelledError):
+ await await_with_cancellation(slow_coro(), token)
+ elapsed = time.monotonic() - start
+
+ assert elapsed < 0.5 # Should be cancelled quickly
+
+
+@pytest.mark.asyncio
+async def test_cancel_callback_invoked() -> None:
+ """Cancel callback should be invoked on cancellation."""
+ token = CancellationToken()
+ callback_called = []
+
+ async def cancel_callback() -> None:
+ callback_called.append(True)
+
+ async def slow_coro() -> str:
+ await asyncio.sleep(10)
+ return "result"
+
+ async def cancel_after_delay() -> None:
+ await asyncio.sleep(0.1)
+ token.cancel()
+
+ asyncio.create_task(cancel_after_delay())
+
+ with pytest.raises(asyncio.CancelledError):
+ await await_with_cancellation(
+ slow_coro(), token, cancel_callback=cancel_callback
+ )
+
+ assert callback_called == [True]
+
+
+@pytest.mark.asyncio
+async def test_sync_cancel_callback_invoked() -> None:
+ """Cancel callback should be invoked on cancellation."""
+ token = CancellationToken()
+ callback_called = []
+
+ async def cancel_callback() -> None:
+ callback_called.append(True)
+
+ async def slow_coro() -> str:
+ await asyncio.sleep(10)
+ return "result"
+
+ async def cancel_after_delay() -> None:
+ await asyncio.sleep(0.1)
+ token.cancel()
+
+ asyncio.create_task(cancel_after_delay())
+
+ with pytest.raises(asyncio.CancelledError):
+ await await_with_cancellation(
+ slow_coro(), token, cancel_callback=cancel_callback
+ )
+
+ assert callback_called == [True]
+
+
+@pytest.mark.asyncio
+async def test_cancel_callback_invoked_on_external_task_cancel() -> None:
+ """Cancel callback should be invoked if the awaiting task is cancelled externally."""
+ token = CancellationToken()
+ callback_called = asyncio.Event()
+
+ async def cancel_callback() -> None:
+ callback_called.set()
+
+ async def slow_coro() -> str:
+ await asyncio.sleep(10)
+ return "result"
+
+ task = asyncio.create_task(
+ await_with_cancellation(slow_coro(), token, cancel_callback=cancel_callback)
+ )
+
+ await asyncio.sleep(0.1)
+ task.cancel()
+
+ with pytest.raises(asyncio.CancelledError):
+ await task
+
+ await asyncio.wait_for(callback_called.wait(), timeout=1.0)
+
+
+@pytest.mark.asyncio
+async def test_cancel_callback_not_invoked_on_success() -> None:
+ """Cancel callback should NOT be invoked when coroutine completes normally."""
+ token = CancellationToken()
+ callback_called = []
+
+ async def cancel_callback() -> None:
+ callback_called.append(True)
+
+ async def fast_coro() -> str:
+ await asyncio.sleep(0.01)
+ return "result"
+
+ result = await await_with_cancellation(
+ fast_coro(), token, cancel_callback=cancel_callback
+ )
+
+ assert result == "result"
+ assert callback_called == []
+
+
+# CancellationReason
+
+
+def test_all_reasons_exist() -> None:
+ """All expected cancellation reasons should exist."""
+ assert CancellationReason.USER_REQUESTED.value == "user_requested"
+ assert CancellationReason.TIMEOUT.value == "timeout"
+ assert CancellationReason.PARENT_CANCELLED.value == "parent_cancelled"
+ assert CancellationReason.WORKFLOW_CANCELLED.value == "workflow_cancelled"
+ assert CancellationReason.TOKEN_CANCELLED.value == "token_cancelled"
+
+
+def test_reasons_are_strings() -> None:
+ """Cancellation reason values should be strings."""
+ for reason in CancellationReason:
+ assert isinstance(reason.value, str)
+
+
+# CancelledError
+
+
+def test_cancelled_error_is_base_exception() -> None:
+ """CancelledError should be a BaseException (not Exception)."""
+ err = CancelledError("test message")
+ assert isinstance(err, BaseException)
+ assert not isinstance(err, Exception) # Should NOT be caught by except Exception
+ assert str(err) == "test message"
+
+
+def test_cancelled_error_not_caught_by_except_exception() -> None:
+ """CancelledError should NOT be caught by except Exception."""
+ caught_by_exception = False
+ caught_by_cancelled_error = False
+
+ try:
+ raise CancelledError("test")
+ except Exception:
+ caught_by_exception = True
+ except CancelledError:
+ caught_by_cancelled_error = True
+
+ assert not caught_by_exception
+ assert caught_by_cancelled_error
+
+
+def test_cancelled_error_with_reason() -> None:
+ """CancelledError should accept and store a reason."""
+ err = CancelledError("test message", reason=CancellationReason.TIMEOUT)
+ assert err.reason == CancellationReason.TIMEOUT
+
+
+def test_cancelled_error_reason_defaults_to_none() -> None:
+ """CancelledError reason should default to None."""
+ err = CancelledError("test message")
+ assert err.reason is None
+
+
+def test_cancelled_error_message_property() -> None:
+ """CancelledError should have a message property."""
+ err = CancelledError("test message")
+ assert err.message == "test message"
+
+
+def test_cancelled_error_default_message() -> None:
+ """CancelledError should have a default message."""
+ err = CancelledError()
+ assert err.message == "Operation cancelled"
+
+
+def test_can_be_raised_and_caught() -> None:
+ """CancelledError should be raisable and catchable."""
+ with pytest.raises(CancelledError) as exc_info:
+ raise CancelledError("Operation cancelled")
+
+ assert "Operation cancelled" in str(exc_info.value)
+
+
+def test_can_be_raised_with_reason() -> None:
+ """CancelledError should be raisable with a reason."""
+ with pytest.raises(CancelledError) as exc_info:
+ raise CancelledError(
+ "Parent was cancelled", reason=CancellationReason.PARENT_CANCELLED
+ )
+
+ assert exc_info.value.reason == CancellationReason.PARENT_CANCELLED
+
+
+# Context var propagation
+
+
+def test_context_var_default_is_none() -> None:
+ """ctx_cancellation_token should default to None."""
+ assert ctx_cancellation_token.get() is None
+
+
+def test_context_var_can_be_set_and_retrieved() -> None:
+ """ctx_cancellation_token should be settable and retrievable."""
+ token = CancellationToken()
+ ctx_cancellation_token.set(token)
+ try:
+ assert ctx_cancellation_token.get() is token
+ finally:
+ ctx_cancellation_token.set(None)
+
+
+@pytest.mark.asyncio
+async def test_context_var_propagates_in_async() -> None:
+ """ctx_cancellation_token should propagate in async context."""
+ token = CancellationToken()
+ ctx_cancellation_token.set(token)
+
+ async def check_token() -> CancellationToken | None:
+ return ctx_cancellation_token.get()
+
+ try:
+ retrieved = await check_token()
+ assert retrieved is token
+ finally:
+ ctx_cancellation_token.set(None)