Skip to content

Commit e533186

Browse files
authored
InMemoryChannelLayer improvements, test fixes (#1976)
1 parent e39fe13 commit e533186

File tree

2 files changed

+62
-31
lines changed

2 files changed

+62
-31
lines changed

channels/layers.py

+35-28
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,13 @@ def __init__(
198198
group_expiry=86400,
199199
capacity=100,
200200
channel_capacity=None,
201-
**kwargs
201+
**kwargs,
202202
):
203203
super().__init__(
204204
expiry=expiry,
205205
capacity=capacity,
206206
channel_capacity=channel_capacity,
207-
**kwargs
207+
**kwargs,
208208
)
209209
self.channels = {}
210210
self.groups = {}
@@ -225,13 +225,14 @@ async def send(self, channel, message):
225225
# name in message
226226
assert "__asgi_channel__" not in message
227227

228-
queue = self.channels.setdefault(channel, asyncio.Queue())
229-
# Are we full
230-
if queue.qsize() >= self.capacity:
231-
raise ChannelFull(channel)
232-
228+
queue = self.channels.setdefault(
229+
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
230+
)
233231
# Add message
234-
await queue.put((time.time() + self.expiry, deepcopy(message)))
232+
try:
233+
queue.put_nowait((time.time() + self.expiry, deepcopy(message)))
234+
except asyncio.queues.QueueFull:
235+
raise ChannelFull(channel)
235236

236237
async def receive(self, channel):
237238
"""
@@ -242,14 +243,16 @@ async def receive(self, channel):
242243
assert self.valid_channel_name(channel)
243244
self._clean_expired()
244245

245-
queue = self.channels.setdefault(channel, asyncio.Queue())
246+
queue = self.channels.setdefault(
247+
channel, asyncio.Queue(maxsize=self.get_capacity(channel))
248+
)
246249

247250
# Do a plain direct receive
248251
try:
249252
_, message = await queue.get()
250253
finally:
251254
if queue.empty():
252-
del self.channels[channel]
255+
self.channels.pop(channel, None)
253256

254257
return message
255258

@@ -279,19 +282,17 @@ def _clean_expired(self):
279282
self._remove_from_groups(channel)
280283
# Is the channel now empty and needs deleting?
281284
if queue.empty():
282-
del self.channels[channel]
285+
self.channels.pop(channel, None)
283286

284287
# Group Expiration
285288
timeout = int(time.time()) - self.group_expiry
286-
for group in self.groups:
287-
for channel in list(self.groups.get(group, set())):
288-
# If join time is older than group_expiry end the group membership
289-
if (
290-
self.groups[group][channel]
291-
and int(self.groups[group][channel]) < timeout
292-
):
289+
for channels in self.groups.values():
290+
for name, timestamp in list(channels.items()):
291+
# If join time is older than group_expiry
292+
# end the group membership
293+
if timestamp and timestamp < timeout:
293294
# Delete from group
294-
del self.groups[group][channel]
295+
channels.pop(name, None)
295296

296297
# Flush extension
297298

@@ -308,8 +309,7 @@ def _remove_from_groups(self, channel):
308309
Removes a channel from all groups. Used when a message on it expires.
309310
"""
310311
for channels in self.groups.values():
311-
if channel in channels:
312-
del channels[channel]
312+
channels.pop(channel, None)
313313

314314
# Groups extension
315315

@@ -329,22 +329,29 @@ async def group_discard(self, group, channel):
329329
assert self.valid_channel_name(channel), "Invalid channel name"
330330
assert self.valid_group_name(group), "Invalid group name"
331331
# Remove from group set
332-
if group in self.groups:
333-
if channel in self.groups[group]:
334-
del self.groups[group][channel]
335-
if not self.groups[group]:
336-
del self.groups[group]
332+
group_channels = self.groups.get(group, None)
333+
if group_channels:
334+
# remove channel if in group
335+
group_channels.pop(channel, None)
336+
# is group now empty? If yes remove it
337+
if not group_channels:
338+
self.groups.pop(group, None)
337339

338340
async def group_send(self, group, message):
339341
# Check types
340342
assert isinstance(message, dict), "Message is not a dict"
341343
assert self.valid_group_name(group), "Invalid group name"
342344
# Run clean
343345
self._clean_expired()
346+
344347
# Send to each channel
345-
for channel in self.groups.get(group, set()):
348+
ops = []
349+
if group in self.groups:
350+
for channel in self.groups[group].keys():
351+
ops.append(asyncio.create_task(self.send(channel, message)))
352+
for send_result in asyncio.as_completed(ops):
346353
try:
347-
await self.send(channel, message)
354+
await send_result
348355
except ChannelFull:
349356
pass
350357

tests/test_inmemorychannel.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,36 @@ async def test_send_receive(channel_layer):
2626
await channel_layer.send(
2727
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
2828
)
29+
await channel_layer.send(
30+
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
31+
)
2932
message = await channel_layer.receive("test-channel-1")
3033
assert message["type"] == "test.message"
3134
assert message["text"] == "Ahoy-hoy!"
35+
# not removed because not empty
36+
assert "test-channel-1" in channel_layer.channels
37+
message = await channel_layer.receive("test-channel-1")
38+
assert message["type"] == "test.message"
39+
assert message["text"] == "Ahoy-hoy!"
40+
# removed because empty
41+
assert "test-channel-1" not in channel_layer.channels
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_race_empty(channel_layer):
46+
"""
47+
Makes sure the race is handled gracefully.
48+
"""
49+
receive_task = asyncio.create_task(channel_layer.receive("test-channel-1"))
50+
await asyncio.sleep(0.1)
51+
await channel_layer.send(
52+
"test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"}
53+
)
54+
del channel_layer.channels["test-channel-1"]
55+
await asyncio.sleep(0.1)
56+
message = await receive_task
57+
assert message["type"] == "test.message"
58+
assert message["text"] == "Ahoy-hoy!"
3259

3360

3461
@pytest.mark.asyncio
@@ -62,7 +89,6 @@ async def test_multi_send_receive(channel_layer):
6289
"""
6390
Tests overlapping sends and receives, and ordering.
6491
"""
65-
channel_layer = InMemoryChannelLayer()
6692
await channel_layer.send("test-channel-3", {"type": "message.1"})
6793
await channel_layer.send("test-channel-3", {"type": "message.2"})
6894
await channel_layer.send("test-channel-3", {"type": "message.3"})
@@ -76,7 +102,6 @@ async def test_groups_basic(channel_layer):
76102
"""
77103
Tests basic group operation.
78104
"""
79-
channel_layer = InMemoryChannelLayer()
80105
await channel_layer.group_add("test-group", "test-gr-chan-1")
81106
await channel_layer.group_add("test-group", "test-gr-chan-2")
82107
await channel_layer.group_add("test-group", "test-gr-chan-3")
@@ -97,7 +122,6 @@ async def test_groups_channel_full(channel_layer):
97122
"""
98123
Tests that group_send ignores ChannelFull
99124
"""
100-
channel_layer = InMemoryChannelLayer()
101125
await channel_layer.group_add("test-group", "test-gr-chan-1")
102126
await channel_layer.group_send("test-group", {"type": "message.1"})
103127
await channel_layer.group_send("test-group", {"type": "message.1"})

0 commit comments

Comments
 (0)