Skip to content

Commit cafdaca

Browse files
authored
Merge pull request #48 from tarasko/feature/44_transfer_exc_to_wait_disconnected
Feature/44 transfer exc to wait disconnected
2 parents 3c243fb + c577316 commit cafdaca

File tree

4 files changed

+133
-21
lines changed

4 files changed

+133
-21
lines changed

examples/reconnecting_client.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import asyncio
2+
from logging import getLogger, INFO, basicConfig
3+
4+
from picows import ws_connect, WSFrame, WSTransport, WSListener, WSMsgType
5+
6+
_logger = getLogger(__name__)
7+
8+
9+
class ClientListener(WSListener):
10+
def on_ws_connected(self, transport: WSTransport):
11+
transport.send(WSMsgType.TEXT, b"Hello world")
12+
13+
def on_ws_frame(self, transport: WSTransport, frame: WSFrame):
14+
if frame.msg_type == WSMsgType.TEXT:
15+
_logger.info("Echo reply: %s", frame.get_payload_as_ascii_text())
16+
17+
# Throw from on_ws_frame to illustrate how library deal with exceptions
18+
# picows will disconnect client and re-raise exception from wait_disconnected
19+
raise RuntimeError("some logic failed")
20+
21+
22+
async def main(url):
23+
while True:
24+
try:
25+
transport, client = await ws_connect(ClientListener, url)
26+
await transport.wait_disconnected()
27+
except asyncio.CancelledError:
28+
raise
29+
except Exception as e:
30+
_logger.error("Client disconnected, reconnect in 5 seconds: %s", str(e))
31+
await asyncio.sleep(5)
32+
33+
34+
if __name__ == '__main__':
35+
basicConfig(level=INFO)
36+
asyncio.run(main("ws://127.0.0.1:9001"))

picows/picows.pxd

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ cdef class WSTransport:
111111
bint auto_ping_expect_pong
112112
object pong_received_at_future
113113
object listener_proxy
114+
object disconnected_future #: asyncio.Future
114115

115116
object _logger #: Logger
116117
bint _log_debug_enabled
117118
bint _close_frame_is_sent
118-
object _disconnected_future #: asyncio.Future
119119
MemoryBuffer _write_buf
120120
int _socket
121121

@@ -130,7 +130,6 @@ cdef class WSTransport:
130130

131131
cdef inline _send_http_handshake(self, bytes ws_path, bytes host_port, bytes websocket_key_b64, object extra_headers)
132132
cdef inline _send_http_handshake_response(self, WSUpgradeResponse response, bytes accept_val)
133-
cdef inline _mark_disconnected(self)
134133
cdef inline _try_native_write_then_transport_write(self, char * ptr, Py_ssize_t sz)
135134

136135

picows/picows.pyx

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ cdef class WSListener:
407407
"""
408408
Called after websocket handshake is complete and websocket is ready to send and receive frames.
409409
Initiate disconnect if exception is thrown by user handler.
410+
411+
* client side: the exception will be transferred to and reraised by :any:`wait_disconnected`.
412+
* server side: the exception will be 'swallowed' by the library and logged at the ERROR level.
410413
411414
:param transport: :any:`WSTransport` object
412415
"""
@@ -418,8 +421,12 @@ cdef class WSListener:
418421
419422
Initiate disconnect if exception is thrown by user handler and
420423
`disconnect_on_exception` was set to True in :any:`ws_connect`
421-
or :any:`ws_create_server`
422-
424+
or :any:`ws_create_server`.
425+
In such case:
426+
427+
* client side: the exception will be transferred to and reraised by :any:`wait_disconnected`.
428+
* server side: the exception will be 'swallowed' by the library and logged at the ERROR level.
429+
423430
.. DANGER::
424431
WSFrame is essentially just a pointer to a chunk of memory in the receiving buffer. It does not own
425432
the memory. Do NOT cache or store WSFrame object for later processing because the data may be invalidated
@@ -498,10 +505,10 @@ cdef class WSTransport:
498505
self.auto_ping_expect_pong = False
499506
self.pong_received_at_future = None
500507
self.listener_proxy = None
508+
self.disconnected_future = loop.create_future()
501509
self._logger = logger
502510
self._log_debug_enabled = self._logger.isEnabledFor(PICOWS_DEBUG_LL)
503511
self._close_frame_is_sent = False
504-
self._disconnected_future = loop.create_future()
505512
self._write_buf = MemoryBuffer(1024)
506513
self._socket = underlying_transport.get_extra_info('socket').fileno()
507514

