Skip to content

Commit 08342d6

Browse files
committed
Ensure async tasks can be called directly
1 parent 80799d8 commit 08342d6

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

django_tasks/task.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
from copy import deepcopy
22
from dataclasses import dataclass, field
33
from datetime import datetime, timedelta
4-
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, Union
5-
4+
from inspect import iscoroutinefunction
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Callable,
9+
Coroutine,
10+
Dict,
11+
Generic,
12+
Optional,
13+
TypeVar,
14+
Union,
15+
)
16+
17+
from asgiref.sync import async_to_sync, sync_to_async
618
from django.db.models.enums import TextChoices
719
from django.utils import timezone
820
from typing_extensions import ParamSpec, Self
@@ -117,9 +129,21 @@ async def aget_result(self, result_id: str) -> "TaskResult[T]":
117129

118130
return result
119131

120-
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
132+
def __call__(
133+
self, *args: P.args, **kwargs: P.kwargs
134+
) -> Union[T, Coroutine[T, None, None]]:
135+
return self.func(*args, **kwargs)
136+
137+
def call(self, *args: P.args, **kwargs: P.kwargs) -> T:
138+
if iscoroutinefunction(self.func):
139+
return async_to_sync(self.func)(*args, **kwargs) # type:ignore[no-any-return]
121140
return self.func(*args, **kwargs)
122141

142+
async def acall(self, *args: P.args, **kwargs: P.kwargs) -> T:
143+
if iscoroutinefunction(self.func):
144+
return await self.func(*args, **kwargs) # type:ignore[no-any-return]
145+
return await sync_to_async(self.func)(*args, **kwargs)
146+
123147
def get_backend(self) -> "BaseTaskBackend":
124148
from . import tasks
125149

tests/tasks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
@task()
55
def noop_task(*args: tuple, **kwargs: dict) -> None:
6-
pass
6+
return None
77

88

99
@task()
1010
async def noop_task_async(*args: tuple, **kwargs: dict) -> None:
11-
pass
11+
return None
1212

1313

1414
@task()

tests/tests/test_tasks.py

+11
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ async def test_invalid_priority(self) -> None:
135135

136136
def test_call_task(self) -> None:
137137
self.assertEqual(test_tasks.calculate_meaning_of_life(), 42)
138+
self.assertEqual(test_tasks.calculate_meaning_of_life.call(), 42)
139+
140+
async def test_call_task_async(self) -> None:
141+
self.assertEqual(await test_tasks.calculate_meaning_of_life.acall(), 42)
142+
143+
async def test_call_async_task(self) -> None:
144+
self.assertIsNone(await test_tasks.noop_task_async()) # type:ignore[func-returns-value]
145+
self.assertIsNone(await test_tasks.noop_task_async.acall())
146+
147+
def test_call_async_task_sync(self) -> None:
148+
self.assertIsNone(test_tasks.noop_task_async.call())
138149

139150
def test_get_result(self) -> None:
140151
result = default_task_backend.enqueue(test_tasks.noop_task, (), {})

0 commit comments

Comments
 (0)