|
1 | 1 | from copy import deepcopy
|
2 | 2 | from dataclasses import dataclass, field
|
3 | 3 | from datetime import datetime, timedelta
|
4 |
| -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, Union |
5 |
| - |
| 4 | +from inspect import iscoroutinefunction |
| 5 | +from typing import ( |
| 6 | + TYPE_CHECKING, |
| 7 | + Any, |
| 8 | + Callable, |
| 9 | + Coroutine, |
| 10 | + Dict, |
| 11 | + Generic, |
| 12 | + Optional, |
| 13 | + TypeVar, |
| 14 | + Union, |
| 15 | +) |
| 16 | + |
| 17 | +from asgiref.sync import async_to_sync, sync_to_async |
6 | 18 | from django.db.models.enums import TextChoices
|
7 | 19 | from django.utils import timezone
|
8 | 20 | from typing_extensions import ParamSpec, Self
|
@@ -117,9 +129,21 @@ async def aget_result(self, result_id: str) -> "TaskResult[T]":
|
117 | 129 |
|
118 | 130 | return result
|
119 | 131 |
|
120 |
| - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: |
| 132 | + def __call__( |
| 133 | + self, *args: P.args, **kwargs: P.kwargs |
| 134 | + ) -> Union[T, Coroutine[T, None, None]]: |
| 135 | + return self.func(*args, **kwargs) |
| 136 | + |
| 137 | + def call(self, *args: P.args, **kwargs: P.kwargs) -> T: |
| 138 | + if iscoroutinefunction(self.func): |
| 139 | + return async_to_sync(self.func)(*args, **kwargs) # type:ignore[no-any-return] |
121 | 140 | return self.func(*args, **kwargs)
|
122 | 141 |
|
| 142 | + async def acall(self, *args: P.args, **kwargs: P.kwargs) -> T: |
| 143 | + if iscoroutinefunction(self.func): |
| 144 | + return await self.func(*args, **kwargs) # type:ignore[no-any-return] |
| 145 | + return await sync_to_async(self.func)(*args, **kwargs) |
| 146 | + |
123 | 147 | def get_backend(self) -> "BaseTaskBackend":
|
124 | 148 | from . import tasks
|
125 | 149 |
|
|
0 commit comments