Skip to content

Commit b08dfa6

Browse files
committed
Store the worker processing a task on the task itself
1 parent 4edc550 commit b08dfa6

File tree

7 files changed

+82
-9
lines changed

7 files changed

+82
-9
lines changed

django_tasks/backends/database/management/commands/db_worker.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import signal
55
import sys
66
import time
7+
import uuid
78
from argparse import ArgumentParser, ArgumentTypeError
89
from types import FrameType
910
from typing import Optional
@@ -34,6 +35,7 @@ def __init__(
3435
batch: bool,
3536
backend_name: str,
3637
startup_delay: bool,
38+
worker_id: uuid.UUID,
3739
):
3840
self.queue_names = queue_names
3941
self.process_all_queues = "*" in queue_names
@@ -42,6 +44,8 @@ def __init__(
4244
self.backend_name = backend_name
4345
self.startup_delay = startup_delay
4446

47+
self.worker_id = worker_id
48+
4549
self.running = True
4650
self.running_task = False
4751

@@ -72,7 +76,11 @@ def configure_signals(self) -> None:
7276
def start(self) -> None:
7377
self.configure_signals()
7478

75-
logger.info("Starting worker for queues=%s", ",".join(self.queue_names))
79+
logger.info(
80+
"Starting worker worker_id=%s for queues=%s",
81+
self.worker_id,
82+
",".join(self.queue_names),
83+
)
7684

7785
if self.startup_delay and self.interval:
7886
# Add a random small delay before starting the loop to avoid a thundering herd
@@ -101,7 +109,7 @@ def start(self) -> None:
101109

102110
if task_result is not None:
103111
# "claim" the task, so it isn't run by another worker process
104-
task_result.claim()
112+
task_result.claim(self.worker_id)
105113

106114
if task_result is not None:
107115
self.run_task(task_result)
@@ -150,7 +158,10 @@ def run_task(self, db_task_result: DBTaskResult) -> None:
150158
sender = type(db_task_result.task.get_backend())
151159
task_result = db_task_result.task_result
152160
except (ModuleNotFoundError, SuspiciousOperation):
153-
logger.exception("Task id=%s failed unexpectedly", db_task_result.id)
161+
logger.exception(
162+
"Task id=%s failed unexpectedly",
163+
db_task_result.id,
164+
)
154165
else:
155166
task_finished.send(
156167
sender=sender,
@@ -214,6 +225,13 @@ def add_arguments(self, parser: ArgumentParser) -> None:
214225
dest="startup_delay",
215226
help="Don't add a small delay at startup.",
216227
)
228+
parser.add_argument(
229+
"--worker-id",
230+
nargs="?",
231+
type=uuid.UUID,
232+
help="Worker id. MUST be unique across worker pool (default: auto-generate)",
233+
default=uuid.uuid4(),
234+
)
217235

218236
def configure_logging(self, verbosity: int) -> None:
219237
if verbosity == 0:
@@ -239,6 +257,7 @@ def handle(
239257
batch: bool,
240258
backend_name: str,
241259
startup_delay: bool,
260+
worker_id: uuid.UUID,
242261
**options: dict,
243262
) -> None:
244263
self.configure_logging(verbosity)
@@ -249,9 +268,12 @@ def handle(
249268
batch=batch,
250269
backend_name=backend_name,
251270
startup_delay=startup_delay,
271+
worker_id=worker_id,
252272
)
253273

254274
worker.start()
255275

256276
if batch:
257-
logger.info("No more tasks to run - exiting gracefully.")
277+
logger.info(
278+
"No more tasks to run for worker_id=%s - exiting gracefully.", worker_id
279+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Generated by Django 5.1.5 on 2025-02-01 12:16
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("django_tasks_database", "0014_remove_dbtaskresult_exception_data"),
9+
]
10+
11+
operations = [
12+
migrations.AddField(
13+
model_name="dbtaskresult",
14+
name="worker_id",
15+
field=models.UUIDField(null=True, verbose_name="worker id"),
16+
),
17+
]

django_tasks/backends/database/models.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class DBTaskResult(GenericBase[P, T], models.Model):
8989
started_at = models.DateTimeField(_("started at"), null=True)
9090
finished_at = models.DateTimeField(_("finished at"), null=True)
9191

92+
worker_id = models.UUIDField(_("worker id"), null=True)
93+
9294
args_kwargs = models.JSONField(_("args kwargs"))
9395

9496
priority = models.IntegerField(_("priority"), default=DEFAULT_PRIORITY)
@@ -163,6 +165,7 @@ def task_result(self) -> "TaskResult[T]":
163165
args=self.args_kwargs["args"],
164166
kwargs=self.args_kwargs["kwargs"],
165167
backend=self.backend_name,
168+
worker_id=None if self.worker_id is None else str(self.worker_id),
166169
)
167170

