|
1 | 1 | import asyncio |
2 | | -import contextlib |
3 | 2 | import logging |
4 | | -import typing |
| 3 | +from typing import Awaitable, Callable, Hashable, Literal, TypeVar |
5 | 4 |
|
6 | | -import typing_extensions |
| 5 | +from typing_extensions import ParamSpec |
7 | 6 |
|
8 | 7 | from pgcachewatch import strategies, utils |
9 | 8 |
|
10 | | -P = typing_extensions.ParamSpec("P") |
11 | | -T = typing.TypeVar("T") |
| 9 | +P = ParamSpec("P") |
| 10 | +T = TypeVar("T") |
12 | 11 |
|
13 | 12 |
|
14 | 13 | def cache( |
15 | 14 | strategy: strategies.Strategy, |
16 | | - statistics_callback: typing.Callable[[typing.Literal["hit", "miss"]], None] |
17 | | - | None = None, |
18 | | -) -> typing.Callable[ |
19 | | - [typing.Callable[P, typing.Awaitable[T]]], |
20 | | - typing.Callable[P, typing.Awaitable[T]], |
21 | | -]: |
22 | | - def outer( |
23 | | - fn: typing.Callable[P, typing.Awaitable[T]], |
24 | | - ) -> typing.Callable[P, typing.Awaitable[T]]: |
25 | | - cached = dict[typing.Hashable, asyncio.Future[T]]() |
| 15 | + statistics_callback: Callable[[Literal["hit", "miss"]], None] | None = None, |
| 16 | +) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: |
| 17 | + def outer(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: |
| 18 | + cached = dict[Hashable, asyncio.Future[T]]() |
26 | 19 |
|
27 | | - async def inner(*args: P.args, **kw: P.kwargs) -> T: |
| 20 | + async def inner(*args: P.args, **kwargs: P.kwargs) -> T: |
28 | 21 | # If db-conn is down, disable cache. |
29 | 22 | if not strategy.pg_connection_healthy(): |
30 | 23 | logging.critical("Database connection is closed, caching disabled.") |
31 | | - return await fn(*args, **kw) |
| 24 | + return await fn(*args, **kwargs) |
32 | 25 |
|
33 | 26 | # Clear cache if we have a event from |
34 | 27 | # the database the instructs us to clear. |
35 | 28 | if strategy.clear(): |
36 | 29 | logging.debug("Cache clear") |
37 | 30 | cached.clear() |
38 | 31 |
|
39 | | - # Check for cache hit |
40 | | - key = utils.make_key(args, kw) |
41 | | - with contextlib.suppress(KeyError): |
42 | | - # OBS: Will only await if the cache key hits. |
43 | | - result = await cached[key] |
| 32 | + key = utils.make_key(args, kwargs) |
| 33 | + |
| 34 | + try: |
| 35 | + waiter = cached[key] |
| 36 | + except KeyError: |
| 37 | + # Cache miss |
| 38 | + ... |
| 39 | + else: |
| 40 | + # Cache hit |
44 | 41 | logging.debug("Cache hit") |
45 | 42 | if statistics_callback: |
46 | 43 | statistics_callback("hit") |
47 | | - return result |
| 44 | + return await waiter |
48 | 45 |
|
49 | | - # Below deals with a cache miss. |
50 | 46 | logging.debug("Cache miss") |
51 | 47 | if statistics_callback: |
52 | 48 | statistics_callback("miss") |
53 | 49 |
|
54 | | - # By using a future as placeholder we avoid |
55 | | - # cache stampeded. Note that on the "miss" branch/path, controll |
56 | | - # is never given to the eventloopscheduler before the future |
57 | | - # is create. |
| 50 | + # Initialize Future to prevent cache stampedes. |
58 | 51 | cached[key] = waiter = asyncio.Future[T]() |
| 52 | + |
59 | 53 | try: |
60 | | - result = await fn(*args, **kw) |
| 54 | + # # Attempt to compute result and set for waiter |
| 55 | + waiter.set_result(await fn(*args, **kwargs)) |
61 | 56 | except Exception as e: |
62 | | - cached.pop( |
63 | | - key, None |
64 | | - ) # Next try should not result in a repeating exception |
65 | | - waiter.set_exception( |
66 | | - e |
67 | | - ) # Propegate exception to other callers who are waiting. |
68 | | - raise e from None # Propegate exception to first caller. |
69 | | - else: |
70 | | - waiter.set_result(result) |
| 57 | + # Remove key from cache on failure. |
| 58 | + cached.pop(key, None) |
| 59 | + # Propagate exception to all awaiting the future. |
| 60 | + waiter.set_exception(e) |
71 | 61 |
|
72 | | - return result |
| 62 | + return await waiter |
73 | 63 |
|
74 | 64 | return inner |
75 | 65 |
|
|
0 commit comments