Skip to content

Commit 5e6cd4b

Browse files
committed
Add worker pings
1 parent b08dfa6 commit 5e6cd4b

File tree

5 files changed

+117
-30
lines changed

5 files changed

+117
-30
lines changed

django_tasks/backends/database/management/commands/db_worker.py

+40-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
import signal
55
import sys
6+
import threading
67
import time
78
import uuid
89
from argparse import ArgumentParser, ArgumentTypeError
@@ -16,7 +17,7 @@
1617

1718
from django_tasks import DEFAULT_TASK_BACKEND_ALIAS, tasks
1819
from django_tasks.backends.database.backend import DatabaseBackend
19-
from django_tasks.backends.database.models import DBTaskResult
20+
from django_tasks.backends.database.models import DBTaskResult, DBWorkerPing
2021
from django_tasks.backends.database.utils import exclusive_transaction
2122
from django_tasks.exceptions import InvalidTaskBackendError
2223
from django_tasks.signals import task_finished
@@ -46,11 +47,29 @@ def __init__(
4647

4748
self.worker_id = worker_id
4849

49-
self.running = True
50+
self.ping_thread = threading.Thread(target=self.run_ping)
51+
52+
self.should_exit = threading.Event()
5053
self.running_task = False
5154

55+
def run_ping(self) -> None:
56+
try:
57+
while True:
58+
try:
59+
DBWorkerPing.ping(self.worker_id)
60+
except Exception:
61+
logger.exception(
62+
"Error sending worker ping worker_id=%s", self.worker_id
63+
)
64+
if self.should_exit.wait(timeout=10):
65+
break
66+
finally:
67+
# Close any connections opened in this thread
68+
for conn in connections.all():
69+
conn.close()
70+
5271
def shutdown(self, signum: int, frame: Optional[FrameType]) -> None:
53-
if not self.running:
72+
if self.should_exit.is_set():
5473
logger.warning(
5574
"Received %s - terminating current task.", signal.strsignal(signum)
5675
)
@@ -60,7 +79,7 @@ def shutdown(self, signum: int, frame: Optional[FrameType]) -> None:
6079
"Received %s - shutting down gracefully... (press Ctrl+C again to force)",
6180
signal.strsignal(signum),
6281
)
63-
self.running = False
82+
self.should_exit.set()
6483

6584
if not self.running_task:
6685
# If we're not currently running a task, exit immediately.
@@ -73,8 +92,9 @@ def configure_signals(self) -> None:
7392
if hasattr(signal, "SIGQUIT"):
7493
signal.signal(signal.SIGQUIT, self.shutdown)
7594

76-
def start(self) -> None:
95+
def run(self) -> None:
7796
self.configure_signals()
97+
self.ping_thread.start()
7898

7999
logger.info(
80100
"Starting worker worker_id=%s for queues=%s",
@@ -86,7 +106,7 @@ def start(self) -> None:
86106
# Add a random small delay before starting the loop to avoid a thundering herd
87107
time.sleep(random.random())
88108

89-
while self.running:
109+
while not self.should_exit.is_set():
90110
tasks = DBTaskResult.objects.ready().filter(backend_name=self.backend_name)
91111
if not self.process_all_queues:
92112
tasks = tasks.filter(queue_name__in=self.queue_names)
@@ -117,16 +137,13 @@ def start(self) -> None:
117137
finally:
118138
self.running_task = False
119139

120-
for conn in connections.all(initialized_only=True):
121-
conn.close()
122-
123140
if self.batch and task_result is None:
124141
# If we're running in "batch" mode, terminate the loop (and thus the worker)
125142
return None
126143

127-
# If ctrl-c has just interrupted a task, self.running was cleared,
144+
# If ctrl-c has just interrupted a task, self.should_exit was cleared,
128145
# and we should not sleep, but rather exit immediately.
129-
if self.running and not task_result:
146+
if not self.should_exit.is_set() and not task_result:
130147
# Wait before checking for another task
131148
time.sleep(self.interval)
132149

@@ -168,6 +185,14 @@ def run_task(self, db_task_result: DBTaskResult) -> None:
168185
task_result=task_result,
169186
)
170187

