diff --git a/flask_testing/utils.py b/flask_testing/utils.py index 7362c40..7ef7505 100644 --- a/flask_testing/utils.py +++ b/flask_testing/utils.py @@ -13,6 +13,7 @@ import gc import multiprocessing import socket +import threading import time try: @@ -32,7 +33,7 @@ # Python 2 urlparse fallback from urlparse import urlparse, urljoin -from werkzeug import cached_property +from werkzeug import cached_property, serving # Use Flask's preferred JSON module so that our runtime behavior matches. from flask import json_available, templating, template_rendered @@ -436,7 +437,6 @@ def __call__(self, result=None): self.app = self.create_app() self._configured_port = self.app.config.get('LIVESERVER_PORT', 5000) - self._port_value = multiprocessing.Value('i', self._configured_port) # We need to create a context in order for extensions to catch up self._ctx = self.app.test_request_context() @@ -453,37 +453,16 @@ def get_server_url(self): """ Return the url of the test server """ - return 'http://localhost:%s' % self._port_value.value + return 'http://localhost:%s' % self._port def _spawn_live_server(self): - self._process = None - port_value = self._port_value - - def worker(app, port): - # Based on solution: http://stackoverflow.com/a/27598916 - # Monkey-patch the server_bind so we can determine the port bound by Flask. - # This handles the case where the port specified is `0`, which means that - # the OS chooses the port. This is the only known way (currently) of getting - # the port out of Flask once we call `run`. - original_socket_bind = socketserver.TCPServer.server_bind - def socket_bind_wrapper(self): - ret = original_socket_bind(self) - - # Get the port and save it into the port_value, so the parent process - # can read it. - (_, port) = self.socket.getsockname() - port_value.value = port - socketserver.TCPServer.server_bind = original_socket_bind - return ret - - socketserver.TCPServer.server_bind = socket_bind_wrapper - app.run(port=port, use_reloader=False) - - self._process = multiprocessing.Process( - target=worker, args=(self.app, self._configured_port) + self._server = serving.make_server( + 'localhost', self._configured_port, self.app, ) + (_, self._port) = self._server.socket.getsockname() - self._process.start() + self._thread = threading.Thread(target=self._server.serve_forever, args=()) + self._thread.start() # We must wait for the server to start listening, but give up # after a specified maximum timeout @@ -548,5 +527,10 @@ def _post_teardown(self): del self._ctx def _terminate_live_server(self): - if self._process: - self._process.terminate() + if self._server: + self._server.shutdown() + self._server = None + + if self._thread: + self._thread.join() + self._thread = None diff --git a/tests/test_utils.py b/tests/test_utils.py index 991f510..a3fa0f6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -210,14 +210,14 @@ def test_assert_no_flashed_messages_fail(self): class BaseTestLiveServer(LiveServerTestCase): - def test_server_process_is_spawned(self): - process = self._process + def test_server_thread_is_spawned(self): + thread = self._thread - # Check the process is spawned - self.assertNotEqual(process, None) + # Check the thread is spawned + self.assertNotEqual(thread, None) - # Check the process is alive - self.assertTrue(process.is_alive()) + # Check the thread is alive + self.assertTrue(thread.is_alive()) def test_server_listening(self): response = urlopen(self.get_server_url())