Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG FIX] Exception handling and enabling graceful shutdown of connection #2130

Closed
wants to merge 13 commits into from
44 changes: 27 additions & 17 deletions channels/generic/http.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import logging
import traceback

from channels.consumer import AsyncConsumer

from ..db import aclose_old_connections
from ..exceptions import StopConsumer

logger = logging.getLogger("channels.consumer")
if not logger.hasHandlers():
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)


class AsyncHttpConsumer(AsyncConsumer):
"""
Expand All @@ -17,10 +30,6 @@ async def send_headers(self, *, status=200, headers=None):
"""
Sets the HTTP response status and headers. Headers may be provided as
a list of tuples or as a dictionary.

Note that the ASGI spec requires that the protocol server only starts
sending the response to the client after ``self.send_body`` has been
called the first time.
"""
if headers is None:
headers = []
Expand All @@ -34,10 +43,6 @@ async def send_headers(self, *, status=200, headers=None):
async def send_body(self, body, *, more_body=False):
"""
Sends a response body to the client. The method expects a bytestring.

Set ``more_body=True`` if you want to send more body content later.
The default behavior closes the response, and further messages on
the channel will be ignored.
"""
assert isinstance(body, bytes), "Body is not bytes"
await self.send(
Expand All @@ -46,18 +51,14 @@ async def send_body(self, body, *, more_body=False):

async def send_response(self, status, body, **kwargs):
"""
Sends a response to the client. This is a thin wrapper over
``self.send_headers`` and ``self.send_body``, and everything said
above applies here as well. This method may only be called once.
Sends a response to the client.
"""
await self.send_headers(status=status, **kwargs)
await self.send_body(body)

async def handle(self, body):
"""
Receives the request body as a bytestring. Response may be composed
using the ``self.send*`` methods; the return value of this method is
thrown away.
Receives the request body as a bytestring.
"""
raise NotImplementedError(
"Subclasses of AsyncHttpConsumer must provide a handle() method."
Expand All @@ -77,9 +78,14 @@ async def http_request(self, message):
"""
if "body" in message:
self.body.append(message["body"])

if not message.get("more_body"):
try:
await self.handle(b"".join(self.body))
except Exception:
logger.error(f"Error in handle(): {traceback.format_exc()}")
await self.send_response(500, b"Internal Server Error")
raise
finally:
await self.disconnect()
raise StopConsumer()
Expand All @@ -88,6 +94,10 @@ async def http_disconnect(self, message):
"""
Let the user do their cleanup and close the consumer.
"""
await self.disconnect()
await aclose_old_connections()
raise StopConsumer()
try:
await self.disconnect()
await aclose_old_connections()
except Exception as e:
logger.error(f"Error during disconnect: {str(e)}")
finally:
raise StopConsumer()
54 changes: 25 additions & 29 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,35 +144,31 @@ def match_type_and_length(self, name):
invalid_name_error = (
"{} name must be a valid unicode string "
+ "with length < {} ".format(MAX_NAME_LENGTH)
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
+ "not {}"
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods."
)

def valid_channel_name(self, name, receive=False):
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
if "!" in name and not name.endswith("!") and receive:
raise TypeError(
"Specific channel names in receive() must end at the !"
)
return True
raise TypeError(self.invalid_name_error.format("Channel", name))

