Skip to content

Commit c7d1e16

Browse files
Prevent writing to task results (#106)
The implementation is slightly slower, and the internal code gets messier, but it prevents a number of potential foot-guns
1 parent 9aa591f commit c7d1e16

File tree

5 files changed

+20
-16
lines changed

5 files changed

+20
-16
lines changed

django_tasks/backends/database/backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
P = ParamSpec("P")
2222

2323

24-
@dataclass
24+
@dataclass(frozen=True)
2525
class TaskResult(BaseTaskResult[T]):
2626
db_result: "DBTaskResult"
2727

django_tasks/backends/database/models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def task_result(self) -> "TaskResult[T]":
148148
backend=self.backend_name,
149149
)
150150

151-
result._return_value = self.return_value
152-
result._exception_data = self.exception_data
151+
object.__setattr__(result, "_return_value", self.return_value)
152+
object.__setattr__(result, "_exception_data", self.exception_data)
153153

154154
return result
155155

django_tasks/backends/immediate.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@ def _execute_task(self, task_result: TaskResult) -> None:
3434
async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func
3535
)
3636

37-
task_result.started_at = timezone.now()
37+
object.__setattr__(task_result, "started_at", timezone.now())
3838
try:
39-
task_result._return_value = json_normalize(
40-
calling_task_func(*task_result.args, **task_result.kwargs)
39+
object.__setattr__(
40+
task_result,
41+
"_return_value",
42+
json_normalize(
43+
calling_task_func(*task_result.args, **task_result.kwargs)
44+
),
4145
)
4246
except BaseException as e:
43-
task_result.finished_at = timezone.now()
47+
object.__setattr__(task_result, "finished_at", timezone.now())
4448
try:
45-
task_result._exception_data = exception_to_dict(e)
49+
object.__setattr__(task_result, "_exception_data", exception_to_dict(e))
4650
except Exception:
4751
logger.exception("Task id=%s unable to save exception", task_result.id)
4852

@@ -53,14 +57,14 @@ def _execute_task(self, task_result: TaskResult) -> None:
5357
task.module_path,
5458
ResultStatus.FAILED,
5559
)
56-
task_result.status = ResultStatus.FAILED
60+
object.__setattr__(task_result, "status", ResultStatus.FAILED)
5761

5862
# If the user tried to terminate, let them
5963
if isinstance(e, KeyboardInterrupt):
6064
raise
6165
else:
62-
task_result.finished_at = timezone.now()
63-
task_result.status = ResultStatus.COMPLETE
66+
object.__setattr__(task_result, "finished_at", timezone.now())
67+
object.__setattr__(task_result, "status", ResultStatus.COMPLETE)
6468

6569
def enqueue(
6670
self, task: Task[P, T], args: P.args, kwargs: P.kwargs

django_tasks/task.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def wrapper(f: Callable[P, T]) -> Task[P, T]:
218218
return wrapper
219219

220220

221-
@dataclass
221+
@dataclass(frozen=True)
222222
class TaskResult(Generic[T]):
223223
task: Task
224224
"""The task for which this is a result"""
@@ -292,7 +292,7 @@ def refresh(self) -> None:
292292
refreshed_task = self.task.get_backend().get_result(self.id)
293293

294294
for attr in TASK_REFRESH_ATTRS:
295-
setattr(self, attr, getattr(refreshed_task, attr))
295+
object.__setattr__(self, attr, getattr(refreshed_task, attr))
296296

297297
async def arefresh(self) -> None:
298298
"""
@@ -301,4 +301,4 @@ async def arefresh(self) -> None:
301301
refreshed_task = await self.task.get_backend().aget_result(self.id)
302302

303303
for attr in TASK_REFRESH_ATTRS:
304-
setattr(self, attr, getattr(refreshed_task, attr))
304+
object.__setattr__(self, attr, getattr(refreshed_task, attr))

tests/tests/test_dummy_backend.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_refresh_result(self) -> None:
7474
)
7575

7676
enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined]
77-
enqueued_result.status = ResultStatus.COMPLETE
77+
object.__setattr__(enqueued_result, "status", ResultStatus.COMPLETE)
7878

7979
self.assertEqual(result.status, ResultStatus.NEW)
8080
result.refresh()
@@ -86,7 +86,7 @@ async def test_refresh_result_async(self) -> None:
8686
)
8787

8888
enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined]
89-
enqueued_result.status = ResultStatus.COMPLETE
89+
object.__setattr__(enqueued_result, "status", ResultStatus.COMPLETE)
9090

9191
self.assertEqual(result.status, ResultStatus.NEW)
9292
await result.arefresh()

0 commit comments

Comments
 (0)