|
1 | 1 | import logging
|
| 2 | +from functools import partial |
2 | 3 | from inspect import iscoroutinefunction
|
3 | 4 | from typing import TypeVar
|
4 | 5 | from uuid import uuid4
|
5 | 6 |
|
6 | 7 | from asgiref.sync import async_to_sync
|
| 8 | +from django.db import transaction |
7 | 9 | from django.utils import timezone
|
8 | 10 | from typing_extensions import ParamSpec
|
9 | 11 |
|
|
22 | 24 | class ImmediateBackend(BaseTaskBackend):
|
23 | 25 | supports_async_task = True
|
24 | 26 |
|
25 |
| - def enqueue( |
26 |
| - self, task: Task[P, T], args: P.args, kwargs: P.kwargs |
27 |
| - ) -> TaskResult[T]: |
28 |
| - self.validate_task(task) |
| 27 | + def _execute_task(self, task_result: TaskResult) -> None: |
| 28 | + """ |
| 29 | + Execute the task for the given `TaskResult`, mutating it with the outcome |
| 30 | + """ |
| 31 | + task = task_result.task |
29 | 32 |
|
30 | 33 | calling_task_func = (
|
31 | 34 | async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func
|
32 | 35 | )
|
33 | 36 |
|
34 |
| - enqueued_at = timezone.now() |
35 |
| - started_at = timezone.now() |
36 |
| - result_id = str(uuid4()) |
| 37 | + task_result.started_at = timezone.now() |
37 | 38 | try:
|
38 |
| - result = json_normalize(calling_task_func(*args, **kwargs)) |
39 |
| - status = ResultStatus.COMPLETE |
| 39 | + task_result._result = json_normalize( |
| 40 | + calling_task_func(*task_result.args, **task_result.kwargs) |
| 41 | + ) |
40 | 42 | except BaseException as e:
|
| 43 | + task_result.finished_at = timezone.now() |
41 | 44 | try:
|
42 |
| - result = exception_to_dict(e) |
| 45 | + task_result._result = exception_to_dict(e) |
43 | 46 | except Exception:
|
44 |
| - logger.exception("Task id=%s unable to save exception", result_id) |
45 |
| - result = None |
| 47 | + logger.exception("Task id=%s unable to save exception", task_result.id) |
| 48 | + task_result._result = None |
46 | 49 |
|
47 | 50 | # Use `.exception` to integrate with error monitoring tools (eg Sentry)
|
48 | 51 | logger.exception(
|
49 | 52 | "Task id=%s path=%s state=%s",
|
50 |
| - result_id, |
| 53 | + task_result.id, |
51 | 54 | task.module_path,
|
52 | 55 | ResultStatus.FAILED,
|
53 | 56 | )
|
54 |
| - status = ResultStatus.FAILED |
| 57 | + task_result.status = ResultStatus.FAILED |
55 | 58 |
|
56 | 59 | # If the user tried to terminate, let them
|
57 | 60 | if isinstance(e, KeyboardInterrupt):
|
58 | 61 | raise
|
| 62 | + else: |
| 63 | + task_result.finished_at = timezone.now() |
| 64 | + task_result.status = ResultStatus.COMPLETE |
| 65 | + |
| 66 | + def enqueue( |
| 67 | + self, task: Task[P, T], args: P.args, kwargs: P.kwargs |
| 68 | + ) -> TaskResult[T]: |
| 69 | + self.validate_task(task) |
59 | 70 |
|
60 | 71 | task_result = TaskResult[T](
|
61 | 72 | task=task,
|
62 |
| - id=result_id, |
63 |
| - status=status, |
64 |
| - enqueued_at=enqueued_at, |
65 |
| - started_at=started_at, |
66 |
| - finished_at=timezone.now(), |
| 73 | + id=str(uuid4()), |
| 74 | + status=ResultStatus.NEW, |
| 75 | + enqueued_at=timezone.now(), |
| 76 | + started_at=None, |
| 77 | + finished_at=None, |
67 | 78 | args=json_normalize(args),
|
68 | 79 | kwargs=json_normalize(kwargs),
|
69 | 80 | backend=self.alias,
|
70 | 81 | )
|
71 | 82 |
|
72 |
| - task_result._result = result |
| 83 | + if self._get_enqueue_on_commit_for_task(task) is not False: |
| 84 | + transaction.on_commit(partial(self._execute_task, task_result)) |
| 85 | + else: |
| 86 | + self._execute_task(task_result) |
73 | 87 |
|
74 | 88 | return task_result
|
0 commit comments