Skip to content

Commit 2563949

Browse files
committed
rework into decorator
1 parent 122d6e4 commit 2563949

File tree

1 file changed

+46
-25
lines changed

1 file changed

+46
-25
lines changed

src/docket/worker.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import base64
3+
import functools
34
import logging
45
import os
56
import signal
@@ -8,7 +9,7 @@
89
import time
910
from datetime import datetime, timedelta, timezone
1011
from types import TracebackType
11-
from typing import Any, Coroutine, Mapping, Protocol, cast
12+
from typing import Any, Callable, Coroutine, Mapping, Protocol, TypeVar, cast
1213

1314
import cloudpickle # type: ignore[import]
1415

@@ -78,6 +79,40 @@ def __init__(self, execution: Execution):
7879
logger: logging.Logger = logging.getLogger(__name__)
7980
tracer: Tracer = trace.get_tracer(__name__)
8081

82+
F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, None]])
83+
84+
85+
def _with_sigterm_handler(func: F) -> F: # pragma: no cover
86+
"""Decorator that installs a SIGTERM handler for graceful shutdown.
87+
88+
On SIGTERM, cancels the wrapped coroutine, allowing in-flight tasks to
89+
complete before the worker exits. Tested via subprocess in
90+
test_sigterm_gracefully_drains_inflight_tasks.
91+
"""
92+
if not hasattr(signal, "SIGTERM"):
93+
return func
94+
95+
@functools.wraps(func)
96+
async def wrapper(*args: Any, **kwargs: Any) -> None:
97+
loop = asyncio.get_running_loop()
98+
task: asyncio.Task[None] | None = None
99+
100+
def handle_sigterm() -> None:
101+
logger.info("Received SIGTERM, initiating graceful shutdown...")
102+
if task and not task.done():
103+
task.cancel()
104+
105+
loop.add_signal_handler(signal.SIGTERM, handle_sigterm)
106+
try:
107+
task = asyncio.create_task(func(*args, **kwargs))
108+
await task
109+
except asyncio.CancelledError:
110+
pass # Expected from signal handler
111+
finally:
112+
loop.remove_signal_handler(signal.SIGTERM)
113+
114+
return cast(F, wrapper)
115+
81116

82117
class _stream_due_tasks(Protocol):
83118
async def __call__(
@@ -199,35 +234,21 @@ async def run(
199234
):
200235
if until_finished:
201236
await worker.run_until_finished()
202-
else: # pragma: no cover
203-
loop = asyncio.get_running_loop()
204-
current_task = asyncio.current_task()
205-
206-
def handle_sigterm() -> None:
207-
logger.info(
208-
"Received SIGTERM, initiating graceful shutdown..."
209-
)
210-
if current_task and not current_task.done():
211-
current_task.cancel()
212-
213-
if hasattr(signal, "SIGTERM"):
214-
loop.add_signal_handler(signal.SIGTERM, handle_sigterm)
215-
216-
try:
217-
await worker.run_forever()
218-
except asyncio.CancelledError:
219-
pass # Expected from signal handler
220-
finally:
221-
if hasattr(signal, "SIGTERM"):
222-
loop.remove_signal_handler(signal.SIGTERM)
237+
else: # pragma: no cover - tested via subprocess
238+
await worker.run_forever()
223239

224240
async def run_until_finished(self) -> None:
225241
"""Run the worker until there are no more tasks to process."""
226242
return await self._run(forever=False)
227243

228-
async def run_forever(self) -> None:
229-
"""Run the worker indefinitely."""
230-
return await self._run(forever=True) # pragma: no cover
244+
@_with_sigterm_handler
245+
async def run_forever(self) -> None: # pragma: no cover - tested via subprocess
246+
"""Run the worker indefinitely.
247+
248+
Installs a SIGTERM handler that initiates graceful shutdown, allowing
249+
in-flight tasks to complete before the worker exits.
250+
"""
251+
await self._run(forever=True)
231252

232253
_execution_counts: dict[str, int]
233254

0 commit comments

Comments
 (0)