168171
object.__setattr__(task_result, "_exception_class", exception_class)
@@ -186,13 +189,14 @@ def task_name(self) -> str:
186189
return self.task_path
187190

188191
@retry(backoff_delay=0)
189-
def claim(self) -> None:
192+
def claim(self, worker_id: uuid.UUID) -> None:
190193
"""
191-
Mark as job as being run
194+
Mark as job as being run by a worker
192195
"""
193196
self.status = ResultStatus.RUNNING
194197
self.started_at = timezone.now()
195-
self.save(update_fields=["status", "started_at"])
198+
self.worker_id = worker_id
199+
self.save(update_fields=["status", "started_at", "worker_id"])
196200

197201
@retry()
198202
def set_succeeded(self, return_value: Any) -> None:
@@ -201,13 +205,15 @@ def set_succeeded(self, return_value: Any) -> None:
201205
self.return_value = return_value
202206
self.exception_class_path = ""
203207
self.traceback = ""
208+
self.worker_id = None
204209
self.save(
205210
update_fields=[
206211
"status",
207212
"return_value",
208213
"finished_at",
209214
"exception_class_path",
210215
"traceback",
216+
"worker_id",
211217
]
212218
)
213219

@@ -218,12 +224,14 @@ def set_failed(self, exc: BaseException) -> None:
218224
self.exception_class_path = get_module_path(type(exc))
219225
self.traceback = get_exception_traceback(exc)
220226
self.return_value = None
227+
self.worker_id = None
221228
self.save(
222229
update_fields=[
223230
"status",
224231
"return_value",
225232
"finished_at",
226233
"exception_class_path",
227234
"traceback",
235+
"worker_id",
228236
]
229237
)

django_tasks/backends/dummy.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def enqueue(
4747
args=args,
4848
kwargs=kwargs,
4949
backend=self.alias,
50+
worker_id=None,
5051
)
5152

5253
if self._get_enqueue_on_commit_for_task(task) is not False:

django_tasks/backends/immediate.py

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def enqueue(
8080
args=args,
8181
kwargs=kwargs,
8282
backend=self.alias,
83+
worker_id=None,
8384
)
8485

8586
if self._get_enqueue_on_commit_for_task(task) is not False:

django_tasks/task.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"started_at",
4242
"status",
4343
"enqueued_at",
44+
"worker_id",
4445
}
4546

4647

@@ -249,6 +250,9 @@ class TaskResult(Generic[T]):
249250
backend: str
250251
"""The name of the backend the task will run on"""
251252

253+
worker_id: Optional[str]
254+
"""The id of the worker running the task"""
255+
252256
_exception_class: Optional[type[BaseException]] = field(init=False, default=None)
253257
_traceback: Optional[str] = field(init=False, default=None)
254258

tests/tests/test_database_backend.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ def test_enqueue_logs(self) -> None:
415415
}
416416
)
417417
class DatabaseBackendWorkerTestCase(TransactionTestCase):
418+
worker_id = uuid.uuid4()
419+
418420
run_worker = staticmethod(
419421
partial(
420422
call_command,
@@ -423,6 +425,7 @@ class DatabaseBackendWorkerTestCase(TransactionTestCase):
423425
batch=True,
424426
interval=0,
425427
startup_delay=False,
428+
worker_id=worker_id,
426429
)
427430
)
428431

@@ -454,6 +457,7 @@ def test_run_enqueued_task(self) -> None:
454457
self.assertGreaterEqual(result.started_at, result.enqueued_at) # type:ignore[arg-type,misc]
455458
self.assertGreaterEqual(result.finished_at, result.started_at) # type:ignore[arg-type,misc]
456459
self.assertEqual(result.status, ResultStatus.SUCCEEDED)
460+
self.assertIsNone(result.worker_id)
457461

