@@ -198,13 +198,13 @@ def __init__(
198
198
group_expiry = 86400 ,
199
199
capacity = 100 ,
200
200
channel_capacity = None ,
201
- ** kwargs
201
+ ** kwargs ,
202
202
):
203
203
super ().__init__ (
204
204
expiry = expiry ,
205
205
capacity = capacity ,
206
206
channel_capacity = channel_capacity ,
207
- ** kwargs
207
+ ** kwargs ,
208
208
)
209
209
self .channels = {}
210
210
self .groups = {}
@@ -225,13 +225,14 @@ async def send(self, channel, message):
225
225
# name in message
226
226
assert "__asgi_channel__" not in message
227
227
228
- queue = self .channels .setdefault (channel , asyncio .Queue ())
229
- # Are we full
230
- if queue .qsize () >= self .capacity :
231
- raise ChannelFull (channel )
232
-
228
+ queue = self .channels .setdefault (
229
+ channel , asyncio .Queue (maxsize = self .get_capacity (channel ))
230
+ )
233
231
# Add message
234
- await queue .put ((time .time () + self .expiry , deepcopy (message )))
232
+ try :
233
+ queue .put_nowait ((time .time () + self .expiry , deepcopy (message )))
234
+ except asyncio .queues .QueueFull :
235
+ raise ChannelFull (channel )
235
236
236
237
async def receive (self , channel ):
237
238
"""
@@ -242,14 +243,16 @@ async def receive(self, channel):
242
243
assert self .valid_channel_name (channel )
243
244
self ._clean_expired ()
244
245
245
- queue = self .channels .setdefault (channel , asyncio .Queue ())
246
+ queue = self .channels .setdefault (
247
+ channel , asyncio .Queue (maxsize = self .get_capacity (channel ))
248
+ )
246
249
247
250
# Do a plain direct receive
248
251
try :
249
252
_ , message = await queue .get ()
250
253
finally :
251
254
if queue .empty ():
252
- del self .channels [ channel ]
255
+ self .channels . pop ( channel , None )
253
256
254
257
return message
255
258
@@ -279,19 +282,17 @@ def _clean_expired(self):
279
282
self ._remove_from_groups (channel )
280
283
# Is the channel now empty and needs deleting?
281
284
if queue .empty ():
282
- del self .channels [ channel ]
285
+ self .channels . pop ( channel , None )
283
286
284
287
# Group Expiration
285
288
timeout = int (time .time ()) - self .group_expiry
286
- for group in self .groups :
287
- for channel in list (self .groups .get (group , set ())):
288
- # If join time is older than group_expiry end the group membership
289
- if (
290
- self .groups [group ][channel ]
291
- and int (self .groups [group ][channel ]) < timeout
292
- ):
289
+ for channels in self .groups .values ():
290
+ for name , timestamp in list (channels .items ()):
291
+ # If join time is older than group_expiry
292
+ # end the group membership
293
+ if timestamp and timestamp < timeout :
293
294
# Delete from group
294
- del self . groups [ group ][ channel ]
295
+ channels . pop ( name , None )
295
296
296
297
# Flush extension
297
298
@@ -308,8 +309,7 @@ def _remove_from_groups(self, channel):
308
309
Removes a channel from all groups. Used when a message on it expires.
309
310
"""
310
311
for channels in self .groups .values ():
311
- if channel in channels :
312
- del channels [channel ]
312
+ channels .pop (channel , None )
313
313
314
314
# Groups extension
315
315
@@ -329,22 +329,29 @@ async def group_discard(self, group, channel):
329
329
assert self .valid_channel_name (channel ), "Invalid channel name"
330
330
assert self .valid_group_name (group ), "Invalid group name"
331
331
# Remove from group set
332
- if group in self .groups :
333
- if channel in self .groups [group ]:
334
- del self .groups [group ][channel ]
335
- if not self .groups [group ]:
336
- del self .groups [group ]
332
+ group_channels = self .groups .get (group , None )
333
+ if group_channels :
334
+ # remove channel if in group
335
+ group_channels .pop (channel , None )
336
+ # is group now empty? If yes remove it
337
+ if not group_channels :
338
+ self .groups .pop (group , None )
337
339
338
340
async def group_send (self , group , message ):
339
341
# Check types
340
342
assert isinstance (message , dict ), "Message is not a dict"
341
343
assert self .valid_group_name (group ), "Invalid group name"
342
344
# Run clean
343
345
self ._clean_expired ()
346
+
344
347
# Send to each channel
345
- for channel in self .groups .get (group , set ()):
348
+ ops = []
349
+ if group in self .groups :
350
+ for channel in self .groups [group ].keys ():
351
+ ops .append (asyncio .create_task (self .send (channel , message )))
352
+ for send_result in asyncio .as_completed (ops ):
346
353
try :
347
- await self . send ( channel , message )
354
+ await send_result
348
355
except ChannelFull :
349
356
pass
350
357
0 commit comments