Skip to content

Commit 594185e

Browse files
committed
Allow in-memory sessions.
1 parent 5a97b3c commit 594185e

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

starlette_plus/middleware/sessions.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,39 @@
4242

4343

4444
class Storage:
45-
__slots__ = "pool"
45+
__slots__ = ("pool", "_keys")
4646

47-
def __init__(self, *, redis: Redis) -> None:
48-
self.pool: redis.Redis = redis.pool
47+
def __init__(self, *, redis: Redis | None = None) -> None:
48+
self.pool: redis.Redis | None = redis.pool if redis else None
49+
self._keys: dict[str, Any] = {}
4950

5051
async def get(self, data: dict[str, Any]) -> dict[str, Any]:
5152
expiry: datetime.datetime = datetime.datetime.fromisoformat(data["expiry"])
5253
key: str = data["_session_secret_key"]
5354

5455
if expiry <= datetime.datetime.now():
55-
await self.pool.delete(key) # type: ignore
56+
await self.delete(key)
5657
return {}
5758

58-
session: Any = await self.pool.get(key) # type: ignore
59+
if self.pool:
60+
session: Any = await self.pool.get(key) # type: ignore
61+
else:
62+
session: Any = self._keys.get(key)
63+
5964
return json.loads(session) if session else {}
6065

6166
async def set(self, key: str, value: dict[str, Any], *, max_age: int) -> None:
62-
await self.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
67+
if self.pool:
68+
await self.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
69+
return
70+
71+
self._keys[key] = json.dumps(value)
6372

6473
async def delete(self, key: str) -> None:
65-
await self.pool.delete(key) # type: ignore
74+
if self.pool:
75+
await self.pool.delete(key) # type: ignore
76+
else:
77+
self._keys.pop(key, None)
6678

6779

6880
class SessionMiddleware:
@@ -75,7 +87,7 @@ def __init__(
7587
max_age: int | None = None,
7688
same_site: str = "lax",
7789
secure: bool = True,
78-
redis: Redis,
90+
redis: Redis | None = None,
7991
) -> None:
8092
self.app: ASGIApp = app
8193
self.name: str = name or "__session_cookie"
@@ -113,6 +125,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
113125
scope["session"] = session
114126

115127
async def wrapper(message: Message) -> None:
128+
nonlocal original, session, cookie
129+
116130
if message["type"] != "http.response.start":
117131
await send(message)
118132
return

0 commit comments

Comments
 (0)