11import inspect
2- from collections .abc import Callable
2+ from collections .abc import Callable , Coroutine
33from functools import wraps
4- from typing import Any , ParamSpec , TypeAlias , cast
4+ from typing import Any , ParamSpec , cast , get_type_hints , overload
55
66from ._core import FluxQueueCore
77from .utils import get_task_name
88
99P = ParamSpec ("P" )
10- TaskDecorator : TypeAlias = Callable [[Callable [P , Any ]], Callable [P , Any ]]
1110
1211
1312class FluxQueue :
@@ -30,7 +29,7 @@ def task(
3029 name : str | None = None ,
3130 queue : str = "default" ,
3231 max_retries : int = 3 ,
33- ) -> TaskDecorator [ P ] :
32+ ):
3433 """
3534 Mark a function as a FluxQueue task.
3635
@@ -52,7 +51,23 @@ def task(
5251 before treating it as dead.
5352 """
5453
55- def decorator (func : Callable [P , Any ]) -> Callable [P , Any ]:
54+ @overload
55+ def decorator (func : Callable [P , None ]) -> Callable [P , None ]: ...
56+
57+ @overload
58+ def decorator (
59+ func : Callable [P , Coroutine [Any , Any , None ]],
60+ ) -> Callable [P , Coroutine [Any , Any , None ]]: ...
61+
62+ def decorator (
63+ func : Callable [P , None | Coroutine [Any , Any , None ]],
64+ ) -> Callable [P , None | Coroutine [Any , Any , None ]]:
65+ type_hints = get_type_hints (func )
66+ return_type = type_hints .get ("return" )
67+
68+ if return_type and return_type is not type (None ):
69+ raise TypeError (f"Task function must return None, got { return_type } " )
70+
5671 is_async = inspect .iscoroutinefunction (func )
5772 task_name = get_task_name (func , name )
5873
0 commit comments