Skip to content

websockets: fix ping_timeout #3376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions tornado/test/websocket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,11 @@ class PingHandler(TestWebSocketHandler):
def on_pong(self, data):
self.write_message("got pong")

return Application([("/", PingHandler)], websocket_ping_interval=0.01)
return Application(
[("/", PingHandler)],
websocket_ping_interval=0.01,
websocket_ping_timeout=0,
)

@gen_test
def test_server_ping(self):
Expand All @@ -831,14 +835,82 @@ def on_ping(self, data):

@gen_test
def test_client_ping(self):
ws = yield self.ws_connect("/", ping_interval=0.01)
ws = yield self.ws_connect("/", ping_interval=0.01, ping_timeout=0)
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got ping")
# TODO: test that the connection gets closed if ping responses stop.
ws.close()


class ServerPingTimeoutTest(WebSocketBaseTestCase):
def get_app(self):
self.handlers: list[WebSocketHandler] = []
test = self

class PingHandler(TestWebSocketHandler):
def initialize(self, close_future=None, compression_options=None):
self.handlers = test.handlers
# capture the handler instance so we can interrogate it later
self.handlers.append(self)
return super().initialize(
close_future=close_future, compression_options=compression_options
)

app = Application([("/", PingHandler)])
return app

@staticmethod
def suppress_pong(ws):
"""Suppress the client's "pong" response."""

def wrapper(fcn):
def _inner(oppcode: int, data: bytes):
if oppcode == 0xA: # NOTE: 0x9=ping, 0xA=pong
# prevent pong responses
return
# leave all other responses unchanged
return fcn(oppcode, data)

return _inner

ws.protocol._handle_message = wrapper(ws.protocol._handle_message)

@gen_test
def test_client_ping_timeout(self):
# websocket client
interval = 0.2
ws = yield self.ws_connect(
"/", ping_interval=interval, ping_timeout=interval / 4
)

# websocket handler (server side)
handler = self.handlers[0]

for _ in range(5):
# wait for the ping period
yield gen.sleep(0.2)

# connection should still be open from the server end
self.assertIsNone(handler.close_code)
self.assertIsNone(handler.close_reason)

# connection should still be open from the client end
assert ws.protocol.close_code is None

# suppress the pong response message
self.suppress_pong(ws)

# give the server time to register this
yield gen.sleep(interval * 1.5)

# connection should be closed from the server side
self.assertEqual(handler.close_code, 1000)
self.assertEqual(handler.close_reason, "ping timed out")

# client should have received a close operation
self.assertEqual(ws.protocol.close_code, 1000)


class ManualPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
Expand Down
131 changes: 84 additions & 47 deletions tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import abc
import asyncio
import base64
import functools
import hashlib
import logging
import os
import sys
import struct
Expand All @@ -26,7 +28,7 @@
from tornado.concurrent import Future, future_set_result_unless_cancelled
from tornado.escape import utf8, native_str, to_unicode
from tornado import gen, httpclient, httputil
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError, IOStream
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver
Expand Down Expand Up @@ -97,6 +99,9 @@ def log_exception(

_default_max_message_size = 10 * 1024 * 1024

# log to "gen_log" but suppress duplicate log messages
de_dupe_gen_log = functools.lru_cache(gen_log.log)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat trick; I haven't seen this one before.



class WebSocketError(Exception):
pass
Expand Down Expand Up @@ -274,17 +279,41 @@ async def get(self, *args: Any, **kwargs: Any) -> None:

@property
def ping_interval(self) -> Optional[float]:
"""The interval for websocket keep-alive pings.
"""The interval for sending websocket pings.

If this is non-zero, the websocket will send a ping every
ping_interval seconds.
The client will respond with a "pong". The connection can be configured
to timeout on late pong delivery using ``websocket_ping_timeout``.

Set websocket_ping_interval = 0 to disable pings.
Set ``websocket_ping_interval = 0`` to disable pings.

Default: ``0``
"""
return self.settings.get("websocket_ping_interval", None)

@property
def ping_timeout(self) -> Optional[float]:
"""If no ping is received in this many seconds,
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
Default is max of 3 pings or 30 seconds.
"""Timeout if no pong is received in this many seconds.

To be used in combination with ``websocket_ping_interval > 0``.
If a ping response (a "pong") is not received within
``websocket_ping_timeout`` seconds, then the websocket connection
will be closed.

This can help to clean up clients which have disconnected without
cleanly closing the websocket connection.

Note, the ping timeout cannot be longer than the ping interval.

Set ``websocket_ping_timeout = 0`` to disable the ping timeout.

Default: ``min(ping_interval, 30)``

.. versionchanged:: 6.5.0
Default changed from the max of 3 pings or 30 seconds.
The ping timeout can no longer be configured longer than the
ping interval.
"""
return self.settings.get("websocket_ping_timeout", None)

Expand Down Expand Up @@ -831,11 +860,10 @@ def __init__(
# the effect of compression, frame overhead, and control frames.
self._wire_bytes_in = 0
self._wire_bytes_out = 0
self.ping_callback = None # type: Optional[PeriodicCallback]
self.last_ping = 0.0
self.last_pong = 0.0
self._received_pong = False # type: bool
self.close_code = None # type: Optional[int]
self.close_reason = None # type: Optional[str]
self._ping_coroutine = None # type: Optional[asyncio.Task]

# Use a property for this to satisfy the abc.
@property
Expand Down Expand Up @@ -1232,7 +1260,7 @@ def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]":
self._run_callback(self.handler.on_ping, data)
elif opcode == 0xA:
# Pong
self.last_pong = IOLoop.current().time()
self._received_pong = True
return self._run_callback(self.handler.on_pong, data)
else:
self._abort()
Expand Down Expand Up @@ -1266,9 +1294,9 @@ def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> Non
self._waiting = self.stream.io_loop.add_timeout(
self.stream.io_loop.time() + 5, self._abort
)
if self.ping_callback:
self.ping_callback.stop()
self.ping_callback = None
if self._ping_coroutine:
self._ping_coroutine.cancel()
self._ping_coroutine = None

def is_closing(self) -> bool:
"""Return ``True`` if this connection is closing.
Expand All @@ -1279,60 +1307,69 @@ def is_closing(self) -> bool:
"""
return self.stream.closed() or self.client_terminated or self.server_terminated

def set_nodelay(self, x: bool) -> None:
self.stream.set_nodelay(x)

@property
def ping_interval(self) -> Optional[float]:
def ping_interval(self) -> float:
interval = self.params.ping_interval
if interval is not None:
return interval
return 0

@property
def ping_timeout(self) -> Optional[float]:
def ping_timeout(self) -> float:
timeout = self.params.ping_timeout
if timeout is not None:
if self.ping_interval and timeout > self.ping_interval:
de_dupe_gen_log(
# Note: using de_dupe_gen_log to prevent this message from
# being duplicated for each connection
logging.WARNING,
f"The websocket_ping_timeout ({timeout}) cannot be longer"
f" than the websocket_ping_interval ({self.ping_interval})."
f"\nSetting websocket_ping_timeout={self.ping_interval}",
)
return self.ping_interval
return timeout
assert self.ping_interval is not None
return max(3 * self.ping_interval, 30)
return self.ping_interval

def start_pinging(self) -> None:
"""Start sending periodic pings to keep the connection alive"""
assert self.ping_interval is not None
if self.ping_interval > 0:
self.last_ping = self.last_pong = IOLoop.current().time()
self.ping_callback = PeriodicCallback(
self.periodic_ping, self.ping_interval * 1000
)
self.ping_callback.start()
if (
# prevent multiple ping coroutines being run in parallel
not self._ping_coroutine
# only run the ping coroutine if a ping interval is configured
and self.ping_interval > 0
):
self._ping_coroutine = asyncio.create_task(self.periodic_ping())

def periodic_ping(self) -> None:
"""Send a ping to keep the websocket alive
async def periodic_ping(self) -> None:
"""Send a ping and wait for a pong if ping_timeout is configured.

Called periodically if the websocket_ping_interval is set and non-zero.
"""
if self.is_closing() and self.ping_callback is not None:
self.ping_callback.stop()
return
interval = self.ping_interval
timeout = self.ping_timeout

# Check for timeout on pong. Make sure that we really have
# sent a recent ping in case the machine with both server and
# client has been suspended since the last ping.
now = IOLoop.current().time()
since_last_pong = now - self.last_pong
since_last_ping = now - self.last_ping
assert self.ping_interval is not None
assert self.ping_timeout is not None
if (
since_last_ping < 2 * self.ping_interval
and since_last_pong > self.ping_timeout
):
self.close()
return
await asyncio.sleep(interval)

self.write_ping(b"")
self.last_ping = now
while True:
# send a ping
self._received_pong = False
ping_time = IOLoop.current().time()
self.write_ping(b"")

def set_nodelay(self, x: bool) -> None:
self.stream.set_nodelay(x)
# wait until the ping timeout
await asyncio.sleep(timeout)

# make sure we received a pong within the timeout
if timeout > 0 and not self._received_pong:
self.close(reason="ping timed out")
return

# wait until the next scheduled ping
await asyncio.sleep(IOLoop.current().time() - ping_time + interval)


class WebSocketClientConnection(simple_httpclient._HTTPConnection):
Expand Down