diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 6d95084bd..097dca85a 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -810,7 +810,11 @@ class PingHandler(TestWebSocketHandler): def on_pong(self, data): self.write_message("got pong") - return Application([("/", PingHandler)], websocket_ping_interval=0.01) + return Application( + [("/", PingHandler)], + websocket_ping_interval=0.01, + websocket_ping_timeout=0, + ) @gen_test def test_server_ping(self): @@ -831,14 +835,82 @@ def on_ping(self, data): @gen_test def test_client_ping(self): - ws = yield self.ws_connect("/", ping_interval=0.01) + ws = yield self.ws_connect("/", ping_interval=0.01, ping_timeout=0) for i in range(3): response = yield ws.read_message() self.assertEqual(response, "got ping") - # TODO: test that the connection gets closed if ping responses stop. ws.close() +class ServerPingTimeoutTest(WebSocketBaseTestCase): + def get_app(self): + self.handlers: list[WebSocketHandler] = [] + test = self + + class PingHandler(TestWebSocketHandler): + def initialize(self, close_future=None, compression_options=None): + self.handlers = test.handlers + # capture the handler instance so we can interrogate it later + self.handlers.append(self) + return super().initialize( + close_future=close_future, compression_options=compression_options + ) + + app = Application([("/", PingHandler)]) + return app + + @staticmethod + def suppress_pong(ws): + """Suppress the client's "pong" response.""" + + def wrapper(fcn): + def _inner(oppcode: int, data: bytes): + if oppcode == 0xA: # NOTE: 0x9=ping, 0xA=pong + # prevent pong responses + return + # leave all other responses unchanged + return fcn(oppcode, data) + + return _inner + + ws.protocol._handle_message = wrapper(ws.protocol._handle_message) + + @gen_test + def test_client_ping_timeout(self): + # websocket client + interval = 0.2 + ws = yield self.ws_connect( + "/", ping_interval=interval, ping_timeout=interval / 4 + ) + + # websocket handler (server side) + handler = self.handlers[0] + + for _ in range(5): + # wait for the ping period + yield gen.sleep(0.2) + + # connection should still be open from the server end + self.assertIsNone(handler.close_code) + self.assertIsNone(handler.close_reason) + + # connection should still be open from the client end + assert ws.protocol.close_code is None + + # suppress the pong response message + self.suppress_pong(ws) + + # give the server time to register this + yield gen.sleep(interval * 1.5) + + # connection should be closed from the server side + self.assertEqual(handler.close_code, 1000) + self.assertEqual(handler.close_reason, "ping timed out") + + # client should have received a close operation + self.assertEqual(ws.protocol.close_code, 1000) + + class ManualPingTest(WebSocketBaseTestCase): def get_app(self): class PingHandler(TestWebSocketHandler): diff --git a/tornado/websocket.py b/tornado/websocket.py index 1e0161e1b..4fbb2da12 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -14,7 +14,9 @@ import abc import asyncio import base64 +import functools import hashlib +import logging import os import sys import struct @@ -26,7 +28,7 @@ from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.escape import utf8, native_str, to_unicode from tornado import gen, httpclient, httputil -from tornado.ioloop import IOLoop, PeriodicCallback +from tornado.ioloop import IOLoop from tornado.iostream import StreamClosedError, IOStream from tornado.log import gen_log, app_log from tornado.netutil import Resolver @@ -97,6 +99,9 @@ def log_exception( _default_max_message_size = 10 * 1024 * 1024 +# log to "gen_log" but suppress duplicate log messages +de_dupe_gen_log = functools.lru_cache(gen_log.log) + class WebSocketError(Exception): pass @@ -274,17 +279,41 @@ async def get(self, *args: Any, **kwargs: Any) -> None: @property def ping_interval(self) -> Optional[float]: - """The interval for websocket keep-alive pings. + """The interval for sending websocket pings. + + If this is non-zero, the websocket will send a ping every + ping_interval seconds. + The client will respond with a "pong". The connection can be configured + to timeout on late pong delivery using ``websocket_ping_timeout``. - Set websocket_ping_interval = 0 to disable pings. + Set ``websocket_ping_interval = 0`` to disable pings. + + Default: ``0`` """ return self.settings.get("websocket_ping_interval", None) @property def ping_timeout(self) -> Optional[float]: - """If no ping is received in this many seconds, - close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). - Default is max of 3 pings or 30 seconds. + """Timeout if no pong is received in this many seconds. + + To be used in combination with ``websocket_ping_interval > 0``. + If a ping response (a "pong") is not received within + ``websocket_ping_timeout`` seconds, then the websocket connection + will be closed. + + This can help to clean up clients which have disconnected without + cleanly closing the websocket connection. + + Note, the ping timeout cannot be longer than the ping interval. + + Set ``websocket_ping_timeout = 0`` to disable the ping timeout. + + Default: ``min(ping_interval, 30)`` + + .. versionchanged:: 6.5.0 + Default changed from the max of 3 pings or 30 seconds. + The ping timeout can no longer be configured longer than the + ping interval. """ return self.settings.get("websocket_ping_timeout", None) @@ -831,11 +860,10 @@ def __init__( # the effect of compression, frame overhead, and control frames. self._wire_bytes_in = 0 self._wire_bytes_out = 0 - self.ping_callback = None # type: Optional[PeriodicCallback] - self.last_ping = 0.0 - self.last_pong = 0.0 + self._received_pong = False # type: bool self.close_code = None # type: Optional[int] self.close_reason = None # type: Optional[str] + self._ping_coroutine = None # type: Optional[asyncio.Task] # Use a property for this to satisfy the abc. @property @@ -1232,7 +1260,7 @@ def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]": self._run_callback(self.handler.on_ping, data) elif opcode == 0xA: # Pong - self.last_pong = IOLoop.current().time() + self._received_pong = True return self._run_callback(self.handler.on_pong, data) else: self._abort() @@ -1266,9 +1294,9 @@ def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> Non self._waiting = self.stream.io_loop.add_timeout( self.stream.io_loop.time() + 5, self._abort ) - if self.ping_callback: - self.ping_callback.stop() - self.ping_callback = None + if self._ping_coroutine: + self._ping_coroutine.cancel() + self._ping_coroutine = None def is_closing(self) -> bool: """Return ``True`` if this connection is closing. @@ -1279,60 +1307,69 @@ def is_closing(self) -> bool: """ return self.stream.closed() or self.client_terminated or self.server_terminated + def set_nodelay(self, x: bool) -> None: + self.stream.set_nodelay(x) + @property - def ping_interval(self) -> Optional[float]: + def ping_interval(self) -> float: interval = self.params.ping_interval if interval is not None: return interval return 0 @property - def ping_timeout(self) -> Optional[float]: + def ping_timeout(self) -> float: timeout = self.params.ping_timeout if timeout is not None: + if self.ping_interval and timeout > self.ping_interval: + de_dupe_gen_log( + # Note: using de_dupe_gen_log to prevent this message from + # being duplicated for each connection + logging.WARNING, + f"The websocket_ping_timeout ({timeout}) cannot be longer" + f" than the websocket_ping_interval ({self.ping_interval})." + f"\nSetting websocket_ping_timeout={self.ping_interval}", + ) + return self.ping_interval return timeout - assert self.ping_interval is not None - return max(3 * self.ping_interval, 30) + return self.ping_interval def start_pinging(self) -> None: """Start sending periodic pings to keep the connection alive""" - assert self.ping_interval is not None - if self.ping_interval > 0: - self.last_ping = self.last_pong = IOLoop.current().time() - self.ping_callback = PeriodicCallback( - self.periodic_ping, self.ping_interval * 1000 - ) - self.ping_callback.start() + if ( + # prevent multiple ping coroutines being run in parallel + not self._ping_coroutine + # only run the ping coroutine if a ping interval is configured + and self.ping_interval > 0 + ): + self._ping_coroutine = asyncio.create_task(self.periodic_ping()) - def periodic_ping(self) -> None: - """Send a ping to keep the websocket alive + async def periodic_ping(self) -> None: + """Send a ping and wait for a pong if ping_timeout is configured. Called periodically if the websocket_ping_interval is set and non-zero. """ - if self.is_closing() and self.ping_callback is not None: - self.ping_callback.stop() - return + interval = self.ping_interval + timeout = self.ping_timeout - # Check for timeout on pong. Make sure that we really have - # sent a recent ping in case the machine with both server and - # client has been suspended since the last ping. - now = IOLoop.current().time() - since_last_pong = now - self.last_pong - since_last_ping = now - self.last_ping - assert self.ping_interval is not None - assert self.ping_timeout is not None - if ( - since_last_ping < 2 * self.ping_interval - and since_last_pong > self.ping_timeout - ): - self.close() - return + await asyncio.sleep(interval) - self.write_ping(b"") - self.last_ping = now + while True: + # send a ping + self._received_pong = False + ping_time = IOLoop.current().time() + self.write_ping(b"") - def set_nodelay(self, x: bool) -> None: - self.stream.set_nodelay(x) + # wait until the ping timeout + await asyncio.sleep(timeout) + + # make sure we received a pong within the timeout + if timeout > 0 and not self._received_pong: + self.close(reason="ping timed out") + return + + # wait until the next scheduled ping + await asyncio.sleep(IOLoop.current().time() - ping_time + interval) class WebSocketClientConnection(simple_httpclient._HTTPConnection):