Skip to content

Commit 1323b1b

Browse files
chrisguidryclaude
andauthored
Fix worker __aexit__ hanging on early cancellation (#263)
If a CancelledError arrived between _worker_done.clear() and entering the try block, the finally block never ran, leaving _worker_done cleared forever. This caused __aexit__ to hang indefinitely waiting on _worker_done.wait(). The fix moves _cancellation_listener_task from __aenter__/__aexit__ to _worker_loop, matching the pattern used by scheduler_task and lease_renewal_task. All async setup now happens inside the try block, ensuring the finally block always runs and _worker_done.set() is called. Also names all asyncio tasks for easier debugging (docket.worker.*, docket.strikelist.*). 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 9a2ac88 commit 1323b1b

File tree

4 files changed

+346
-119
lines changed

4 files changed

+346
-119
lines changed

src/docket/strikelist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ async def connect(self) -> None:
227227
self._connection_pool = await connection_pool_from_url(self.url)
228228

229229
self._strikes_loaded = asyncio.Event()
230-
self._monitor_task = asyncio.create_task(self._monitor_strikes())
230+
self._monitor_task = asyncio.create_task(
231+
self._monitor_strikes(), name="docket.strikelist.monitor"
232+
)
231233

232234
async def close(self) -> None:
233235
"""Close the Redis connection and stop monitoring.

src/docket/worker.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,15 @@ def _maybe_suppress_instrumentation(self) -> Generator[None, None, None]:
205205
yield
206206

207207
async def __aenter__(self) -> Self:
208-
self._heartbeat_task = asyncio.create_task(self._heartbeat())
208+
self._heartbeat_task = asyncio.create_task(
209+
self._heartbeat(), name="docket.worker.heartbeat"
210+
)
209211
self._execution_counts: dict[str, int] = {}
210212
# Track concurrency slots for active tasks so we can refresh them during
211213
# lease renewal. Maps execution.key → concurrency_key
212214
self._concurrency_slots: dict[str, str] = {}
213215
# Track running tasks for cancellation lookup
214216
self._tasks_by_key: dict[TaskKey, asyncio.Task[None]] = {}
215-
self._cancellation_listener_task = asyncio.create_task(
216-
self._cancellation_listener()
217-
)
218217
# Events for coordinating worker loop shutdown
219218
self._worker_stopping = asyncio.Event()
220219
self._worker_done = asyncio.Event()
@@ -233,11 +232,6 @@ async def __aexit__(
233232
self._worker_stopping.set()
234233
await self._worker_done.wait()
235234

236-
self._cancellation_listener_task.cancel()
237-
with suppress(asyncio.CancelledError):
238-
await self._cancellation_listener_task
239-
del self._cancellation_listener_task
240-
241235
self._heartbeat_task.cancel()
242236
with suppress(asyncio.CancelledError):
243237
await self._heartbeat_task
@@ -346,10 +340,12 @@ def handle_shutdown(sig_name: str) -> None: # pragma: no cover
346340

347341
try:
348342
if until_finished:
349-
run_task = asyncio.create_task(worker.run_until_finished())
343+
run_task = asyncio.create_task(
344+
worker.run_until_finished(), name="docket.worker.run"
345+
)
350346
else:
351347
run_task = asyncio.create_task(
352-
worker.run_forever()
348+
worker.run_forever(), name="docket.worker.run"
353349
) # pragma: no cover
354350
await run_task
355351
except asyncio.CancelledError: # pragma: no cover
@@ -419,21 +415,16 @@ async def _run(self, forever: bool = False) -> None:
419415
async def _worker_loop(self, redis: Redis, forever: bool = False):
420416
self._worker_stopping.clear()
421417
self._worker_done.clear()
418+
self._cancellation_ready.clear() # Reset for reconnection scenarios
422419

423-
await self._cancellation_ready.wait()
424-
425-
if self.schedule_automatic_tasks:
426-
await self._schedule_all_automatic_perpetual_tasks()
427-
420+
# Initialize task variables before try block so finally can check them.
421+
# This ensures _worker_done.set() is always called if _worker_done.clear() was.
422+
cancellation_listener_task: asyncio.Task[None] | None = None
423+
scheduler_task: asyncio.Task[None] | None = None
424+
lease_renewal_task: asyncio.Task[None] | None = None
428425
active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
429-
430-
scheduler_task = asyncio.create_task(self._scheduler_loop(redis))
431-
lease_renewal_task = asyncio.create_task(
432-
self._renew_leases(redis, active_tasks)
433-
)
434426
task_executions: dict[asyncio.Task[None], Execution] = {}
435427
available_slots = self.concurrency
436-
437428
log_context = self._log_context()
438429

439430
async def check_for_work() -> bool:
@@ -547,10 +538,27 @@ async def ack_message(redis: Redis, message_id: RedisMessageID) -> None:
547538
)
548539
await pipeline.execute()
549540

550-
has_work: bool = True
551-
stopping = self._worker_stopping.is_set
552-
553541
try:
542+
# Start cancellation listener and wait for it to be ready
543+
cancellation_listener_task = asyncio.create_task(
544+
self._cancellation_listener(),
545+
name="docket.worker.cancellation_listener",
546+
)
547+
await self._cancellation_ready.wait()
548+
549+
if self.schedule_automatic_tasks:
550+
await self._schedule_all_automatic_perpetual_tasks()
551+
552+
scheduler_task = asyncio.create_task(
553+
self._scheduler_loop(redis), name="docket.worker.scheduler"
554+
)
555+
lease_renewal_task = asyncio.create_task(
556+
self._renew_leases(redis, active_tasks),
557+
name="docket.worker.lease_renewal",
558+
)
559+
560+
has_work: bool = True
561+
stopping = self._worker_stopping.is_set
554562
while (forever or has_work or active_tasks) and not stopping():
555563
await process_completed_tasks()
556564

@@ -588,9 +596,21 @@ async def ack_message(redis: Redis, message_id: RedisMessageID) -> None:
588596
await asyncio.gather(*active_tasks, return_exceptions=True)
589597
await process_completed_tasks()
590598

599+
# Signal internal tasks to stop
591600
self._worker_stopping.set()
592-
await scheduler_task
593-
await lease_renewal_task
601+
602+
# These check _worker_stopping and exit cleanly
603+
if scheduler_task is not None:
604+
await scheduler_task
605+
if lease_renewal_task is not None:
606+
await lease_renewal_task
607+
608+
# Cancellation listener has while True loop, needs explicit cancellation
609+
if cancellation_listener_task is not None: # pragma: no branch
610+
cancellation_listener_task.cancel()
611+
with suppress(asyncio.CancelledError):
612+
await cancellation_listener_task
613+
594614
self._worker_done.set()
595615

596616
async def _scheduler_loop(self, redis: Redis) -> None:
@@ -1001,7 +1021,9 @@ async def _run_function_with_timeout(
10011021
},
10021022
),
10031023
)
1004-
task = asyncio.create_task(task_coro)
1024+
task = asyncio.create_task(
1025+
task_coro, name=f"docket.worker.task:{execution.key}"
1026+
)
10051027
try:
10061028
while not task.done(): # pragma: no branch
10071029
remaining = timeout.remaining().total_seconds()

tests/test_worker.py

Lines changed: 2 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import asyncio
22
import logging
3-
import sys
4-
from contextlib import asynccontextmanager, suppress
3+
from contextlib import asynccontextmanager
54
from datetime import datetime, timedelta, timezone
65
from typing import AsyncGenerator, Callable
6+
77
from unittest.mock import AsyncMock, patch
88
from uuid import uuid4
99

@@ -18,30 +18,13 @@
1818
Docket,
1919
Perpetual,
2020
Worker,
21-
testing,
2221
)
2322
from docket.dependencies import Timeout
2423
from docket.execution import Execution
2524
from docket.tasks import standard_tasks
2625
from docket.worker import ms
2726
from tests._key_leak_checker import KeyCountChecker
2827

29-
if sys.version_info >= (3, 11): # pragma: no cover
30-
from asyncio import timeout as async_timeout
31-
else: # pragma: no cover
32-
33-
@asynccontextmanager
34-
async def async_timeout(delay: float):
35-
"""Compatibility shim for asyncio.timeout on Python 3.10."""
36-
task = asyncio.current_task()
37-
loop = asyncio.get_running_loop()
38-
deadline = loop.time() + delay
39-
handle = loop.call_at(deadline, task.cancel) # type: ignore[union-attr]
40-
try:
41-
yield
42-
finally:
43-
handle.cancel()
44-
4528

4629
async def test_worker_acknowledges_messages(
4730
docket: Docket, worker: Worker, the_task: AsyncMock
@@ -1135,74 +1118,3 @@ async def mock_xreadgroup( # pyright: ignore[reportUnknownParameterType]
11351118
assert task_executed
11361119
# Should have called xreadgroup at least twice (once NOGROUP, then success)
11371120
assert call_count >= 2
1138-
1139-
1140-
async def test_run_forever_cancels_promptly_with_future_tasks(
1141-
docket: Docket, the_task: AsyncMock, now: Callable[[], datetime]
1142-
):
1143-
"""run_forever() should cancel promptly even with future-scheduled tasks.
1144-
1145-
Issue #260: Perpetual tasks block worker shutdown.
1146-
"""
1147-
execution = await docket.add(the_task, now() + timedelta(seconds=15))()
1148-
1149-
async with Worker(
1150-
docket,
1151-
minimum_check_interval=timedelta(milliseconds=5),
1152-
scheduling_resolution=timedelta(milliseconds=5),
1153-
) as worker:
1154-
worker_task = asyncio.create_task(worker.run_forever())
1155-
await asyncio.sleep(0.05)
1156-
worker_task.cancel()
1157-
with suppress(asyncio.CancelledError): # pragma: no branch
1158-
async with async_timeout(1.0): # pragma: no branch
1159-
await worker_task
1160-
1161-
the_task.assert_not_called()
1162-
await testing.assert_task_scheduled(docket, the_task, key=execution.key)
1163-
1164-
1165-
async def test_run_until_finished_exits_promptly_with_future_tasks(
1166-
docket: Docket, the_task: AsyncMock, now: Callable[[], datetime]
1167-
):
1168-
"""run_until_finished() should exit promptly when only future tasks exist.
1169-
1170-
Issue #260: Perpetual tasks block worker shutdown.
1171-
"""
1172-
execution = await docket.add(the_task, now() + timedelta(seconds=15))()
1173-
1174-
async with Worker(
1175-
docket,
1176-
minimum_check_interval=timedelta(milliseconds=5),
1177-
scheduling_resolution=timedelta(milliseconds=5),
1178-
) as worker:
1179-
async with async_timeout(1.0):
1180-
await worker.run_until_finished()
1181-
1182-
the_task.assert_not_called()
1183-
await testing.assert_task_scheduled(docket, the_task, key=execution.key)
1184-
1185-
1186-
async def test_run_at_most_cancels_promptly_with_future_tasks(
1187-
docket: Docket, the_task: AsyncMock, now: Callable[[], datetime]
1188-
):
1189-
"""run_at_most() should cancel promptly even with future-scheduled tasks.
1190-
1191-
Issue #260: Perpetual tasks block worker shutdown.
1192-
"""
1193-
execution = await docket.add(the_task, now() + timedelta(seconds=15))()
1194-
1195-
async with Worker(
1196-
docket,
1197-
minimum_check_interval=timedelta(milliseconds=5),
1198-
scheduling_resolution=timedelta(milliseconds=5),
1199-
) as worker:
1200-
worker_task = asyncio.create_task(worker.run_at_most({execution.key: 1}))
1201-
await asyncio.sleep(0.05)
1202-
worker_task.cancel()
1203-
with suppress(asyncio.CancelledError): # pragma: no branch
1204-
async with async_timeout(1.0): # pragma: no branch
1205-
await worker_task
1206-
1207-
the_task.assert_not_called()
1208-
await testing.assert_task_scheduled(docket, the_task, key=execution.key)

0 commit comments

Comments
 (0)