Skip to content

Commit 391338b

Browse files
Cancel Temporal timer when workflow.sleep() task is cancelled (#1352)
* fix: cancel Temporal timer when workflow.sleep() task is cancelled * chore: simplify unit test to check for TimerCanceled event in workflow history * fix: linting --------- Co-authored-by: tconley1428 <tconley1428@gmail.com>
1 parent afa83e1 commit 391338b

2 files changed

Lines changed: 65 additions & 2 deletions

File tree

temporalio/worker/_workflow_instance.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,10 +1739,13 @@ async def workflow_sleep(
17391739
else None
17401740
)
17411741
fut = self.create_future()
1742-
self._timer_impl(
1742+
timer_handle = self._timer_impl(
17431743
duration,
17441744
_TimerOptions(user_metadata=user_metadata),
1745-
lambda: fut.set_result(None),
1745+
lambda: fut.set_result(None) if not fut.done() else None,
1746+
)
1747+
fut.add_done_callback(
1748+
lambda f: timer_handle.cancel() if f.cancelled() else None
17461749
)
17471750
await fut
17481751

tests/worker/test_workflow.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3431,6 +3431,66 @@ async def test_workflow_cancel_signal_and_timer_fired_in_same_task(
34313431
await result_task
34323432

34333433

3434+
@workflow.defn
3435+
class CancelWorkflowSleepTaskWorkflow:
3436+
"""Like CancelSignalAndTimerFiredInSameTaskWorkflow but uses workflow.sleep."""
3437+
3438+
_ready = False
3439+
timer_task: asyncio.Task[None] # type: ignore[reportUninitializedInstanceVariable]
3440+
3441+
@workflow.run
3442+
async def run(self) -> str:
3443+
self.timer_task = asyncio.create_task(workflow.sleep(60 * 60))
3444+
self._ready = True
3445+
try:
3446+
await self.timer_task
3447+
return "timer_completed"
3448+
except asyncio.CancelledError:
3449+
return "timer_cancelled"
3450+
3451+
@workflow.query
3452+
def ready(self) -> bool:
3453+
return self._ready
3454+
3455+
@workflow.signal
3456+
def cancel_timer(self) -> None:
3457+
self.timer_task.cancel()
3458+
3459+
3460+
async def test_workflow_sleep_task_cancellation(
3461+
client: Client,
3462+
):
3463+
async with new_worker(
3464+
client,
3465+
CancelWorkflowSleepTaskWorkflow,
3466+
) as worker:
3467+
handle = await client.start_workflow(
3468+
CancelWorkflowSleepTaskWorkflow.run,
3469+
id=f"workflow-{uuid.uuid4()}",
3470+
task_queue=worker.task_queue,
3471+
)
3472+
3473+
async def ready() -> bool:
3474+
return await handle.query(CancelWorkflowSleepTaskWorkflow.ready)
3475+
3476+
await assert_eq_eventually(True, ready)
3477+
await handle.signal(CancelWorkflowSleepTaskWorkflow.cancel_timer)
3478+
result = await handle.result()
3479+
3480+
assert result == "timer_cancelled"
3481+
# Verify the Temporal timer was actually cancelled on the server
3482+
resp = await client.workflow_service.get_workflow_execution_history(
3483+
GetWorkflowExecutionHistoryRequest(
3484+
namespace=client.namespace,
3485+
execution=WorkflowExecution(workflow_id=handle.id),
3486+
)
3487+
)
3488+
timer_canceled = any(
3489+
e.event_type == EventType.EVENT_TYPE_TIMER_CANCELED for e in resp.history.events
3490+
)
3491+
assert timer_canceled, "Expected TimerCanceled event in history"
3492+
3493+
34343494
class MyCustomError(ApplicationError):
34353495
def __init__(self, message: str) -> None:
34363496
super().__init__(message, type="MyCustomError", non_retryable=True)

0 commit comments

Comments
 (0)