Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/docket/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,20 @@ def worker(
help="Exit after the current docket is finished",
),
] = False,
healthcheck_port: Annotated[
int | None,
typer.Option(
"--healthcheck-port",
help="The port to serve a healthcheck on",
envvar="DOCKET_WORKER_HEALTHCHECK_PORT",
),
] = None,
metrics_port: Annotated[
int | None,
typer.Option(
"--metrics-port",
help="The port to serve Prometheus metrics on",
envvar="DOCKET_WORKER_METRICS_PORT",
),
] = None,
) -> None:
Expand All @@ -279,6 +288,7 @@ def worker(
scheduling_resolution=scheduling_resolution,
schedule_automatic_tasks=schedule_automatic_tasks,
until_finished=until_finished,
healthcheck_port=healthcheck_port,
metrics_port=metrics_port,
tasks=tasks,
)
Expand Down
34 changes: 30 additions & 4 deletions src/docket/instrumentation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from contextlib import contextmanager
from threading import Thread
from typing import Generator, cast

from opentelemetry import metrics
Expand Down Expand Up @@ -145,6 +145,34 @@ def set(
message_setter: MessageSetter = MessageSetter()


@contextmanager
def healthcheck_server(
host: str = "0.0.0.0", port: int | None = None
) -> Generator[None, None, None]:
if port is None:
yield
return

from http.server import BaseHTTPRequestHandler, HTTPServer

class HealthcheckHandler(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"OK")

def log_message(self, format: str, *args: object) -> None:
# Suppress access logs from the webserver
pass

server = HTTPServer((host, port), HealthcheckHandler)
with server:
Thread(target=server.serve_forever, daemon=True).start()

yield


@contextmanager
def metrics_server(
host: str = "0.0.0.0", port: int | None = None
Expand Down Expand Up @@ -173,8 +201,6 @@ def metrics_server(
handler_class=_SilentHandler,
)
with server:
t = threading.Thread(target=server.serve_forever)
t.daemon = True
t.start()
Thread(target=server.serve_forever, daemon=True).start()

yield
7 changes: 6 additions & 1 deletion src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
TASKS_STARTED,
TASKS_STRICKEN,
TASKS_SUCCEEDED,
healthcheck_server,
metrics_server,
)

Expand Down Expand Up @@ -152,10 +153,14 @@ async def run(
scheduling_resolution: timedelta = timedelta(milliseconds=250),
schedule_automatic_tasks: bool = True,
until_finished: bool = False,
healthcheck_port: int | None = None,
metrics_port: int | None = None,
tasks: list[str] = ["docket.tasks:standard_tasks"],
) -> None:
with metrics_server(port=metrics_port):
with (
healthcheck_server(port=healthcheck_port),
metrics_server(port=metrics_port),
):
async with Docket(name=docket_name, url=url) as docket:
for task_path in tasks:
docket.register_collection(task_path)
Expand Down
26 changes: 25 additions & 1 deletion tests/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from docket import Docket, Worker
from docket.dependencies import Retry
from docket.instrumentation import message_getter, message_setter, metrics_server
from docket.instrumentation import (
healthcheck_server,
message_getter,
message_setter,
metrics_server,
)

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -560,3 +565,22 @@ async def test_worker_publishes_depth_gauges(

QUEUE_DEPTH.assert_called_once_with(2, docket_labels)
SCHEDULE_DEPTH.assert_called_once_with(3, docket_labels)


@pytest.fixture
def healthcheck_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


def test_healthcheck_server_returns_ok(healthcheck_port: int):
"""Should return 200 and OK body from the liveness endpoint."""
with healthcheck_server(port=healthcheck_port):
conn = http.client.HTTPConnection(f"localhost:{healthcheck_port}")
conn.request("GET", "/")
response = conn.getresponse()

assert response.status == 200
assert response.headers["Content-Type"] == "text/plain"
assert response.read().decode() == "OK"