Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion channels/testing/live.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
import multiprocessing

from daphne.testing import DaphneProcess
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
Expand All @@ -11,9 +12,15 @@
from channels.routing import get_default_application


def make_application(*, static_wrapper):
# Global queue for commands from test process to server process
_server_command_queue = None


def make_application(*, static_wrapper, commands={}):
# Module-level function for pickle-ability
application = get_default_application()
# Wrap the application with our command processing middleware
application = ServerCommandMiddleware(application, commands)
if static_wrapper is not None:
application = static_wrapper(application)
return application
Expand All @@ -28,6 +35,34 @@ def set_database_connection():
settings.DATABASES["default"]["NAME"] = test_db_name


class ServerCommandMiddleware:
"""
Middleware that processes commands from the test process.
This is automatically added to the ASGI application in test mode.
"""
def __init__(self, app, commands):
self.app = app
self.commands = commands

async def __call__(self, scope, receive, send):
# Process any pending server commands before handling the request
self.process_server_commands()
return await self.app(scope, receive, send)

def process_server_commands(self):
global _server_command_queue
if _server_command_queue is None:
return

while not _server_command_queue.empty():
try:
command = _server_command_queue.get_nowait()
if command in self.commands:
self.commands[command]()
except:
break


class ChannelsLiveServerTestCase(TransactionTestCase):
"""
Does basically the same as TransactionTestCase but also launches a
Expand All @@ -40,6 +75,7 @@ class ChannelsLiveServerTestCase(TransactionTestCase):
ProtocolServerProcess = DaphneProcess
static_wrapper = ASGIStaticFilesHandler
serve_static = True
commands = {}

@property
def live_server_url(self):
Expand All @@ -51,6 +87,8 @@ def live_server_ws_url(self):

@classmethod
def setUpClass(cls):
global _server_command_queue

for connection in connections.all():
if cls._is_in_memory_db(connection):
raise ImproperlyConfigured(
Expand All @@ -64,9 +102,14 @@ def setUpClass(cls):
)
cls._live_server_modified_settings.enable()

# Create a command queue for communication with the server process
_server_command_queue = multiprocessing.Queue()
cls._server_command_queue = _server_command_queue

get_application = partial(
make_application,
static_wrapper=cls.static_wrapper if cls.serve_static else None,
commands=cls.commands,
)
cls._server_process = cls.ProtocolServerProcess(
cls.host,
Expand All @@ -89,6 +132,13 @@ def tearDownClass(cls):
cls._live_server_modified_settings.disable()
super().tearDownClass()

def run_server_command(self, command):
"""
Add command to server command queue.
"""
if hasattr(self.__class__, '_server_command_queue'):
self._server_command_queue.put(command)

@classmethod
def _is_in_memory_db(cls, connection):
"""
Expand Down