@@ -722,7 +729,7 @@ cdef class WSTransport:
722729
(underlying transport is closed, on_ws_disconnected has been called)
723730
724731
"""
725-
await asyncio.shield(self._disconnected_future)
732+
await asyncio.shield(self.disconnected_future)
726733

727734
async def measure_roundtrip_time(self, int rounds) -> List[float]:
728735
"""
@@ -828,10 +835,6 @@ cdef class WSTransport:
828835
self.response = response
829836
self.underlying_transport.write(response_bytes)
830837

831-
cdef _mark_disconnected(self):
832-
if not self._disconnected_future.done():
833-
self._disconnected_future.set_result(None)
834-
835838
cdef _try_native_write_then_transport_write(self, char* ptr, Py_ssize_t sz):
836839
if <size_t>self.underlying_transport.get_write_buffer_size() > 0:
837840
self.underlying_transport.write(PyBytes_FromStringAndSize(ptr, sz))
@@ -876,6 +879,7 @@ cdef class WSProtocol:
876879
bint _log_debug_enabled
877880
bint is_client_side
878881
bint _disconnect_on_exception
882+
object _disconnect_exception #: Optional[Exception]
879883

880884
object _loop
881885

@@ -936,6 +940,7 @@ cdef class WSProtocol:
936940
self._log_debug_enabled = self._logger.isEnabledFor(PICOWS_DEBUG_LL)
937941
self.is_client_side = is_client_side
938942
self._disconnect_on_exception = disconnect_on_exception
943+
self._disconnect_exception = None
939944

940945
self._loop = asyncio.get_running_loop()
941946

@@ -1035,7 +1040,13 @@ cdef class WSProtocol:
10351040
self.transport.pong_received_at_future.set_exception(ConnectionResetError())
10361041
self.transport.pong_received_at_future = None
10371042

1038-
self.transport._mark_disconnected()
1043+
if not self.transport.disconnected_future.done():
1044+
# The server side does not allow to await on a particular client or retrieve its disconnect exception.
1045+
# Do not set exception on future to avoid warnings about unconsumed exception from asyncio.
1046+
if self._disconnect_exception is None or not self.is_client_side:
1047+
self.transport.disconnected_future.set_result(None)
1048+
else:
1049+
self.transport.disconnected_future.set_exception(self._disconnect_exception)
10391050

10401051
def eof_received(self) -> bool:
10411052
if self._log_debug_enabled:
@@ -1531,8 +1542,12 @@ cdef class WSProtocol:
15311542
cdef inline _invoke_on_ws_connected(self):
15321543
try:
15331544
self.listener.on_ws_connected(self.transport)
1534-
except Exception as e:
1535-
self._logger.exception("Unhandled exception in on_ws_connected, initiate disconnect")
1545+
except Exception as exc:
1546+
if self.is_client_side:
1547+
self._logger.info("Exception from user's WSListener.on_ws_connected handler, initiate disconnect")
1548+
self._disconnect_exception = exc
1549+
else:
1550+
self._logger.exception("Exception from user's WSListener.on_ws_connected handler, initiate disconnect")
15361551
self.transport.send_close(WSCloseCode.INTERNAL_ERROR)
15371552
self._loop.call_later(DISCONNECT_AFTER_ERROR_DELAY, self.transport.disconnect)
15381553

@@ -1560,19 +1575,27 @@ cdef class WSProtocol:
15601575
return
15611576

15621577
self.listener.on_ws_frame(self.transport, frame)
1563-
except Exception as e:
1578+
except Exception as exc:
15641579
if self._disconnect_on_exception:
1565-
self._logger.exception("Unhandled exception in on_ws_frame, initiate disconnect")
1580+
if self.is_client_side:
1581+
if self._disconnect_exception is None:
1582+
self._disconnect_exception = exc
1583+
self._logger.info("Exception from user's WSListener.on_ws_frame, initiate disconnect")
1584+
else:
1585+
self._logger.exception("Secondary exception from user's WSListener.on_ws_frame")
1586+
else:
1587+
self._logger.exception("Exception from user's WSListener.on_ws_frame, initiate disconnect")
1588+
15661589
self.transport.send_close(WSCloseCode.INTERNAL_ERROR)
15671590
self._loop.call_later(DISCONNECT_AFTER_ERROR_DELAY, self.transport.disconnect)
15681591
else:
1569-
self._logger.exception("Unhandled exception in on_ws_frame")
1592+
self._logger.exception("Unhandled exception from user's WSListener.on_ws_frame")
15701593

