diff --git a/aiocache/decorators.py b/aiocache/decorators.py index d2c41b24..03e16897 100644 --- a/aiocache/decorators.py +++ b/aiocache/decorators.py @@ -2,6 +2,10 @@ import functools import inspect import logging +from typing import Callable, TypeVar, ParamSpec, Coroutine, Any + +P = ParamSpec("P") +R = TypeVar("R") from aiocache.base import SENTINEL from aiocache.lock import RedLock @@ -43,12 +47,13 @@ def __init__( self.noself = noself self.cache = cache - def __call__(self, f): + # Use ParamSpec and TypeVar to preserve the decorated function's type signature for static type checkers. + def __call__(self, f: Callable[P, R]) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(f) - async def wrapper(*args, **kwargs): + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return await self.decorator(f, *args, **kwargs) - wrapper.cache = self.cache + wrapper.cache = self.cache # type: ignore[attr-defined] return wrapper async def decorator( @@ -228,12 +233,13 @@ def __init__( self.skip_cache_func = skip_cache_func self.ttl = ttl - def __call__(self, f): + # Use ParamSpec and TypeVar to preserve the decorated function's type signature for static type checkers. + def __call__(self, f: Callable[P, R]) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(f) - async def wrapper(*args, **kwargs): + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return await self.decorator(f, *args, **kwargs) - wrapper.cache = self.cache + wrapper.cache = self.cache # type: ignore[attr-defined] return wrapper async def decorator( diff --git a/tests/ut/test_decorators.py b/tests/ut/test_decorators.py index 7fe4c68e..d3081356 100644 --- a/tests/ut/test_decorators.py +++ b/tests/ut/test_decorators.py @@ -11,6 +11,7 @@ from aiocache.base import SENTINEL from aiocache.decorators import _get_args_dict from aiocache.lock import RedLock +from typing import Any, Dict, List async def stub(*args, value=None, seconds=0, **kwargs): @@ -177,6 +178,16 @@ async def what(self, a, b): assert str(inspect.signature(what)) == "(self, a, b)" assert inspect.getfullargspec(what.__wrapped__).args == ["self", "a", "b"] + async def test_cached_preserves_type_hints(self, mock_cache: Any) -> None: + mock_cache.get.return_value = None + + @cached(cache=mock_cache) + async def add(x: int, y: int) -> int: + return x + y + + # mypy limitation: async decorators with ParamSpec (pyright works). + assert (await add(1, 1)) == 2 # type: ignore[comparison-overlap] + async def test_reuses_cache_instance(self, mock_cache): @cached(cache=mock_cache) async def what(): @@ -477,6 +488,18 @@ async def what(self, keys=None, what=1): assert str(inspect.signature(what)) == "(self, keys=None, what=1)" assert inspect.getfullargspec(what.__wrapped__).args == ["self", "keys", "what"] + async def test_preserves_type_hints(self, mock_cache: Any) -> None: + mock_cache.multi_get.return_value = [None] + + @multi_cached(cache=mock_cache, keys_from_attr="keys") + async def add(x: int, y: int, keys: List[str]) -> Dict[str, int]: + return {k: x + y for k in keys} + + result = await add(1, 1, keys=["a"]) + + # mypy limitation: comparison-overlap with async decorators with ParamSpec. + assert result == {"a": 2} # type: ignore[comparison-overlap] + async def test_key_builder(self): @multi_cached(cache=SimpleMemoryCache(), keys_from_attr="keys", key_builder=lambda key, _, keys: key + 1)