188+
for conn in connections.all(initialized_only=True):
189+
conn.close()
190+
191+
def stop(self) -> None:
192+
self.should_exit.set()
193+
DBWorkerPing.cleanup_ping(self.worker_id)
194+
self.ping_thread.join()
195+
171196

172197
def valid_backend_name(val: str) -> str:
173198
try:
@@ -271,7 +296,10 @@ def handle(
271296
worker_id=worker_id,
272297
)
273298

274-
worker.start()
299+
try:
300+
worker.run()
301+
finally:
302+
worker.stop()
275303

276304
if batch:
277305
logger.info(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Generated by Django 5.1.5 on 2025-02-01 13:08
2+
3+
import uuid
4+
5+
from django.db import migrations, models
6+
7+
8+
class Migration(migrations.Migration):
9+
dependencies = [
10+
("django_tasks_database", "0015_dbtaskresult_worker_id"),
11+
]
12+
13+
operations = [
14+
migrations.CreateModel(
15+
name="DBWorkerPing",
16+
fields=[
17+
(
18+
"worker_id",
19+
models.UUIDField(
20+
default=uuid.uuid4,
21+
editable=False,
22+
primary_key=True,
23+
serialize=False,
24+
),
25+
),
26+
("last_ping", models.DateTimeField()),
27+
],
28+
),
29+
]

django_tasks/backends/database/models.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import uuid
3+
from datetime import timedelta
34
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
45

56
import django
@@ -235,3 +236,32 @@ def set_failed(self, exc: BaseException) -> None:
235236
"worker_id",
236237
]
237238
)
239+
240+
241+
class DBWorkerPingQuerySet(models.QuerySet):
242+
STALE_TIME = 600
243+
244+
def stale(self) -> "DBWorkerPingQuerySet":
245+
return self.filter(
246+
last_ping__lte=timezone.now() - timedelta(seconds=self.STALE_TIME)
247+
)
248+
249+
250+
class DBWorkerPing(models.Model):
251+
worker_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
252+
253+
last_ping = models.DateTimeField()
254+
255+
objects = DBWorkerPingQuerySet.as_manager()
256+
257+
@classmethod
258+
@retry()
259+
def ping(cls, worker_id: str) -> None:
260+
cls.objects.update_or_create(
261+
worker_id=worker_id, defaults={"last_ping": timezone.now()}
262+
)
263+
264+
@classmethod
265+
@retry()
266+
def cleanup_ping(cls, worker_id: str) -> None:
267+
cls.objects.filter(worker_id=worker_id).delete()

tests/settings.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@
5959
)
6060
}
6161

62-
# Set exclusive transactions in 5.1+
63-
if django.VERSION >= (5, 1) and "sqlite" in DATABASES["default"]["ENGINE"]:
64-
DATABASES["default"].setdefault("OPTIONS", {})["transaction_mode"] = "EXCLUSIVE"
65-
6662
if "sqlite" in DATABASES["default"]["ENGINE"]:
63+
if django.VERSION >= (5, 1):
64+
# Set exclusive transactions in 5.1+
65+
DATABASES["default"].setdefault("OPTIONS", {})["transaction_mode"] = "EXCLUSIVE"
66+
6767
DATABASES["default"]["TEST"] = {"NAME": os.path.join(BASE_DIR, "db-test.sqlite3")}
6868

6969

tests/tests/test_database_backend.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def test_run_enqueued_task(self) -> None:
447447

448448
self.assertEqual(result.status, ResultStatus.NEW)
449449

450-
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
450+
with self.assertNumQueries(12 if connection.vendor == "mysql" else 11):
451451
self.run_worker()
452452

453453
self.assertEqual(result.status, ResultStatus.NEW)
@@ -468,28 +468,28 @@ def test_batch_processes_all_tasks(self) -> None:
468468

469469
self.assertEqual(DBTaskResult.objects.ready().count(), 4)
470470

471-
with self.assertNumQueries(27 if connection.vendor == "mysql" else 23):
471+
with self.assertNumQueries(30 if connection.vendor == "mysql" else 26):
472472
self.run_worker()
473473

