Skip to content

Commit 4bc2aff

Browse files
authored
Adds a healthcheck HTTP server for docket workers (#126)
1 parent 6366875 commit 4bc2aff

File tree

4 files changed

+71
-6
lines changed

4 files changed

+71
-6
lines changed

src/docket/cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,20 @@ def worker(
259259
help="Exit after the current docket is finished",
260260
),
261261
] = False,
262+
healthcheck_port: Annotated[
263+
int | None,
264+
typer.Option(
265+
"--healthcheck-port",
266+
help="The port to serve a healthcheck on",
267+
envvar="DOCKET_WORKER_HEALTHCHECK_PORT",
268+
),
269+
] = None,
262270
metrics_port: Annotated[
263271
int | None,
264272
typer.Option(
265273
"--metrics-port",
266274
help="The port to serve Prometheus metrics on",
275+
envvar="DOCKET_WORKER_METRICS_PORT",
267276
),
268277
] = None,
269278
) -> None:
@@ -279,6 +288,7 @@ def worker(
279288
scheduling_resolution=scheduling_resolution,
280289
schedule_automatic_tasks=schedule_automatic_tasks,
281290
until_finished=until_finished,
291+
healthcheck_port=healthcheck_port,
282292
metrics_port=metrics_port,
283293
tasks=tasks,
284294
)

src/docket/instrumentation.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import threading
21
from contextlib import contextmanager
2+
from threading import Thread
33
from typing import Generator, cast
44

55
from opentelemetry import metrics
@@ -145,6 +145,34 @@ def set(
145145
message_setter: MessageSetter = MessageSetter()
146146

147147

148+
@contextmanager
149+
def healthcheck_server(
150+
host: str = "0.0.0.0", port: int | None = None
151+
) -> Generator[None, None, None]:
152+
if port is None:
153+
yield
154+
return
155+
156+
from http.server import BaseHTTPRequestHandler, HTTPServer
157+
158+
class HealthcheckHandler(BaseHTTPRequestHandler):
159+
def do_GET(self):
160+
self.send_response(200)
161+
self.send_header("Content-type", "text/plain")
162+
self.end_headers()
163+
self.wfile.write(b"OK")
164+
165+
def log_message(self, format: str, *args: object) -> None:
166+
# Suppress access logs from the webserver
167+
pass
168+
169+
server = HTTPServer((host, port), HealthcheckHandler)
170+
with server:
171+
Thread(target=server.serve_forever, daemon=True).start()
172+
173+
yield
174+
175+
148176
@contextmanager
149177
def metrics_server(
150178
host: str = "0.0.0.0", port: int | None = None
@@ -173,8 +201,6 @@ def metrics_server(
173201
handler_class=_SilentHandler,
174202
)
175203
with server:
176-
t = threading.Thread(target=server.serve_forever)
177-
t.daemon = True
178-
t.start()
204+
Thread(target=server.serve_forever, daemon=True).start()
179205

180206
yield

src/docket/worker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
TASKS_STARTED,
5252
TASKS_STRICKEN,
5353
TASKS_SUCCEEDED,
54+
healthcheck_server,
5455
metrics_server,
5556
)
5657

@@ -152,10 +153,14 @@ async def run(
152153
scheduling_resolution: timedelta = timedelta(milliseconds=250),
153154
schedule_automatic_tasks: bool = True,
154155
until_finished: bool = False,
156+
healthcheck_port: int | None = None,
155157
metrics_port: int | None = None,
156158
tasks: list[str] = ["docket.tasks:standard_tasks"],
157159
) -> None:
158-
with metrics_server(port=metrics_port):
160+
with (
161+
healthcheck_server(port=healthcheck_port),
162+
metrics_server(port=metrics_port),
163+
):
159164
async with Docket(name=docket_name, url=url) as docket:
160165
for task_path in tasks:
161166
docket.register_collection(task_path)

tests/test_instrumentation.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313

1414
from docket import Docket, Worker
1515
from docket.dependencies import Retry
16-
from docket.instrumentation import message_getter, message_setter, metrics_server
16+
from docket.instrumentation import (
17+
healthcheck_server,
18+
message_getter,
19+
message_setter,
20+
metrics_server,
21+
)
1722

1823
tracer = trace.get_tracer(__name__)
1924

@@ -560,3 +565,22 @@ async def test_worker_publishes_depth_gauges(
560565

561566
QUEUE_DEPTH.assert_called_once_with(2, docket_labels)
562567
SCHEDULE_DEPTH.assert_called_once_with(3, docket_labels)
568+
569+
570+
@pytest.fixture
571+
def healthcheck_port() -> int:
572+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
573+
s.bind(("", 0))
574+
return s.getsockname()[1]
575+
576+
577+
def test_healthcheck_server_returns_ok(healthcheck_port: int):
578+
"""Should return 200 and OK body from the liveness endpoint."""
579+
with healthcheck_server(port=healthcheck_port):
580+
conn = http.client.HTTPConnection(f"localhost:{healthcheck_port}")
581+
conn.request("GET", "/")
582+
response = conn.getresponse()
583+
584+
assert response.status == 200
585+
assert response.headers["Content-Type"] == "text/plain"
586+
assert response.read().decode() == "OK"

0 commit comments

Comments
 (0)