From 0b69e8a81c3a4a72ff3cb044ba129deffffbd5a2 Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Wed, 23 Apr 2025 07:44:12 -0400 Subject: [PATCH] Adds a healthcheck HTTP server for docket workers --- src/docket/cli.py | 10 ++++++++++ src/docket/instrumentation.py | 34 ++++++++++++++++++++++++++++++---- src/docket/worker.py | 7 ++++++- tests/test_instrumentation.py | 26 +++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/docket/cli.py b/src/docket/cli.py index 7dd358b8..7dfbabc4 100644 --- a/src/docket/cli.py +++ b/src/docket/cli.py @@ -259,11 +259,20 @@ def worker( help="Exit after the current docket is finished", ), ] = False, + healthcheck_port: Annotated[ + int | None, + typer.Option( + "--healthcheck-port", + help="The port to serve a healthcheck on", + envvar="DOCKET_WORKER_HEALTHCHECK_PORT", + ), + ] = None, metrics_port: Annotated[ int | None, typer.Option( "--metrics-port", help="The port to serve Prometheus metrics on", + envvar="DOCKET_WORKER_METRICS_PORT", ), ] = None, ) -> None: @@ -279,6 +288,7 @@ def worker( scheduling_resolution=scheduling_resolution, schedule_automatic_tasks=schedule_automatic_tasks, until_finished=until_finished, + healthcheck_port=healthcheck_port, metrics_port=metrics_port, tasks=tasks, ) diff --git a/src/docket/instrumentation.py b/src/docket/instrumentation.py index 4db017b7..db623558 100644 --- a/src/docket/instrumentation.py +++ b/src/docket/instrumentation.py @@ -1,5 +1,5 @@ -import threading from contextlib import contextmanager +from threading import Thread from typing import Generator, cast from opentelemetry import metrics @@ -145,6 +145,34 @@ def set( message_setter: MessageSetter = MessageSetter() +@contextmanager +def healthcheck_server( + host: str = "0.0.0.0", port: int | None = None +) -> Generator[None, None, None]: + if port is None: + yield + return + + from http.server import BaseHTTPRequestHandler, HTTPServer + + class HealthcheckHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.send_header("Content-type", "text/plain") + self.end_headers() + self.wfile.write(b"OK") + + def log_message(self, format: str, *args: object) -> None: + # Suppress access logs from the webserver + pass + + server = HTTPServer((host, port), HealthcheckHandler) + with server: + Thread(target=server.serve_forever, daemon=True).start() + + yield + + @contextmanager def metrics_server( host: str = "0.0.0.0", port: int | None = None @@ -173,8 +201,6 @@ def metrics_server( handler_class=_SilentHandler, ) with server: - t = threading.Thread(target=server.serve_forever) - t.daemon = True - t.start() + Thread(target=server.serve_forever, daemon=True).start() yield diff --git a/src/docket/worker.py b/src/docket/worker.py index c0c672ce..10652288 100644 --- a/src/docket/worker.py +++ b/src/docket/worker.py @@ -51,6 +51,7 @@ TASKS_STARTED, TASKS_STRICKEN, TASKS_SUCCEEDED, + healthcheck_server, metrics_server, ) @@ -152,10 +153,14 @@ async def run( scheduling_resolution: timedelta = timedelta(milliseconds=250), schedule_automatic_tasks: bool = True, until_finished: bool = False, + healthcheck_port: int | None = None, metrics_port: int | None = None, tasks: list[str] = ["docket.tasks:standard_tasks"], ) -> None: - with metrics_server(port=metrics_port): + with ( + healthcheck_server(port=healthcheck_port), + metrics_server(port=metrics_port), + ): async with Docket(name=docket_name, url=url) as docket: for task_path in tasks: docket.register_collection(task_path) diff --git a/tests/test_instrumentation.py b/tests/test_instrumentation.py index 0207896c..104ed6ca 100644 --- a/tests/test_instrumentation.py +++ b/tests/test_instrumentation.py @@ -13,7 +13,12 @@ from docket import Docket, Worker from docket.dependencies import Retry -from docket.instrumentation import message_getter, message_setter, metrics_server +from docket.instrumentation import ( + healthcheck_server, + message_getter, + message_setter, + metrics_server, +) tracer = trace.get_tracer(__name__) @@ -560,3 +565,22 @@ async def test_worker_publishes_depth_gauges( QUEUE_DEPTH.assert_called_once_with(2, docket_labels) SCHEDULE_DEPTH.assert_called_once_with(3, docket_labels) + + +@pytest.fixture +def healthcheck_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def test_healthcheck_server_returns_ok(healthcheck_port: int): + """Should return 200 and OK body from the liveness endpoint.""" + with healthcheck_server(port=healthcheck_port): + conn = http.client.HTTPConnection(f"localhost:{healthcheck_port}") + conn.request("GET", "/") + response = conn.getresponse() + + assert response.status == 200 + assert response.headers["Content-Type"] == "text/plain" + assert response.read().decode() == "OK"