Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worker pings #133

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
189 changes: 124 additions & 65 deletions django_tasks/backends/database/management/commands/db_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import signal
import sys
import threading
import time
from argparse import ArgumentParser, ArgumentTypeError
from types import FrameType
Expand All @@ -15,15 +16,18 @@

from django_tasks import DEFAULT_TASK_BACKEND_ALIAS, tasks
from django_tasks.backends.database.backend import DatabaseBackend
from django_tasks.backends.database.models import DBTaskResult
from django_tasks.backends.database.models import DBTaskResult, DBWorkerPing
from django_tasks.backends.database.utils import exclusive_transaction
from django_tasks.exceptions import InvalidTaskBackendError
from django_tasks.signals import task_finished
from django_tasks.task import DEFAULT_QUEUE_NAME
from django_tasks.utils import get_random_id

package_logger = logging.getLogger("django_tasks")
logger = logging.getLogger("django_tasks.backends.database.db_worker")

PING_TIMEOUT = 10


class Worker:
def __init__(
Expand All @@ -34,6 +38,7 @@ def __init__(
batch: bool,
backend_name: str,
startup_delay: bool,
worker_id: str,
):
self.queue_names = queue_names
self.process_all_queues = "*" in queue_names
Expand All @@ -42,11 +47,33 @@ def __init__(
self.backend_name = backend_name
self.startup_delay = startup_delay

self.running = True
self.running_task = False
self.worker_id = worker_id

self.ping_thread = threading.Thread(target=self.run_ping)

self.should_exit = threading.Event()
self.running_task = threading.Lock()

def run_ping(self) -> None:
try:
while True:
try:
DBWorkerPing.ping(self.worker_id)
logger.debug("Sent ping worker_id=%s", self.worker_id)
except Exception:
logger.exception(
"Error updating worker ping worker_id=%s", self.worker_id
)

if self.should_exit.wait(timeout=PING_TIMEOUT):
break
finally:
# Close any connections opened in this thread
for conn in connections.all():
conn.close()

def shutdown(self, signum: int, frame: Optional[FrameType]) -> None:
if not self.running:
if self.should_exit.is_set():
logger.warning(
"Received %s - terminating current task.", signal.strsignal(signum)
)
Expand All @@ -56,9 +83,9 @@ def shutdown(self, signum: int, frame: Optional[FrameType]) -> None:
"Received %s - shutting down gracefully... (press Ctrl+C again to force)",
signal.strsignal(signum),
)
self.running = False
self.should_exit.set()

if not self.running_task:
if not self.running_task.locked():
# If we're not currently running a task, exit immediately.
# This is useful if we're currently in a `sleep`.
sys.exit(0)
Expand All @@ -69,93 +96,103 @@ def configure_signals(self) -> None:
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, self.shutdown)

def start(self) -> None:
def run(self) -> None:
self.configure_signals()

logger.info("Starting worker for queues=%s", ",".join(self.queue_names))
logger.info(
"Starting worker worker_id=%s for queues=%s",
self.worker_id,
",".join(self.queue_names),
)

if self.startup_delay and self.interval:
# Add a random small delay before starting the loop to avoid a thundering herd
# Add a random small delay before starting to avoid a thundering herd
time.sleep(random.random())

while self.running:
self.ping_thread.start()

while not self.should_exit.is_set():
tasks = DBTaskResult.objects.ready().filter(backend_name=self.backend_name)
if not self.process_all_queues:
tasks = tasks.filter(queue_name__in=self.queue_names)

try:
self.running_task = True

# During this transaction, all "ready" tasks are locked. Therefore, it's important
# it be as efficient as possible.
with exclusive_transaction(tasks.db):
try:
task_result = tasks.get_locked()
except OperationalError as e:
# Ignore locked databases and keep trying.
# It should unlock eventually.
if "is locked" in e.args[0]:
task_result = None
else:
raise

if task_result is not None:
# "claim" the task, so it isn't run by another worker process
task_result.claim()
# During this transaction, all "ready" tasks are locked. Therefore, it's important
# it be as efficient as possible.
with exclusive_transaction(tasks.db):
try:
task_result = tasks.get_locked()
except OperationalError as e:
# Ignore locked databases and keep trying.
# It should unlock eventually.
if "is locked" in e.args[0]:
task_result = None
else:
raise

if task_result is not None:
self.run_task(task_result)

finally:
self.running_task = False
# "claim" the task, so it isn't run by another worker process
task_result.claim(self.worker_id)

for conn in connections.all(initialized_only=True):
conn.close()
if task_result is not None:
self.run_task(task_result)

