Skip to content

Commit f21169a

Browse files
committed
Perform periodic health checks on Redis, default to in-memory after a failure.
1 parent 2dbaa3e commit f21169a

File tree

5 files changed

+51
-35
lines changed

5 files changed

+51
-35
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,7 @@ cython_debug/
157157
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160-
#.idea/
160+
#.idea/
161+
162+
# Test files
163+
test*.py

starlette_plus/core.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
if TYPE_CHECKING:
3535
from starlette.types import ASGIApp, Message, Receive, Scope, Send
3636

37-
from .redis import Redis
3837
from .types_.core import Methods, RouteOptions
3938
from .types_.limiter import BucketType, ExemptCallable, RateLimitData
4039

@@ -192,31 +191,12 @@ def __init__(self, *args: Any, **kwargs: Unpack[ApplicationOptions]) -> None:
192191
middleware_: list[Middleware] = kwargs.pop("middleware", [])
193192
middleware_.insert(0, Middleware(LoggingMiddleware)) if self._access_log else None
194193

195-
statrtups = kwargs.pop("on_startup", [])
196-
statrtups.append(self.__startup)
197-
198-
super().__init__(*args, **kwargs, middleware=middleware_, on_startup=statrtups) # type: ignore
194+
super().__init__(*args, **kwargs, middleware=middleware_) # type: ignore
199195

200196
self.add_view(self)
201197
for view in views:
202198
self.add_view(view)
203199

204-
async def __startup(self) -> None:
205-
for middleware in self.user_middleware:
206-
redis: Redis | None = middleware.kwargs.get("redis", None) # type: ignore
207-
208-
if not redis:
209-
continue
210-
211-
try:
212-
resp: bool = await redis.ping()
213-
except Exception:
214-
resp = False
215-
216-
if not resp:
217-
logger.warning("Unable to connect to redis on %s, defaulting to in-memory.", middleware.cls.__name__)
218-
middleware.kwargs["redis"] = None
219-
220200
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
221201
self: Self = super().__new__(cls)
222202
self.__routes__ = []

starlette_plus/limiter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def __init__(self, redis: Redis | None = None) -> None:
5050
async def get_tat(self, key: str, /) -> datetime.datetime:
5151
now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
5252

53-
if self.redis:
53+
if self.redis and self.redis.could_connect:
5454
value: str | None = await self.redis.pool.get(key) # type: ignore
5555
return datetime.datetime.fromisoformat(value) if value else now # type: ignore
5656

5757
return self._keys.get(key, {"tat": now}).get("tat", now)
5858

5959
async def set_tat(self, key: str, /, *, tat: datetime.datetime, limit: RateLimit) -> None:
60-
if self.redis:
60+
if self.redis and self.redis.could_connect:
6161
await self.redis.pool.set(key, tat.isoformat(), ex=int(limit.period.total_seconds() + 60)) # type: ignore
6262
else:
6363
self._keys[key] = {"tat": tat, "limit": limit}

starlette_plus/middleware/sessions.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,13 @@
2525
from typing import TYPE_CHECKING, Any
2626

2727
import itsdangerous
28-
import redis.asyncio as redis
2928
from starlette.datastructures import MutableHeaders
3029
from starlette.requests import HTTPConnection
3130

3231
from ..redis import Redis
3332

3433

3534
if TYPE_CHECKING:
36-
import redis.asyncio as redis
3735
from starlette.types import ASGIApp, Message, Receive, Scope, Send
3836

3937
from ..redis import Redis
@@ -43,10 +41,10 @@
4341

4442

4543
class Storage:
46-
__slots__ = ("pool", "_keys")
44+
__slots__ = ("redis", "_keys")
4745

4846
def __init__(self, *, redis: Redis | None = None) -> None:
49-
self.pool: redis.Redis | None = redis.pool if redis else None
47+
self.redis: Redis | None = redis
5048
self._keys: dict[str, Any] = {}
5149

5250
async def get(self, data: dict[str, Any]) -> dict[str, Any]:
@@ -57,23 +55,23 @@ async def get(self, data: dict[str, Any]) -> dict[str, Any]:
5755
await self.delete(key)
5856
return {}
5957

60-
if self.pool:
61-
session: Any = await self.pool.get(key) # type: ignore
58+
if self.redis and self.redis.could_connect:
59+
session: Any = await self.redis.pool.get(key) # type: ignore
6260
else:
6361
session: Any = self._keys.get(key)
6462

6563
return json.loads(session) if session else {}
6664

6765
async def set(self, key: str, value: dict[str, Any], *, max_age: int) -> None:
68-
if self.pool:
69-
await self.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
66+
if self.redis and self.redis.could_connect:
67+
await self.redis.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
7068
return
7169

7270
self._keys[key] = json.dumps(value)
7371

7472
async def delete(self, key: str) -> None:
75-
if self.pool:
76-
await self.pool.delete(key) # type: ignore
73+
if self.redis and self.redis.could_connect:
74+
await self.redis.pool.delete(key) # type: ignore
7775
else:
7876
self._keys.pop(key, None)
7977

starlette_plus/redis.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,50 @@
1313
limitations under the License.
1414
"""
1515

16+
import asyncio
17+
import logging
18+
1619
import redis.asyncio as redis
1720

1821

22+
logger: logging.Logger = logging.getLogger(__name__)
23+
24+
1925
class Redis:
2026
def __init__(self, *, url: str | None = None) -> None:
2127
url = url or "redis://localhost:6379/0"
2228
pool = redis.ConnectionPool.from_url(url, decode_responses=True) # type: ignore
2329

2430
self.pool: redis.Redis = redis.Redis.from_pool(pool)
31+
self.url = url
32+
33+
self._could_connect: bool | None = None
34+
self._task = asyncio.create_task(self._health_task())
35+
36+
@property
37+
def could_connect(self) -> bool | None:
38+
return self._could_connect
2539

2640
async def ping(self) -> bool:
27-
return bool(await self.pool.ping()) # type: ignore
41+
try:
42+
async with asyncio.timeout(3.0):
43+
self._could_connect = bool(await self.pool.ping()) # type: ignore
44+
except Exception:
45+
if self._could_connect is not False:
46+
logger.warning(
47+
"Unable to connect to Redis: %s. Services relying on this instance will now be in-memory.", self.url
48+
)
49+
50+
self._could_connect = False
51+
52+
return self._could_connect
53+
54+
async def _health_task(self) -> None:
55+
while True:
56+
previous = self.could_connect
57+
await self.ping()
58+
59+
if not previous and self.could_connect:
60+
logger.info("Redis connection has been (re)established: %s", self.url)
61+
62+
await asyncio.sleep(5)

0 commit comments

Comments
 (0)