Skip to content

Commit ddf67d4

Browse files
Jakub Grzmielfacebook-github-bot
authored andcommitted
Improve pyre typing
Summary: Small pyre improvement to later library. Reviewed By: fried, zsol Differential Revision: D62593644 fbshipit-source-id: ed2836c5e11253366fb0e5eb20b5f1b5db31ffcd
1 parent 1aa43dc commit ddf67d4

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

later/task.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
NewType,
3939
Optional,
4040
overload,
41+
ParamSpec,
4142
Protocol,
4243
Sequence,
4344
Tuple,
@@ -53,6 +54,7 @@
5354
FixerType = Callable[[asyncio.Task], Union[asyncio.Task, Awaitable[asyncio.Task]]]
5455
logger: logging.Logger = logging.getLogger(__name__)
5556
T = TypeVar("T")
57+
TParams = ParamSpec("TParams")
5658

5759
__all__: Sequence[str] = [
5860
"Watcher",
@@ -125,15 +127,15 @@ async def cancel(fut: asyncio.Future) -> None:
125127

126128

127129
def as_task(
128-
func: Callable[..., Coroutine[object, object, T]]
129-
) -> Callable[..., asyncio.Task[T]]:
130+
func: Callable[TParams, Coroutine[object, object, T]]
131+
) -> Callable[TParams, asyncio.Task[T]]:
130132
"""
131133
Decorate a function, So that when called it is wrapped in a task
132134
on the running loop.
133135
"""
134136

135137
@wraps(func)
136-
def create_task(*args: Any, **kws: Mapping[str, Any]) -> asyncio.Task[T]:
138+
def create_task(*args: TParams.args, **kws: TParams.kwargs) -> asyncio.Task[T]:
137139
loop = asyncio.get_running_loop()
138140
return loop.create_task(func(*args, **kws))
139141

@@ -161,7 +163,7 @@ class Watcher:
161163
_cancel_timeout: float
162164
_preexit_callbacks: List[Callable[[], None]]
163165
_shielded_tasks: Dict[asyncio.Task, asyncio.Future]
164-
# pyre-fixme[13]: Attribute `loop` is never initialized.
166+
# pyre-ignore[13]: loop is initialized in __aenter__
165167
loop: asyncio.AbstractEventLoop
166168
running: bool
167169
done_ok: bool
@@ -424,7 +426,7 @@ async def _handle_cancel(self) -> None:
424426
)
425427

426428

427-
CacheKey = NewType("CacheKey", Sequence[Hashable])
429+
CacheKey = NewType("CacheKey", tuple[Hashable, ...])
428430
ArgID = Union[int, str]
429431

430432

@@ -457,11 +459,9 @@ def _build_key(
457459
Allow for not including certain fields from args or kwargs
458460
"""
459461
if not ignored_args:
460-
# pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`.
461462
return CacheKey((args, tuple(sorted(kwargs.items()))))
462463

463464
# If we do want to ignore something then do so
464-
# pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`.
465465
return CacheKey(
466466
(
467467
tuple((value for idx, value in enumerate(args) if idx not in ignored_args)),
@@ -474,37 +474,40 @@ def _build_key(
474474

475475
class AsyncCallable(Protocol):
476476
def __call__(
477-
self, fn: Callable[..., Coroutine[object, object, T]]
478-
) -> Callable[..., Coroutine[object, object, T]]: # pragma: nocover
477+
self, fn: Callable[TParams, Coroutine[object, object, T]]
478+
) -> Callable[TParams, Coroutine[object, object, T]]: # pragma: nocover
479479
...
480480

481481

482-
FuncType = Callable[..., Coroutine[object, object, T]]
483-
484-
485-
@overload # noqa: 811
482+
@overload
486483
def herd(
487-
fn: FuncType[T], *, ignored_args: Optional[AbstractSet[ArgID]] = None
488-
) -> FuncType[T]: # pragma: nocover
484+
fn: Callable[TParams, Coroutine[object, object, T]],
485+
*,
486+
ignored_args: Optional[AbstractSet[ArgID]] = None,
487+
) -> Callable[TParams, Coroutine[object, object, T]]: # pragma: nocover
489488
...
490489

491490

492-
@overload # noqa: 811
491+
@overload
493492
def herd(
494-
fn: Optional[AsyncCallable] = None,
493+
fn: None = None,
495494
*,
496495
ignored_args: Optional[AbstractSet[ArgID]] = None,
497496
) -> AsyncCallable: # pragma: nocover
498497
...
499498

500499

501-
# pyre-ignore[3]: Defining the return type is pointless here
502500
def herd(
503-
# pyre-ignore[2]: This is fine, we don't need types the overloads cover it
504-
fn=None,
501+
fn: Callable[TParams, Coroutine[object, object, T]] | None = None,
505502
*,
506503
ignored_args: Optional[AbstractSet[ArgID]] = None,
507-
): # noqa: 811
504+
) -> (
505+
Callable[TParams, Coroutine[object, object, T]]
506+
| Callable[
507+
[Callable[TParams, Coroutine[object, object, T]]],
508+
Callable[TParams, Coroutine[object, object, T]],
509+
]
510+
):
508511
"""
509512
Provide a simple thundering herd protection as a decorator.
510513
if requests comes in while and existing request with those same args is pending,
@@ -518,11 +521,13 @@ def herd(
518521
Each member of the herd is "shielded" from cancellation effecting other herd members
519522
"""
520523

521-
def decorator(fn: FuncType[T]) -> T:
524+
def decorator(
525+
fn: Callable[TParams, Coroutine[object, object, T]]
526+
) -> Callable[TParams, Coroutine[object, object, T]]:
522527
local: threading.local = threading.local()
523528

524529
@functools.wraps(fn)
525-
async def wrapped(*args: Any, **kwargs: Any) -> T:
530+
async def wrapped(*args: TParams.args, **kwargs: TParams.kwargs) -> T:
526531
pending = cast(Dict[CacheKey, _CountTask], _get_local(local, "pending"))
527532
request = _build_key(tuple(args), kwargs, ignored_args)
528533
count_task = pending.setdefault(request, _CountTask())
@@ -548,7 +553,6 @@ async def wrapped(*args: Any, **kwargs: Any) -> T:
548553
return wrapped
549554

550555
if fn and callable(fn):
551-
# pyre-fixme[6]: For 1st param expected `F` but got `(...) -> object`.
552556
return decorator(fn)
553557

554558
return decorator

0 commit comments

Comments
 (0)