Skip to content

Commit e39fe13

Browse files
authored
Use the async sessions api if it exists (#2092)
1 parent 8d90b07 commit e39fe13

File tree

3 files changed

+95
-11
lines changed

3 files changed

+95
-11
lines changed

channels/sessions.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from importlib import import_module
44

5+
import django
56
from django.conf import settings
67
from django.contrib.sessions.backends.base import UpdateError
78
from django.core.exceptions import SuspiciousOperation
@@ -163,9 +164,7 @@ def __init__(self, scope, send):
163164

164165
async def resolve_session(self):
165166
session_key = self.scope["cookies"].get(self.cookie_name)
166-
self.scope["session"]._wrapped = await database_sync_to_async(
167-
self.session_store
168-
)(session_key)
167+
self.scope["session"]._wrapped = self.session_store(session_key)
169168

170169
async def send(self, message):
171170
"""
@@ -183,7 +182,7 @@ async def send(self, message):
183182
and message.get("status", 200) != 500
184183
and (modified or settings.SESSION_SAVE_EVERY_REQUEST)
185184
):
186-
await database_sync_to_async(self.save_session)()
185+
await self.save_session()
187186
# If this is a message type that can transport cookies back to the
188187
# client, then do so.
189188
if message["type"] in self.cookie_response_message_types:
@@ -221,12 +220,15 @@ async def send(self, message):
221220
# Pass up the send
222221
return await self.real_send(message)
223222

224-
def save_session(self):
223+
async def save_session(self):
225224
"""
226225
Saves the current session.
227226
"""
228227
try:
229-
self.scope["session"].save()
228+
if django.VERSION >= (5, 1):
229+
await self.scope["session"].asave()
230+
else:
231+
await database_sync_to_async(self.scope["session"].save)()
230232
except UpdateError:
231233
raise SuspiciousOperation(
232234
"The request's session was deleted before the "

docs/topics/sessions.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ whenever the session is modified.
7373

7474
If you are in a WebSocket consumer, however, the session is populated
7575
**but will never be saved automatically** - you must call
76-
``scope["session"].save()`` yourself whenever you want to persist a session
76+
``scope["session"].save()`` (or the asynchronous version,
77+
``scope["session"].asave()``) yourself whenever you want to persist a session
7778
to your session store. If you don't save, the session will still work correctly
7879
inside the consumer (as it's stored as an instance variable), but other
7980
connections or HTTP views won't be able to see the changes.

tests/test_http.py

+85-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import re
2+
from importlib import import_module
23

4+
import django
35
import pytest
6+
from django.conf import settings
47

58
from channels.consumer import AsyncConsumer
69
from channels.db import database_sync_to_async
@@ -93,15 +96,12 @@ async def test_session_samesite_invalid(samesite_invalid):
9396

9497
@pytest.mark.django_db(transaction=True)
9598
@pytest.mark.asyncio
96-
async def test_muliple_sessions():
99+
async def test_multiple_sessions():
97100
"""
98101
Create two application instances and test then out of order to verify that
99102
separate scopes are used.
100103
"""
101104

102-
async def inner(scope, receive, send):
103-
send(scope["path"])
104-
105105
class SimpleHttpApp(AsyncConsumer):
106106
async def http_request(self, event):
107107
await database_sync_to_async(self.scope["session"].save)()
@@ -123,3 +123,84 @@ async def http_request(self, event):
123123

124124
first_response = await first_communicator.get_response()
125125
assert first_response["body"] == b"/first/"
126+
127+
128+
@pytest.mark.django_db(transaction=True)
129+
@pytest.mark.asyncio
130+
async def test_session_saves():
131+
"""
132+
Saves information to a session and validates that it actually saves to the backend
133+
"""
134+
135+
class SimpleHttpApp(AsyncConsumer):
136+
@database_sync_to_async
137+
def set_fav_color(self):
138+
self.scope["session"]["fav_color"] = "blue"
139+
140+
async def http_request(self, event):
141+
if django.VERSION >= (5, 1):
142+
await self.scope["session"].aset("fav_color", "blue")
143+
else:
144+
await self.set_fav_color()
145+
await self.send(
146+
{"type": "http.response.start", "status": 200, "headers": []}
147+
)
148+
await self.send(
149+
{
150+
"type": "http.response.body",
151+
"body": self.scope["session"].session_key.encode(),
152+
}
153+
)
154+
155+
app = SessionMiddlewareStack(SimpleHttpApp.as_asgi())
156+
157+
communicator = HttpCommunicator(app, "GET", "/first/")
158+
159+
response = await communicator.get_response()
160+
session_key = response["body"].decode()
161+
162+
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
163+
session = SessionStore(session_key=session_key)
164+
if django.VERSION >= (5, 1):
165+
session_fav_color = await session.aget("fav_color")
166+
else:
167+
session_fav_color = await database_sync_to_async(session.get)("fav_color")
168+
169+
assert session_fav_color == "blue"
170+
171+
172+
@pytest.mark.django_db(transaction=True)
173+
@pytest.mark.asyncio
174+
async def test_session_save_update_error():
175+
"""
176+
Intentionally deletes the session to ensure that SuspiciousOperation is raised
177+
"""
178+
179+
async def inner(scope, receive, send):
180+
send(scope["path"])
181+
182+
class SimpleHttpApp(AsyncConsumer):
183+
@database_sync_to_async
184+
def set_fav_color(self):
185+
self.scope["session"]["fav_color"] = "blue"
186+
187+
async def http_request(self, event):
188+
# Create a session as normal:
189+
await database_sync_to_async(self.scope["session"].save)()
190+
191+
# Then simulate it's deletion from somewhere else:
192+
# (e.g. logging out from another request)
193+
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
194+
session = SessionStore(session_key=self.scope["session"].session_key)
195+
await database_sync_to_async(session.flush)()
196+
197+
await self.send(
198+
{"type": "http.response.start", "status": 200, "headers": []}
199+
)
200+
201+
app = SessionMiddlewareStack(SimpleHttpApp.as_asgi())
202+
203+
communicator = HttpCommunicator(app, "GET", "/first/")
204+
205+
with pytest.raises(django.core.exceptions.SuspiciousOperation):
206+
await communicator.get_response()

0 commit comments

Comments
 (0)