Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion examples/python/cancellation/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time

from hatchet_sdk import Context, EmptyModel, Hatchet
from hatchet_sdk import CancellationReason, CancelledError, Context, EmptyModel, Hatchet

hatchet = Hatchet(debug=True)

Expand Down Expand Up @@ -40,6 +40,25 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:



# > Handling cancelled error
@cancellation_workflow.task()
def my_task(input: EmptyModel, ctx: Context) -> dict:
try:
result = ctx.playground("test", "default")
except CancelledError as e:
# Handle parent cancellation - i.e. perform cleanup, then re-raise
print(f"Parent Task cancelled: {e.reason}")
# Always re-raise CancelledError so Hatchet can properly handle the cancellation
raise
except Exception as e:
# This will NOT catch CancelledError
print(f"Other error: {e}")
raise
return result




def main() -> None:
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
worker.start()
Expand Down
6 changes: 1 addition & 5 deletions examples/python/durable/test_durable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ async def test_durable(hatchet: Hatchet) -> None:

active_workers = [w for w in workers.rows if w.status == "ACTIVE"]

assert len(active_workers) == 2
assert len(active_workers) == 1
assert any(
w.name == hatchet.config.apply_namespace("e2e-test-worker")
for w in active_workers
)
assert any(
w.name == hatchet.config.apply_namespace("e2e-test-worker_durable")
for w in active_workers
)

assert result["durable_task"]["status"] == "success"

Expand Down
8 changes: 6 additions & 2 deletions examples/python/simple/worker.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# > Simple
import time

from hatchet_sdk import Context, EmptyModel, Hatchet
from hatchet_sdk import Context, DurableContext, EmptyModel, Hatchet

hatchet = Hatchet(debug=True)


@hatchet.task()
def simple(input: EmptyModel, ctx: Context) -> dict[str, str]:
time.sleep(50)
return {"result": "Hello, world!"}


@hatchet.durable_task()
def simple_durable(input: EmptyModel, ctx: Context) -> dict[str, str]:
async def simple_durable(input: EmptyModel, ctx: DurableContext) -> dict[str, str]:
res = await simple.aio_run(input)
print(res)
return {"result": "Hello, world!"}


Expand Down
27 changes: 25 additions & 2 deletions frontend/docs/pages/home/cancellation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,38 @@ When a task is canceled, Hatchet sends a cancellation signal to the task. The ta

/>

### CancelledError Exception

When a sync task is cancelled while waiting for a child workflow or during a cancellation-aware operation, a `CancelledError` exception is raised.

<Callout type="warning">
**Important:** `CancelledError` inherits from `BaseException`, not
`Exception`. This means it will **not** be caught by bare `except Exception:`
handlers. This is intentional and mirrors the behavior of Python's
`asyncio.CancelledError`.
</Callout>

<Snippet src={snippets.python.cancellation.worker.handling_cancelled_error} />

### Cancellation Reasons

The `CancelledError` includes a `reason` attribute that indicates why the cancellation occurred:

| Reason | Description |
| --------------------------------------- | --------------------------------------------------------------------- |
| `CancellationReason.USER_REQUESTED` | The user explicitly requested cancellation via `ctx.cancel()` |
| `CancellationReason.WORKFLOW_CANCELLED` | The workflow run was cancelled (e.g., via API or concurrency control) |
| `CancellationReason.PARENT_CANCELLED` | The parent workflow was cancelled while waiting for a child |
| `CancellationReason.TIMEOUT` | The operation timed out |
| `CancellationReason.UNKNOWN` | Unknown or unspecified reason |

</Tabs.Tab>
<Tabs.Tab title="Typescript">
<Snippet
src={snippets.typescript.cancellations.workflow.declaring_a_task}

/>
<Snippet
src={snippets.typescript.cancellations.workflow.abort_signal}

/>

</Tabs.Tab>
Expand Down
22 changes: 21 additions & 1 deletion sdks/python/examples/cancellation/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time

from hatchet_sdk import Context, EmptyModel, Hatchet
from hatchet_sdk import CancellationReason, CancelledError, Context, EmptyModel, Hatchet

hatchet = Hatchet(debug=True)

Expand Down Expand Up @@ -42,6 +42,26 @@ def check_flag(input: EmptyModel, ctx: Context) -> dict[str, str]:
# !!


# > Handling cancelled error
@cancellation_workflow.task()
def my_task(input: EmptyModel, ctx: Context) -> dict:
try:
result = ctx.playground("test", "default")
except CancelledError as e:
# Handle parent cancellation - i.e. perform cleanup, then re-raise
print(f"Parent Task cancelled: {e.reason}")
# Always re-raise CancelledError so Hatchet can properly handle the cancellation
raise
except Exception as e:
# This will NOT catch CancelledError
print(f"Other error: {e}")
raise
return result


# !!


def main() -> None:
worker = hatchet.worker("cancellation-worker", workflows=[cancellation_workflow])
worker.start()
Expand Down
6 changes: 1 addition & 5 deletions sdks/python/examples/durable/test_durable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ async def test_durable(hatchet: Hatchet) -> None:

active_workers = [w for w in workers.rows if w.status == "ACTIVE"]

assert len(active_workers) >= 2
assert len(active_workers) == 1
assert any(
w.name == hatchet.config.apply_namespace("e2e-test-worker")
for w in active_workers
)
assert any(
w.name == hatchet.config.apply_namespace("e2e-test-worker_durable")
for w in active_workers
)

assert result["durable_task"]["status"] == "success"