15711594
cdef inline _invoke_on_ws_disconnected(self):
15721595
try:
15731596
self.listener.on_ws_disconnected(self.transport)
15741597
except:
1575-
self._logger.exception("Unhandled exception in on_ws_disconnected")
1598+
self._logger.exception("Unhandled exception from user's on_ws_disconnected")
15761599

15771600
cdef inline _shrink_buffer(self):
15781601
if self._f_curr_frame_start_pos > 0:

tests/test_basics.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
ServerAsyncContext, TIMEOUT, \
1515
materialize_frame, ClientAsyncContext
1616

17+
18+
class MyException(RuntimeError):
19+
pass
20+
21+
1722
if os.name == 'nt':
1823
@pytest.fixture(
1924
params=(
@@ -354,10 +359,28 @@ def factory_listener(r):
354359
(transport, _) = await picows.ws_connect(picows.WSListener, url)
355360

356361

357-
async def test_ws_on_connected_throw():
362+
async def test_ws_on_connected_throw_client_side():
363+
# Check that client side, initiate disconnect(no timeouts on wait_disconnected) and
364+
# transfer exception to wait_disconnected
365+
class ClientListener(picows.WSListener):
366+
def on_ws_connected(self, transport: picows.WSTransport):
367+
raise MyException("exception from client side on_ws_connected")
368+
369+
server = await picows.ws_create_server(lambda _: picows.WSListener(),
370+
"127.0.0.1", 0)
371+
async with ServerAsyncContext(server) as server_ctx:
372+
(transport, _) = await picows.ws_connect(ClientListener, server_ctx.plain_url)
373+
async with async_timeout.timeout(TIMEOUT):
374+
with pytest.raises(MyException):
375+
await transport.wait_disconnected()
376+
377+
378+
async def test_ws_on_connected_throw_server_side():
379+
# Check that server side initiate disconnect(no timeouts on wait_disconnected) and
380+
# swallow exception
358381
class ServerClientListener(picows.WSListener):
359382
def on_ws_connected(self, transport: picows.WSTransport):
360-
raise RuntimeError("exception from on_ws_connected")
383+
raise MyException("exception from server side on_ws_connected")
361384

362385
server = await picows.ws_create_server(lambda _: ServerClientListener(),
363386
"127.0.0.1", 0)
@@ -369,10 +392,41 @@ def on_ws_connected(self, transport: picows.WSTransport):
369392

370393
@pytest.mark.parametrize("disconnect_on_exception", [True, False],
371394
ids=["disconnect_on_exception", "no_disconnect_on_exception"])
372-
async def test_ws_on_frame_throw(disconnect_on_exception):
395+
async def test_ws_on_frame_throw_client_side(disconnect_on_exception):
396+
class ServerClientListener(picows.WSListener):
397+
def on_ws_connected(self, transport: picows.WSTransport):
398+
transport.send(WSMsgType.BINARY, b"Hello")
399+
400+
class ClientListener(picows.WSListener):
401+
def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame):
402+
raise MyException("exception from client side on_ws_frame")
403+
404+
server = await picows.ws_create_server(lambda _: ServerClientListener(),
405+
"127.0.0.1",
406+
0)
407+
408+
async with ServerAsyncContext(server) as server_ctx:
409+
transport, listener = await picows.ws_connect(ClientListener, server_ctx.plain_url,
410+
disconnect_on_exception=disconnect_on_exception)
411+
try:
412+
if disconnect_on_exception:
413+
with pytest.raises(MyException):
414+
async with async_timeout.timeout(TIMEOUT):
415+
await transport.wait_disconnected()
416+
else:
417+
with pytest.raises(asyncio.TimeoutError):
418+
async with async_timeout.timeout(TIMEOUT):
419+
await transport.wait_disconnected()
420+
finally:
421+
transport.disconnect(False)
422+
423+
424+
@pytest.mark.parametrize("disconnect_on_exception", [True, False],
425+
ids=["disconnect_on_exception", "no_disconnect_on_exception"])
426+
async def test_ws_on_frame_throw_server_side(disconnect_on_exception):
373427
class ServerClientListener(picows.WSListener):
374428
def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame):
375-
raise RuntimeError("exception from on_ws_frame")
429+
raise MyException("exception from server side on_ws_frame")
376430

377431
server = await picows.ws_create_server(lambda _: ServerClientListener(),
378432
"127.0.0.1",

0 commit comments

Comments
 (0)