def valid_group_name(self, name):
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(self.invalid_name_error.format("Group", name))
def require_valid_channel_name(self, name, receive=False):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Channel"))
if not bool(self.channel_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Channel"))
if "!" in name and not name.endswith("!") and receive:
raise TypeError("Specific channel names in receive() must end at the !")
return True

def require_valid_group_name(self, name):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Group"))
if not bool(self.group_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Group"))
return True

def valid_channel_names(self, names, receive=False):
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"

assert all(
self.valid_channel_name(channel, receive=receive) for channel in names
)
for channel in names:
self.require_valid_channel_name(channel, receive=receive)
return True

def non_local_name(self, name):
Expand Down Expand Up @@ -243,7 +239,7 @@ async def send(self, channel, message):
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_channel_name(channel)
# If it's a process-local channel, strip off local part and stick full
# name in message
assert "__asgi_channel__" not in message
Expand All @@ -263,7 +259,7 @@ async def receive(self, channel):
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self.require_valid_channel_name(channel)
self._clean_expired()

queue = self.channels.setdefault(
Expand Down Expand Up @@ -341,16 +337,16 @@ async def group_add(self, group, channel):
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_group_name(group)
self.require_valid_channel_name(channel)
# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()

async def group_discard(self, group, channel):
# Both should be text and valid
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
self.require_valid_channel_name(channel)
self.require_valid_group_name(group)
# Remove from group set
group_channels = self.groups.get(group, None)
if group_channels:
Expand All @@ -363,7 +359,7 @@ async def group_discard(self, group, channel):
async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
self.require_valid_group_name(group)
# Run clean
self._clean_expired()

Expand Down
54 changes: 41 additions & 13 deletions tests/test_generic_http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import time
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -57,8 +58,7 @@ async def handle(self, body):
@pytest.mark.asyncio
async def test_per_scope_consumers():
"""
Tests that a distinct consumer is used per scope, with AsyncHttpConsumer as
the example consumer class.
Tests that a distinct consumer is used per scope.
"""

class TestConsumer(AsyncHttpConsumer):
Expand All @@ -68,7 +68,6 @@ def __init__(self):

async def handle(self, body):
body = f"{self.__class__.__name__} {id(self)} {self.time}"

await self.send_response(
200,
body.encode("utf-8"),
Expand All @@ -77,27 +76,21 @@ async def handle(self, body):

app = TestConsumer.as_asgi()

# Open a connection
communicator = HttpCommunicator(app, method="GET", path="/test/")
response = await communicator.get_response()
assert response["status"] == 200

# And another one.
communicator = HttpCommunicator(app, method="GET", path="/test2/")
second_response = await communicator.get_response()
assert second_response["status"] == 200

assert response["body"] != second_response["body"]


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_async_http_consumer_future():
"""
Regression test for channels accepting only coroutines. The ASGI specification
states that the `receive` and `send` arguments to an ASGI application should be
"awaitable callable" objects. That includes non-coroutine functions that return
Futures.
Regression test for channels accepting only coroutines.
"""

class TestConsumer(AsyncHttpConsumer):
Expand All @@ -110,7 +103,6 @@ async def handle(self, body):

app = TestConsumer()

# Ensure the passed functions are specifically coroutines.
async def coroutine_app(scope, receive, send):
async def receive_coroutine():
return await asyncio.ensure_future(receive())
Expand All @@ -126,7 +118,6 @@ async def send_coroutine(*args, **kwargs):
assert response["status"] == 200
assert response["headers"] == [(b"Content-Type", b"text/plain")]

# Ensure the passed functions are "Awaitable Callables" and NOT coroutines.
async def awaitable_callable_app(scope, receive, send):
def receive_awaitable_callable():
return asyncio.ensure_future(receive())
Expand All @@ -136,9 +127,46 @@ def send_awaitable_callable(*args, **kwargs):

await app(scope, receive_awaitable_callable, send_awaitable_callable)

# Open a connection
communicator = HttpCommunicator(awaitable_callable_app, method="GET", path="/")
response = await communicator.get_response()
assert response["body"] == b"42"
assert response["status"] == 200
assert response["headers"] == [(b"Content-Type", b"text/plain")]


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_error_logging():
"""Regression test for error logging."""

class TestConsumer(AsyncHttpConsumer):
async def handle(self, body):
raise AssertionError("Error correctly raised")

communicator = HttpCommunicator(TestConsumer(), "GET", "/")
with patch("channels.generic.http.logger.error") as mock_logger_error:
try:
await communicator.get_response(timeout=0.05)
except AssertionError:
pass
args, _ = mock_logger_error.call_args
assert "Error in handle()" in args[0]
assert "AssertionError: Error correctly raised" in args[0]


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_error_handling_and_send_response():
"""Regression test to check error handling."""

class TestConsumer(AsyncHttpConsumer):
async def handle(self, body):
raise AssertionError("Error correctly raised")

communicator = HttpCommunicator(TestConsumer(), "GET", "/")
with patch.object(AsyncHttpConsumer, "send_response") as mock_send_response:
try:
await communicator.get_response(timeout=0.05)
except AssertionError:
pass
mock_send_response.assert_called_once_with(500, b"Internal Server Error")
41 changes: 40 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ async def test_send_receive():

@pytest.mark.parametrize(
"method",
[BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name],
[
BaseChannelLayer().require_valid_channel_name,
BaseChannelLayer().require_valid_group_name,
],
)
@pytest.mark.parametrize(
"channel_name,expected_valid",
Expand All @@ -84,3 +87,39 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid)
else:
with pytest.raises(TypeError):
method(channel_name)


@pytest.mark.parametrize(
"name",
[
"a" * 101, # Group name too long
],
)
def test_group_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Group")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_group_name(name)


@pytest.mark.parametrize(
"name",
[
"a" * 101, # Channel name too long
],
)
def test_channel_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Channel")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_channel_name(name)
Loading