if self.batch and task_result is None:
# If we're running in "batch" mode, terminate the loop (and thus the worker)
return None

# If ctrl-c has just interrupted a task, self.running was cleared,
# If ctrl-c has just interrupted a task, self.should_exit was cleared,
# and we should not sleep, but rather exit immediately.
if self.running and not task_result:
if not self.should_exit.is_set() and not task_result:
# Wait before checking for another task
time.sleep(self.interval)

def run_task(self, db_task_result: DBTaskResult) -> None:
"""
Run the given task, marking it as succeeded or failed.
"""
try:
task = db_task_result.task
task_result = db_task_result.task_result

logger.info(
"Task id=%s path=%s state=%s",
db_task_result.id,
db_task_result.task_path,
task_result.status,
)
return_value = task.call(*task_result.args, **task_result.kwargs)

# Setting the return and success value inside the error handling,
# So errors setting it (eg JSON encode) can still be recorded
db_task_result.set_succeeded(return_value)
task_finished.send(
sender=type(task.get_backend()), task_result=db_task_result.task_result
)
except BaseException as e:
db_task_result.set_failed(e)
with self.running_task:
try:
sender = type(db_task_result.task.get_backend())
task = db_task_result.task
task_result = db_task_result.task_result
except (ModuleNotFoundError, SuspiciousOperation):
logger.exception("Task id=%s failed unexpectedly", db_task_result.id)
else:

logger.info(
"Task id=%s path=%s state=%s",
db_task_result.id,
db_task_result.task_path,
task_result.status,
)
return_value = task.call(*task_result.args, **task_result.kwargs)

# Setting the return and success value inside the error handling,
# So errors setting it (eg JSON encode) can still be recorded
db_task_result.set_succeeded(return_value)
task_finished.send(
sender=sender,
task_result=task_result,
sender=type(task.get_backend()),
task_result=db_task_result.task_result,
)
except BaseException as e:
db_task_result.set_failed(e)
try:
sender = type(db_task_result.task.get_backend())
task_result = db_task_result.task_result
except (ModuleNotFoundError, SuspiciousOperation):
logger.exception(
"Task id=%s failed unexpectedly",
db_task_result.id,
)
else:
task_finished.send(
sender=sender,
task_result=task_result,
)

for conn in connections.all(initialized_only=True):
conn.close()

def stop(self) -> None:
self.should_exit.set()
self.ping_thread.join()
DBWorkerPing.cleanup_ping(self.worker_id)


def valid_backend_name(val: str) -> str:
Expand All @@ -177,6 +214,14 @@ def valid_interval(val: str) -> float:
return num


def validate_worker_id(val: str) -> str:
if not val:
raise ArgumentTypeError("Worker id must not be empty")
if len(val) > 64:
raise ArgumentTypeError("Worker ids must be shorter than 64 characters")
return val


class Command(BaseCommand):
help = "Run a database background worker"

Expand Down Expand Up @@ -214,6 +259,13 @@ def add_arguments(self, parser: ArgumentParser) -> None:
dest="startup_delay",
help="Don't add a small delay at startup.",
)
parser.add_argument(
"--worker-id",
nargs="?",
type=validate_worker_id,
help="Worker id. MUST be unique across worker pool (default: auto-generate)",
default=get_random_id(),
)

def configure_logging(self, verbosity: int) -> None:
if verbosity == 0:
Expand All @@ -239,6 +291,7 @@ def handle(
batch: bool,
backend_name: str,
startup_delay: bool,
worker_id: str,
**options: dict,
) -> None:
self.configure_logging(verbosity)
Expand All @@ -249,9 +302,15 @@ def handle(
batch=batch,
backend_name=backend_name,
startup_delay=startup_delay,
worker_id=worker_id,
)

worker.start()
try:
worker.run()
finally:
worker.stop()

if batch:
logger.info("No more tasks to run - exiting gracefully.")
logger.info(
"No more tasks to run for worker_id=%s - exiting gracefully.", worker_id
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Generated by Django 5.1.5 on 2025-02-04 17:10

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("django_tasks_database", "0014_remove_dbtaskresult_exception_data"),
]

operations = [
migrations.CreateModel(
name="DBWorkerPing",
fields=[
(
"worker_id",
models.CharField(
editable=False, max_length=64, primary_key=True, serialize=False
),
),
("last_ping", models.DateTimeField()),
],
),
migrations.AddField(
model_name="dbtaskresult",
name="worker_id",
field=models.CharField(default="", max_length=64, verbose_name="worker id"),
),
]
Loading
Loading