Skip to content

Commit 3b3222d

Browse files
authored
refactor(task): Improve task decorator type inference (#113)
* refactor(types): improve task decorator type inference * fix(task): Raise TypeError for invalid return type annotations
1 parent 7308b6f commit 3b3222d

3 files changed

Lines changed: 31 additions & 6 deletions

File tree

python/fluxqueue/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
__all__ = ["FluxQueue"]
22

3-
from .core import FluxQueue
3+
from .client import FluxQueue
Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import inspect
2-
from collections.abc import Callable
2+
from collections.abc import Callable, Coroutine
33
from functools import wraps
4-
from typing import Any, ParamSpec, TypeAlias, cast
4+
from typing import Any, ParamSpec, cast, get_type_hints, overload
55

66
from ._core import FluxQueueCore
77
from .utils import get_task_name
88

99
P = ParamSpec("P")
10-
TaskDecorator: TypeAlias = Callable[[Callable[P, Any]], Callable[P, Any]]
1110

1211

1312
class FluxQueue:
@@ -30,7 +29,7 @@ def task(
3029
name: str | None = None,
3130
queue: str = "default",
3231
max_retries: int = 3,
33-
) -> TaskDecorator[P]:
32+
):
3433
"""
3534
Mark a function as a FluxQueue task.
3635
@@ -52,7 +51,23 @@ def task(
5251
before treating it as dead.
5352
"""
5453

55-
def decorator(func: Callable[P, Any]) -> Callable[P, Any]:
54+
@overload
55+
def decorator(func: Callable[P, None]) -> Callable[P, None]: ...
56+
57+
@overload
58+
def decorator(
59+
func: Callable[P, Coroutine[Any, Any, None]],
60+
) -> Callable[P, Coroutine[Any, Any, None]]: ...
61+
62+
def decorator(
63+
func: Callable[P, None | Coroutine[Any, Any, None]],
64+
) -> Callable[P, None | Coroutine[Any, Any, None]]:
65+
type_hints = get_type_hints(func)
66+
return_type = type_hints.get("return")
67+
68+
if return_type and return_type is not type(None):
69+
raise TypeError(f"Task function must return None, got {return_type}")
70+
5671
is_async = inspect.iscoroutinefunction(func)
5772
task_name = get_task_name(func, name)
5873

tests/test_tasks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,13 @@ async def async_hello(name: str):
3030
assert b"Async George" in redis_result[0] # type: ignore
3131

3232
test_env.redis_client.flushdb()
33+
34+
35+
def test_invalid_return_type(test_env: TestEnvFixture):
36+
with pytest.raises(TypeError):
37+
38+
@test_env.fluxqueue.task() # type: ignore
39+
def test_task() -> int:
40+
return 5
41+
42+
test_task()

0 commit comments

Comments
 (0)