474474
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
475475
self.assertEqual(DBTaskResult.objects.succeeded().count(), 3)
476476
self.assertEqual(DBTaskResult.objects.failed().count(), 1)
477477

478478
def test_no_tasks(self) -> None:
479-
with self.assertNumQueries(3):
479+
with self.assertNumQueries(6):
480480
self.run_worker()
481481

482482
def test_doesnt_process_different_queue(self) -> None:
483483
result = test_tasks.noop_task.using(queue_name="queue-1").enqueue()
484484

485485
self.assertEqual(DBTaskResult.objects.ready().count(), 1)
486486

487-
with self.assertNumQueries(3):
487+
with self.assertNumQueries(6):
488488
self.run_worker()
489489

490490
self.assertEqual(DBTaskResult.objects.ready().count(), 1)
491491

492-
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
492+
with self.assertNumQueries(12 if connection.vendor == "mysql" else 11):
493493
self.run_worker(queue_name=result.task.queue_name)
494494

495495
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
@@ -499,12 +499,12 @@ def test_process_all_queues(self) -> None:
499499

500500
self.assertEqual(DBTaskResult.objects.ready().count(), 1)
501501

502-
with self.assertNumQueries(3):
502+
with self.assertNumQueries(6):
503503
self.run_worker()
504504

505505
self.assertEqual(DBTaskResult.objects.ready().count(), 1)
506506

507-
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
507+
with self.assertNumQueries(12 if connection.vendor == "mysql" else 11):
508508
self.run_worker(queue_name="*")
509509

510510
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
@@ -519,7 +519,7 @@ def test_failing_task(self) -> None:
519519
with self.assertRaisesMessage(ValueError, "Task has not finished yet"):
520520
result.traceback # noqa: B018
521521

522-
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
522+
with self.assertNumQueries(12 if connection.vendor == "mysql" else 11):
523523
self.run_worker()
524524

525525
self.assertEqual(result.status, ResultStatus.NEW)
@@ -553,7 +553,7 @@ def test_complex_exception(self) -> None:
553553
with self.assertRaisesMessage(ValueError, "Task has not finished"):
554554
result.traceback # noqa: B018
555555

556-
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
556+
with self.assertNumQueries(12 if connection.vendor == "mysql" else 11):
557557
self.run_worker()
558558

559559
self.assertEqual(result.status, ResultStatus.NEW)
@@ -577,12 +577,12 @@ def test_doesnt_process_different_backend(self) -> None:
577577

578578
self.assertEqual(DBTaskResult.objects.ready().count(), 1)
579579

580-
with self.assertNumQueries(3):
580+
with self.assertNumQueries(6):
581581
self.run_worker(backend_name="dummy")
582582

583583
self.assertEqual(DBTaskResult.objects.ready().count(), 1)
584584

585-
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
585+
with self.assertNumQueries(12 if connection.vendor == "mysql" else 11):
586586
self.run_worker(backend_name=result.backend)
587587

588588
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
@@ -650,7 +650,7 @@ def test_run_after(self) -> None:
650650
self.assertEqual(DBTaskResult.objects.count(), 1)
651651
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
652652

653-
with self.assertNumQueries(3):
653+
with self.assertNumQueries(6):
654654
self.run_worker()
655655

656656
self.assertEqual(DBTaskResult.objects.count(), 1)
@@ -661,7 +661,7 @@ def test_run_after(self) -> None:
661661

662662
self.assertEqual(DBTaskResult.objects.ready().count(), 1)
663663

664-
with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
664+
with self.assertNumQueries(12 if connection.vendor == "mysql" else 11):
665665
self.run_worker()
666666

667667
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
@@ -711,7 +711,7 @@ def test_verbose_logging(self) -> None:
711711

712712
stdout = StringIO()
713713
self.run_worker(verbosity=3, stdout=stdout, stderr=stdout)
714-
714+
self.maxDiff = None
715715
self.assertEqual(
716716
stdout.getvalue().splitlines(),
717717
[

0 commit comments

Comments
 (0)