Skip to content

Commit a144b4b

Browse files
authored
Add warning for the length of the group name (#2122)
1 parent b502c73 commit a144b4b

File tree

2 files changed

+65
-30
lines changed

2 files changed

+65
-30
lines changed

channels/layers.py

+25-29
Original file line numberDiff line numberDiff line change
@@ -144,35 +144,31 @@ def match_type_and_length(self, name):
144144
invalid_name_error = (
145145
"{} name must be a valid unicode string "
146146
+ "with length < {} ".format(MAX_NAME_LENGTH)
147-
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
148-
+ "not {}"
147+
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods."
149148
)
150149

151-
def valid_channel_name(self, name, receive=False):
152-
if self.match_type_and_length(name):
153-
if bool(self.channel_name_regex.match(name)):
154-
# Check cases for special channels
155-
if "!" in name and not name.endswith("!") and receive:
156-
raise TypeError(
157-
"Specific channel names in receive() must end at the !"
158-
)
159-
return True
160-
raise TypeError(self.invalid_name_error.format("Channel", name))
161-
162-
def valid_group_name(self, name):
163-
if self.match_type_and_length(name):
164-
if bool(self.group_name_regex.match(name)):
165-
return True
166-
raise TypeError(self.invalid_name_error.format("Group", name))
150+
def require_valid_channel_name(self, name, receive=False):
151+
if not self.match_type_and_length(name):
152+
raise TypeError(self.invalid_name_error.format("Channel"))
153+
if not bool(self.channel_name_regex.match(name)):
154+
raise TypeError(self.invalid_name_error.format("Channel"))
155+
if "!" in name and not name.endswith("!") and receive:
156+
raise TypeError("Specific channel names in receive() must end at the !")
157+
return True
158+
159+
def require_valid_group_name(self, name):
160+
if not self.match_type_and_length(name):
161+
raise TypeError(self.invalid_name_error.format("Group"))
162+
if not bool(self.group_name_regex.match(name)):
163+
raise TypeError(self.invalid_name_error.format("Group"))
164+
return True
167165

168166
def valid_channel_names(self, names, receive=False):
169167
_non_empty_list = True if names else False
170168
_names_type = isinstance(names, list)
171169
assert _non_empty_list and _names_type, "names must be a non-empty list"
172-
173-
assert all(
174-
self.valid_channel_name(channel, receive=receive) for channel in names
175-
)
170+
for channel in names:
171+
self.require_valid_channel_name(channel, receive=receive)
176172
return True
177173

178174
def non_local_name(self, name):
@@ -243,7 +239,7 @@ async def send(self, channel, message):
243239
"""
244240
# Typecheck
245241
assert isinstance(message, dict), "message is not a dict"
246-
assert self.valid_channel_name(channel), "Channel name not valid"
242+
self.require_valid_channel_name(channel)
247243
# If it's a process-local channel, strip off local part and stick full
248244
# name in message
249245
assert "__asgi_channel__" not in message
@@ -263,7 +259,7 @@ async def receive(self, channel):
263259
If more than one coroutine waits on the same channel, a random one
264260
of the waiting coroutines will get the result.
265261
"""
266-
assert self.valid_channel_name(channel)
262+
self.require_valid_channel_name(channel)
267263
self._clean_expired()
268264

269265
queue = self.channels.setdefault(
@@ -341,16 +337,16 @@ async def group_add(self, group, channel):
341337
Adds the channel name to a group.
342338
"""
343339
# Check the inputs
344-
assert self.valid_group_name(group), "Group name not valid"
345-
assert self.valid_channel_name(channel), "Channel name not valid"
340+
self.require_valid_group_name(group)
341+
self.require_valid_channel_name(channel)
346342
# Add to group dict
347343
self.groups.setdefault(group, {})
348344
self.groups[group][channel] = time.time()
349345

350346
async def group_discard(self, group, channel):
351347
# Both should be text and valid
352-
assert self.valid_channel_name(channel), "Invalid channel name"
353-
assert self.valid_group_name(group), "Invalid group name"
348+
self.require_valid_channel_name(channel)
349+
self.require_valid_group_name(group)
354350
# Remove from group set
355351
group_channels = self.groups.get(group, None)
356352
if group_channels:
@@ -363,7 +359,7 @@ async def group_discard(self, group, channel):
363359
async def group_send(self, group, message):
364360
# Check types
365361
assert isinstance(message, dict), "Message is not a dict"
366-
assert self.valid_group_name(group), "Invalid group name"
362+
self.require_valid_group_name(group)
367363
# Run clean
368364
self._clean_expired()
369365

tests/test_layers.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ async def test_send_receive():
7272

7373
@pytest.mark.parametrize(
7474
"method",
75-
[BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name],
75+
[
76+
BaseChannelLayer().require_valid_channel_name,
77+
BaseChannelLayer().require_valid_group_name,
78+
],
7679
)
7780
@pytest.mark.parametrize(
7881
"channel_name,expected_valid",
@@ -84,3 +87,39 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid)
8487
else:
8588
with pytest.raises(TypeError):
8689
method(channel_name)
90+
91+
92+
@pytest.mark.parametrize(
93+
"name",
94+
[
95+
"a" * 101, # Group name too long
96+
],
97+
)
98+
def test_group_name_length_error_message(name):
99+
"""
100+
Ensure the correct error message is raised when group names
101+
exceed the character limit or contain invalid characters.
102+
"""
103+
layer = BaseChannelLayer()
104+
expected_error_message = layer.invalid_name_error.format("Group")
105+
106+
with pytest.raises(TypeError, match=expected_error_message):
107+
layer.require_valid_group_name(name)
108+
109+
110+
@pytest.mark.parametrize(
111+
"name",
112+
[
113+
"a" * 101, # Channel name too long
114+
],
115+
)
116+
def test_channel_name_length_error_message(name):
117+
"""
118+
Ensure the correct error message is raised when group names
119+
exceed the character limit or contain invalid characters.
120+
"""
121+
layer = BaseChannelLayer()
122+
expected_error_message = layer.invalid_name_error.format("Channel")
123+
124+
with pytest.raises(TypeError, match=expected_error_message):
125+
layer.require_valid_channel_name(name)

0 commit comments

Comments
 (0)