458462
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
459463

@@ -629,6 +633,15 @@ def test_fractional_interval(self) -> None:
629633

630634
self.assertEqual(worker_class.mock_calls[0].kwargs["interval"], 0.1)
631635

636+
def test_invalid_worker_id(self) -> None:
637+
output = StringIO()
638+
with redirect_stderr(output):
639+
with self.assertRaises(SystemExit):
640+
execute_from_command_line(
641+
["django-admin", "db_worker", "--worker-id", "123"]
642+
)
643+
self.assertIn("invalid UUID value", output.getvalue())
644+
632645
def test_run_after(self) -> None:
633646
result = test_tasks.noop_task.using(
634647
run_after=timezone.now() + timedelta(hours=10)
@@ -702,10 +715,10 @@ def test_verbose_logging(self) -> None:
702715
self.assertEqual(
703716
stdout.getvalue().splitlines(),
704717
[
705-
"Starting worker for queues=default",
718+
f"Starting worker worker_id={self.worker_id} for queues=default",
706719
f"Task id={result.id} path=tests.tasks.noop_task state=RUNNING",
707720
f"Task id={result.id} path=tests.tasks.noop_task state=SUCCEEDED",
708-
"No more tasks to run - exiting gracefully.",
721+
f"No more tasks to run for worker_id={self.worker_id} - exiting gracefully.",
709722
],
710723
)
711724

@@ -1333,6 +1346,7 @@ def test_interrupt_signals(self) -> None:
13331346

13341347
result.refresh()
13351348
self.assertEqual(result.status, ResultStatus.RUNNING)
1349+
self.assertIsNotNone(result.worker_id)
13361350

13371351
process.send_signal(sig)
13381352

@@ -1343,6 +1357,7 @@ def test_interrupt_signals(self) -> None:
13431357
result.refresh()
13441358

13451359
self.assertEqual(result.status, ResultStatus.SUCCEEDED)
1360+
self.assertIsNone(result.worker_id)
13461361

13471362
@skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows")
13481363
def test_repeat_ctrl_c(self) -> None:
@@ -1355,6 +1370,7 @@ def test_repeat_ctrl_c(self) -> None:
13551370

13561371
result.refresh()
13571372
self.assertEqual(result.status, ResultStatus.RUNNING)
1373+
self.assertIsNotNone(result.worker_id)
13581374

13591375
process.send_signal(signal.SIGINT)
13601376

@@ -1373,6 +1389,7 @@ def test_repeat_ctrl_c(self) -> None:
13731389
result.refresh()
13741390
self.assertEqual(result.status, ResultStatus.FAILED)
13751391
self.assertEqual(result.exception_class, SystemExit)
1392+
self.assertIsNone(result.worker_id)
13761393

13771394
@skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL")
13781395
def test_kill(self) -> None:
@@ -1411,6 +1428,7 @@ def test_system_exit_task(self) -> None:
14111428
result.refresh()
14121429
self.assertEqual(result.status, ResultStatus.FAILED)
14131430
self.assertEqual(result.exception_class, SystemExit)
1431+
self.assertIsNone(result.worker_id)
14141432

14151433
def test_keyboard_interrupt_task(self) -> None:
14161434
result = test_tasks.failing_task_keyboard_interrupt.enqueue()
@@ -1423,6 +1441,7 @@ def test_keyboard_interrupt_task(self) -> None:
14231441
result.refresh()
14241442
self.assertEqual(result.status, ResultStatus.FAILED)
14251443
self.assertEqual(result.exception_class, KeyboardInterrupt)
1444+
self.assertIsNone(result.worker_id)
14261445

14271446
def test_multiple_workers(self) -> None:
14281447
results = [test_tasks.sleep_for.enqueue(0.1) for _ in range(10)]
@@ -1439,6 +1458,7 @@ def test_multiple_workers(self) -> None:
14391458
for result in results:
14401459
result.refresh()
14411460
self.assertEqual(result.status, ResultStatus.SUCCEEDED)
1461+
self.assertIsNone(result.worker_id)
14421462

14431463
all_output = ""
14441464

0 commit comments

Comments
 (0)