Skip to content

Commit 4c5f63e

Browse files
Add configurable transaction management (#57)
1 parent 87e1fdc commit 4c5f63e

11 files changed

+419
-59
lines changed

README.md

+19-1
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ The task decorator accepts a few arguments to customize the task:
6767
- `priority`: The priority of the task (between -100 and 100. Larger numbers are higher priority. 0 by default)
6868
- `queue_name`: Whether to run the task on a specific queue
6969
- `backend`: Name of the backend for this task to use (as defined in `TASKS`)
70+
- `enqueue_on_commit`: Whether the task is enqueued when the current transaction commits successfully, or enqueued immediately. By default, this is handled by the backend (see below). `enqueue_on_commit` may not be modified with `.using`.
7071

71-
These attributes can also be modified at run-time with `.using`:
72+
These attributes (besides `enqueue_on_commit`) can also be modified at run-time with `.using`:
7273

7374
```python
7475
modified_task = calculate_meaning_of_life.using(priority=10)
@@ -88,6 +89,23 @@ The returned `TaskResult` can be interrogated to query the current state of the
8889

8990
If the task takes arguments, these can be passed as-is to `enqueue`.
9091

92+
#### Transactions
93+
94+
By default, tasks are enqueued after the current transaction (if there is one) commits successfully (using Django's `transaction.on_commit` method), rather than enqueueing immediately.
95+
96+
This can be configured using the `ENQUEUE_ON_COMMIT` setting. `True` and `False` force the behaviour.
97+
98+
```python
99+
TASKS = {
100+
"default": {
101+
"BACKEND": "django_tasks.backends.immediate.ImmediateBackend",
102+
"ENQUEUE_ON_COMMIT": False
103+
}
104+
}
105+
```
106+
107+
This can also be configured per-task by passing `enqueue_on_commit` to the `task` decorator.
108+
91109
### Queue names
92110

93111
By default, tasks are enqueued onto the "default" queue. When using multiple queues, it can be useful to constrain the allowed names, so tasks aren't missed.

django_tasks/backends/base.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from typing import Any, Iterable, TypeVar
44

55
from asgiref.sync import sync_to_async
6-
from django.core.checks.messages import CheckMessage
6+
from django.core.checks import messages
7+
from django.db import connections
8+
from django.test.testcases import _DatabaseFailure
79
from django.utils import timezone
810
from typing_extensions import ParamSpec
911

@@ -16,6 +18,9 @@
1618

1719

1820
class BaseTaskBackend(metaclass=ABCMeta):
21+
alias: str
22+
enqueue_on_commit: bool
23+
1924
task_class = Task
2025

2126
supports_defer = False
@@ -32,6 +37,27 @@ def __init__(self, options: dict) -> None:
3237

3338
self.alias = options["ALIAS"]
3439
self.queues = set(options.get("QUEUES", [DEFAULT_QUEUE_NAME]))
40+
self.enqueue_on_commit = bool(options.get("ENQUEUE_ON_COMMIT", True))
41+
42+
def _get_enqueue_on_commit_for_task(self, task: Task) -> bool:
43+
"""
44+
Determine the correct `enqueue_on_commit` setting to use for a given task.
45+
46+
If the task defines it, use that, otherwise, fall back to the backend.
47+
"""
48+
# If this project doesn't use a database, there's nothing to commit to
49+
if not connections.settings:
50+
return False
51+
52+
# If connections are disabled during tests, there's nothing to commit to
53+
for conn in connections.all():
54+
if isinstance(conn.connect, _DatabaseFailure):
55+
return False
56+
57+
if isinstance(task.enqueue_on_commit, bool):
58+
return task.enqueue_on_commit
59+
60+
return self.enqueue_on_commit
3561

3662
def validate_task(self, task: Task) -> None:
3763
"""
@@ -101,8 +127,10 @@ async def aget_result(self, result_id: str) -> TaskResult:
101127
result_id=result_id
102128
)
103129

104-
def check(self, **kwargs: Any) -> Iterable[CheckMessage]:
105-
raise NotImplementedError(
106-
"subclasses may provide a check() method to verify that task "
107-
"backend is configured correctly."
108-
)
130+
def check(self, **kwargs: Any) -> Iterable[messages.CheckMessage]:
131+
if self.enqueue_on_commit and not connections.settings:
132+
yield messages.CheckMessage(
133+
messages.ERROR,
134+
"`ENQUEUE_ON_COMMIT` cannot be used when no databases are configured",
135+
hint="Set `ENQUEUE_ON_COMMIT` to False",
136+
)

django_tasks/backends/database/backend.py

+13-19
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from typing import TYPE_CHECKING, Any, Iterable, TypeVar
33

44
from django.apps import apps
5-
from django.core.checks import ERROR, CheckMessage
5+
from django.core.checks import messages
66
from django.core.exceptions import ValidationError
7-
from django.db import connections, router
7+
from django.db import connections, router, transaction
88
from typing_extensions import ParamSpec
99

1010
from django_tasks.backends.base import BaseTaskBackend
@@ -51,18 +51,10 @@ def enqueue(
5151

5252
db_result = self._task_to_db_task(task, args, kwargs)
5353

54-
db_result.save()
55-
56-
return db_result.task_result
57-
58-
async def aenqueue(
59-
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
60-
) -> TaskResult[T]:
61-
self.validate_task(task)
62-
63-
db_result = self._task_to_db_task(task, args, kwargs)
64-
65-
await db_result.asave()
54+
if self._get_enqueue_on_commit_for_task(task):
55+
transaction.on_commit(db_result.save)
56+
else:
57+
db_result.save()
6658

6759
return db_result.task_result
6860

@@ -82,14 +74,16 @@ async def aget_result(self, result_id: str) -> TaskResult:
8274
except (DBTaskResult.DoesNotExist, ValidationError) as e:
8375
raise ResultDoesNotExist(result_id) from e
8476

85-
def check(self, **kwargs: Any) -> Iterable[CheckMessage]:
77+
def check(self, **kwargs: Any) -> Iterable[messages.CheckMessage]:
8678
from .models import DBTaskResult
8779

80+
yield from super().check(**kwargs)
81+
8882
backend_name = self.__class__.__name__
8983

9084
if not apps.is_installed("django_tasks.backends.database"):
91-
yield CheckMessage(
92-
ERROR,
85+
yield messages.CheckMessage(
86+
messages.ERROR,
9387
f"{backend_name} configured as django_tasks backend, but database app not installed",
9488
"Insert 'django_tasks.backends.database' in INSTALLED_APPS",
9589
)
@@ -100,8 +94,8 @@ def check(self, **kwargs: Any) -> Iterable[CheckMessage]:
10094
and hasattr(db_connection, "transaction_mode")
10195
and db_connection.transaction_mode != "EXCLUSIVE"
10296
):
103-
yield CheckMessage(
104-
ERROR,
97+
yield messages.CheckMessage(
98+
messages.ERROR,
10599
f"{backend_name} is using SQLite non-exclusive transactions",
106100
f"Set settings.DATABASES[{db_connection.alias!r}]['OPTIONS']['transaction_mode'] to 'EXCLUSIVE'",
107101
)

django_tasks/backends/dummy.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from copy import deepcopy
2+
from functools import partial
23
from typing import List, TypeVar
34
from uuid import uuid4
45

6+
from django.db import transaction
57
from django.utils import timezone
68
from typing_extensions import ParamSpec
79

@@ -42,8 +44,11 @@ def enqueue(
4244
backend=self.alias,
4345
)
4446

45-
# Copy the task to prevent mutation issues
46-
self.results.append(deepcopy(result))
47+
if self._get_enqueue_on_commit_for_task(task) is not False:
48+
# Copy the task to prevent mutation issues
49+
transaction.on_commit(partial(self.results.append, deepcopy(result)))
50+
else:
51+
self.results.append(deepcopy(result))
4752

4853
return result
4954

django_tasks/backends/immediate.py

+34-20
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2+
from functools import partial
23
from inspect import iscoroutinefunction
34
from typing import TypeVar
45
from uuid import uuid4
56

67
from asgiref.sync import async_to_sync
8+
from django.db import transaction
79
from django.utils import timezone
810
from typing_extensions import ParamSpec
911

@@ -22,53 +24,65 @@
2224
class ImmediateBackend(BaseTaskBackend):
2325
supports_async_task = True
2426

25-
def enqueue(
26-
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
27-
) -> TaskResult[T]:
28-
self.validate_task(task)
27+
def _execute_task(self, task_result: TaskResult) -> None:
28+
"""
29+
Execute the task for the given `TaskResult`, mutating it with the outcome
30+
"""
31+
task = task_result.task
2932

3033
calling_task_func = (
3134
async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func
3235
)
3336

34-
enqueued_at = timezone.now()
35-
started_at = timezone.now()
36-
result_id = str(uuid4())
37+
task_result.started_at = timezone.now()
3738
try:
38-
result = json_normalize(calling_task_func(*args, **kwargs))
39-
status = ResultStatus.COMPLETE
39+
task_result._result = json_normalize(
40+
calling_task_func(*task_result.args, **task_result.kwargs)
41+
)
4042
except BaseException as e:
43+
task_result.finished_at = timezone.now()
4144
try:
42-
result = exception_to_dict(e)
45+
task_result._result = exception_to_dict(e)
4346
except Exception:
44-
logger.exception("Task id=%s unable to save exception", result_id)
45-
result = None
47+
logger.exception("Task id=%s unable to save exception", task_result.id)
48+
task_result._result = None
4649

4750
# Use `.exception` to integrate with error monitoring tools (eg Sentry)
4851
logger.exception(
4952
"Task id=%s path=%s state=%s",
50-
result_id,
53+
task_result.id,
5154
task.module_path,
5255
ResultStatus.FAILED,
5356
)
54-
status = ResultStatus.FAILED
57+
task_result.status = ResultStatus.FAILED
5558

5659
# If the user tried to terminate, let them
5760
if isinstance(e, KeyboardInterrupt):
5861
raise
62+
else:
63+
task_result.finished_at = timezone.now()
64+
task_result.status = ResultStatus.COMPLETE
65+
66+
def enqueue(
67+
self, task: Task[P, T], args: P.args, kwargs: P.kwargs
68+
) -> TaskResult[T]:
69+
self.validate_task(task)
5970

6071
task_result = TaskResult[T](
6172
task=task,
62-
id=result_id,
63-
status=status,
64-
enqueued_at=enqueued_at,
65-
started_at=started_at,
66-
finished_at=timezone.now(),
73+
id=str(uuid4()),
74+
status=ResultStatus.NEW,
75+
enqueued_at=timezone.now(),
76+
started_at=None,
77+
finished_at=None,
6778
args=json_normalize(args),
6879
kwargs=json_normalize(kwargs),
6980
backend=self.alias,
7081
)
7182

72-
task_result._result = result
83+
if self._get_enqueue_on_commit_for_task(task) is not False:
84+
transaction.on_commit(partial(self._execute_task, task_result))
85+
else:
86+
self._execute_task(task_result)
7387

7488
return task_result

django_tasks/task.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,17 @@ class Task(Generic[P, T]):
5656
"""The name of the backend the task will run on"""
5757

5858
queue_name: str = DEFAULT_QUEUE_NAME
59-
"""The name of the queue the task will run on """
59+
"""The name of the queue the task will run on"""
6060

6161
run_after: Optional[datetime] = None
6262
"""The earliest this task will run"""
6363

64+
enqueue_on_commit: Optional[bool] = None
65+
"""
66+
Whether the task will be enqueued when the current transaction commits,
67+
immediately, or whatever the backend decides
68+
"""
69+
6470
def __post_init__(self) -> None:
6571
self.get_backend().validate_task(self)
6672

@@ -170,6 +176,7 @@ def task(
170176
priority: int = DEFAULT_PRIORITY,
171177
queue_name: str = DEFAULT_QUEUE_NAME,
172178
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
179+
enqueue_on_commit: Optional[bool] = None,
173180
) -> Callable[[Callable[P, T]], Task[P, T]]: ...
174181

175182

@@ -180,6 +187,7 @@ def task(
180187
priority: int = DEFAULT_PRIORITY,
181188
queue_name: str = DEFAULT_QUEUE_NAME,
182189
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
190+
enqueue_on_commit: Optional[bool] = None,
183191
) -> Union[Task[P, T], Callable[[Callable[P, T]], Task[P, T]]]:
184192
"""
185193
A decorator used to create a task.
@@ -188,7 +196,11 @@ def task(
188196

189197
def wrapper(f: Callable[P, T]) -> Task[P, T]:
190198
return tasks[backend].task_class(
191-
priority=priority, func=f, queue_name=queue_name, backend=backend
199+
priority=priority,
200+
func=f,
201+
queue_name=queue_name,
202+
backend=backend,
203+
enqueue_on_commit=enqueue_on_commit,
192204
)
193205

194206
if function:

tests/tasks.py

+10
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,13 @@ def complex_exception() -> None:
4444
@task()
4545
def exit_task() -> None:
4646
exit(1)
47+
48+
49+
@task(enqueue_on_commit=True)
50+
def enqueue_on_commit_task() -> None:
51+
pass
52+
53+
54+
@task(enqueue_on_commit=False)
55+
def never_enqueue_on_commit_task() -> None:
56+
pass

0 commit comments

Comments
 (0)