Expand Down
6 changes: 6 additions & 0 deletions sdks/python/hatchet_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from hatchet_sdk.cancellation import CancellationToken
from hatchet_sdk.clients.admin import (
RunStatus,
ScheduleTriggerWorkflowOptions,
Expand Down Expand Up @@ -148,6 +149,8 @@
WorkerLabelComparator,
)
from hatchet_sdk.exceptions import (
CancellationReason,
CancelledError,
DedupeViolationError,
FailedTaskRunExceptionGroup,
NonRetryableException,
Expand Down Expand Up @@ -186,6 +189,9 @@
"CELEvaluationResult",
"CELFailure",
"CELSuccess",
"CancellationReason",
"CancellationToken",
"CancelledError",
"ClientConfig",
"ClientTLSConfig",
"ConcurrencyExpression",
Expand Down
194 changes: 194 additions & 0 deletions sdks/python/hatchet_sdk/cancellation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Cancellation token for coordinating cancellation across async and sync operations."""

from __future__ import annotations

import asyncio
import threading
from collections.abc import Callable
from typing import TYPE_CHECKING

from hatchet_sdk.exceptions import CancellationReason
from hatchet_sdk.logger import logger

if TYPE_CHECKING:
pass


class CancellationToken:
"""
A token that can be used to signal cancellation across async and sync operations.

The token provides both asyncio and threading event primitives, allowing it to work
seamlessly in both async and sync code paths. Child workflow run IDs can be registered
with the token so they can be cancelled when the parent is cancelled.

Example:
```python
token = CancellationToken()

# In async code
await token.aio_wait() # Blocks until cancelled

# In sync code
token.wait(timeout=1.0) # Returns True if cancelled within timeout

# Check if cancelled
if token.is_cancelled:
raise CancelledError("Operation was cancelled")

# Trigger cancellation
token.cancel()
```
"""

def __init__(self) -> None:
self._cancelled = False
self._reason: CancellationReason | None = None
self._async_event: asyncio.Event | None = None
self._sync_event = threading.Event()
self._child_run_ids: list[str] = []
self._callbacks: list[Callable[[], None]] = []
self._lock = threading.Lock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does the async code play with threading.Lock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be fine afict i refactored some of the callback stuff to minimize any contention


def _get_async_event(self) -> asyncio.Event:
"""Lazily create the asyncio event to avoid requiring an event loop at init time."""
if self._async_event is None:
self._async_event = asyncio.Event()
# If already cancelled, set the event
if self._cancelled:
self._async_event.set()
return self._async_event

def cancel(self, reason: CancellationReason = CancellationReason.TOKEN_CANCELLED) -> None:
"""
Trigger cancellation.

This will:
- Set the cancelled flag and reason
- Signal both async and sync events
- Invoke all registered callbacks

Args:
reason: The reason for cancellation.
"""
with self._lock:
if self._cancelled:
logger.debug(
f"CancellationToken: cancel() called but already cancelled, "
f"reason={self._reason.value if self._reason else 'none'}"
)
return

logger.debug(
f"CancellationToken: cancel() called, reason={reason.value}, "
f"{len(self._child_run_ids)} children registered"
)

self._cancelled = True
self._reason = reason

# Signal both event types
if self._async_event is not None:
self._async_event.set()
self._sync_event.set()

for callback in self._callbacks:
try:
logger.debug(f"CancellationToken: invoking callback {callback}")
callback()
except Exception as e: # noqa: PERF203
logger.warning(f"CancellationToken: callback raised exception: {e}")

logger.debug(f"CancellationToken: cancel() complete, reason={reason.value}")

@property
def is_cancelled(self) -> bool:
"""Check if cancellation has been triggered."""
return self._cancelled

@property
def reason(self) -> CancellationReason | None:
"""Get the reason for cancellation, or None if not cancelled."""
return self._reason

async def aio_wait(self) -> None:
"""
Await until cancelled (for use in asyncio).

This will block until cancel() is called.
"""
await self._get_async_event().wait()
logger.debug(
f"CancellationToken: async wait completed (cancelled), "
f"reason={self._reason.value if self._reason else 'none'}"
)

def wait(self, timeout: float | None = None) -> bool:
"""
Block until cancelled (for use in sync code).

Args:
timeout: Maximum time to wait in seconds. None means wait forever.

Returns:
True if the token was cancelled (event was set), False if timeout expired.
"""
result = self._sync_event.wait(timeout)
if result:
logger.debug(
f"CancellationToken: sync wait interrupted by cancellation, "
f"reason={self._reason.value if self._reason else 'none'}"
)
return result

def register_child(self, run_id: str) -> None:
"""
Register a child workflow run ID with this token.

When the parent is cancelled, these child run IDs can be used to cancel
the child workflows as well.

Args:
run_id: The workflow run ID of the child workflow.
"""
with self._lock:
logger.debug(f"CancellationToken: registering child workflow {run_id}")
self._child_run_ids.append(run_id)

def get_child_run_ids(self) -> list[str]:
"""
Get a copy of the registered child run IDs.

Returns:
A list of child workflow run IDs.
"""
with self._lock:
return self._child_run_ids.copy()

def add_callback(self, callback: Callable[[], None]) -> None:
"""
Register a callback to be invoked when cancellation is triggered.

If the token is already cancelled, the callback will be invoked immediately.

Args:
callback: A callable that takes no arguments.
"""
with self._lock:
if self._cancelled:
# Already cancelled, invoke immediately
logger.debug(
f"CancellationToken: invoking callback immediately (already cancelled): {callback}"
)
try:
callback()
except Exception as e:
logger.warning(f"CancellationToken: callback raised exception: {e}")
else:
self._callbacks.append(callback)

def __repr__(self) -> str:
return (
f"CancellationToken(cancelled={self._cancelled}, "
f"children={len(self._child_run_ids)}, callbacks={len(self._callbacks)})"
)
Loading
Loading