Skip to content

Commit

Permalink
Improve pyre typing
Browse files Browse the repository at this point in the history
Summary: Small pyre improvement to later library.

Reviewed By: fried, zsol

Differential Revision: D62593644

fbshipit-source-id: ed2836c5e11253366fb0e5eb20b5f1b5db31ffcd
  • Loading branch information
Jakub Grzmiel authored and facebook-github-bot committed Sep 13, 2024
1 parent 1aa43dc commit ddf67d4
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions later/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
NewType,
Optional,
overload,
ParamSpec,
Protocol,
Sequence,
Tuple,
Expand All @@ -53,6 +54,7 @@
FixerType = Callable[[asyncio.Task], Union[asyncio.Task, Awaitable[asyncio.Task]]]
logger: logging.Logger = logging.getLogger(__name__)
T = TypeVar("T")
TParams = ParamSpec("TParams")

__all__: Sequence[str] = [
"Watcher",
Expand Down Expand Up @@ -125,15 +127,15 @@ async def cancel(fut: asyncio.Future) -> None:


def as_task(
func: Callable[..., Coroutine[object, object, T]]
) -> Callable[..., asyncio.Task[T]]:
func: Callable[TParams, Coroutine[object, object, T]]
) -> Callable[TParams, asyncio.Task[T]]:
"""
Decorate a function, So that when called it is wrapped in a task
on the running loop.
"""

@wraps(func)
def create_task(*args: Any, **kws: Mapping[str, Any]) -> asyncio.Task[T]:
def create_task(*args: TParams.args, **kws: TParams.kwargs) -> asyncio.Task[T]:
loop = asyncio.get_running_loop()
return loop.create_task(func(*args, **kws))

Expand Down Expand Up @@ -161,7 +163,7 @@ class Watcher:
_cancel_timeout: float
_preexit_callbacks: List[Callable[[], None]]
_shielded_tasks: Dict[asyncio.Task, asyncio.Future]
# pyre-fixme[13]: Attribute `loop` is never initialized.
# pyre-ignore[13]: loop is initialized in __aenter__
loop: asyncio.AbstractEventLoop
running: bool
done_ok: bool
Expand Down Expand Up @@ -424,7 +426,7 @@ async def _handle_cancel(self) -> None:
)


CacheKey = NewType("CacheKey", Sequence[Hashable])
CacheKey = NewType("CacheKey", tuple[Hashable, ...])
ArgID = Union[int, str]


Expand Down Expand Up @@ -457,11 +459,9 @@ def _build_key(
Allow for not including certain fields from args or kwargs
"""
if not ignored_args:
# pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`.
return CacheKey((args, tuple(sorted(kwargs.items()))))

# If we do want to ignore something then do so
# pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`.
return CacheKey(
(
tuple((value for idx, value in enumerate(args) if idx not in ignored_args)),
Expand All @@ -474,37 +474,40 @@ def _build_key(

class AsyncCallable(Protocol):
def __call__(
self, fn: Callable[..., Coroutine[object, object, T]]
) -> Callable[..., Coroutine[object, object, T]]: # pragma: nocover
self, fn: Callable[TParams, Coroutine[object, object, T]]
) -> Callable[TParams, Coroutine[object, object, T]]: # pragma: nocover
...


FuncType = Callable[..., Coroutine[object, object, T]]


@overload # noqa: 811
@overload
def herd(
fn: FuncType[T], *, ignored_args: Optional[AbstractSet[ArgID]] = None
) -> FuncType[T]: # pragma: nocover
fn: Callable[TParams, Coroutine[object, object, T]],
*,
ignored_args: Optional[AbstractSet[ArgID]] = None,
) -> Callable[TParams, Coroutine[object, object, T]]: # pragma: nocover
...


@overload # noqa: 811
@overload
def herd(
fn: Optional[AsyncCallable] = None,
fn: None = None,
*,
ignored_args: Optional[AbstractSet[ArgID]] = None,
) -> AsyncCallable: # pragma: nocover
...


# pyre-ignore[3]: Defining the return type is pointless here
def herd(
# pyre-ignore[2]: This is fine, we don't need types the overloads cover it
fn=None,
fn: Callable[TParams, Coroutine[object, object, T]] | None = None,
*,
ignored_args: Optional[AbstractSet[ArgID]] = None,
): # noqa: 811
) -> (
Callable[TParams, Coroutine[object, object, T]]
| Callable[
[Callable[TParams, Coroutine[object, object, T]]],
Callable[TParams, Coroutine[object, object, T]],
]
):
"""
Provide a simple thundering herd protection as a decorator.
if requests comes in while and existing request with those same args is pending,
Expand All @@ -518,11 +521,13 @@ def herd(
Each member of the herd is "shielded" from cancellation effecting other herd members
"""

def decorator(fn: FuncType[T]) -> T:
def decorator(
fn: Callable[TParams, Coroutine[object, object, T]]
) -> Callable[TParams, Coroutine[object, object, T]]:
local: threading.local = threading.local()

@functools.wraps(fn)
async def wrapped(*args: Any, **kwargs: Any) -> T:
async def wrapped(*args: TParams.args, **kwargs: TParams.kwargs) -> T:
pending = cast(Dict[CacheKey, _CountTask], _get_local(local, "pending"))
request = _build_key(tuple(args), kwargs, ignored_args)
count_task = pending.setdefault(request, _CountTask())
Expand All @@ -548,7 +553,6 @@ async def wrapped(*args: Any, **kwargs: Any) -> T:
return wrapped

if fn and callable(fn):
# pyre-fixme[6]: For 1st param expected `F` but got `(...) -> object`.
return decorator(fn)

return decorator

0 comments on commit ddf67d4

Please sign in to comment.