Skip to content

Commit 05acfd2

Browse files
dennisshengsym
authored and
sym
committed
fix eof_received
1 parent 7bb12a1 commit 05acfd2

File tree

3 files changed

+141
-10
lines changed

3 files changed

+141
-10
lines changed

tests/test_close_notify.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import asyncio
2+
import ssl
3+
import threading
4+
import time
5+
import unittest
6+
7+
from uvloop import _testbase as tb
8+
9+
10+
class TestCloseNotify(tb.SSLTestCase, tb.UVTestCase):
11+
12+
ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
13+
ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')
14+
15+
PAYLOAD_SIZE = 1024 * 50
16+
TIMEOUT = 10
17+
18+
HELLO_MSG = b'A' * PAYLOAD_SIZE
19+
END_MSG = b'THE END'
20+
21+
class ClientProto(asyncio.Protocol):
22+
23+
def __init__(self, conn_lost):
24+
self.transport = None
25+
self.conn_lost = conn_lost
26+
self.buffered_bytes = 0
27+
self.total_bytes = 0
28+
29+
def connection_made(self, tr):
30+
self.transport = tr
31+
32+
def data_received(self, data):
33+
self.buffered_bytes += len(data)
34+
self.total_bytes += len(data)
35+
36+
if self.transport.is_reading() and self.buffered_bytes >= TestCloseNotify.PAYLOAD_SIZE:
37+
print("app pause_reading")
38+
self.transport.pause_reading()
39+
40+
def eof_received(self):
41+
print("app eof_received")
42+
43+
def connection_lost(self, exc):
44+
print(f"finally received: {self.total_bytes}")
45+
self.conn_lost.set_result(None)
46+
47+
def test_close_notify(self):
48+
49+
conn_lost = self.loop.create_future()
50+
51+
def server(sock):
52+
53+
incoming = ssl.MemoryBIO()
54+
outgoing = ssl.MemoryBIO()
55+
56+
server_context = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
57+
sslobj = server_context.wrap_bio(incoming, outgoing, server_side=True)
58+
59+
while True:
60+
try:
61+
sslobj.do_handshake()
62+
except ssl.SSLWantReadError:
63+
if outgoing.pending:
64+
sock.send(outgoing.read())
65+
incoming.write(sock.recv(16384))
66+
else:
67+
if outgoing.pending:
68+
sock.send(outgoing.read())
69+
break
70+
71+
# first send: 1024 * 50 bytes
72+
sslobj.write(self.HELLO_MSG)
73+
sock.send(outgoing.read())
74+
75+
time.sleep(1)
76+
77+
# then send: 7 bytes
78+
sslobj.write(self.END_MSG)
79+
sock.send(outgoing.read())
80+
81+
# send close_notify but don't wait for response
82+
with self.assertRaises(ssl.SSLWantReadError):
83+
sslobj.unwrap()
84+
sock.send(outgoing.read())
85+
86+
sock.close()
87+
88+
async def client(addr):
89+
cp = TestCloseNotify.ClientProto(conn_lost)
90+
client_context = self._create_client_ssl_context()
91+
tr, proto = await self.loop.create_connection(lambda: cp, *addr, ssl=client_context)
92+
93+
# app read buffer and do some logic in 3 seconds
94+
await asyncio.sleep(3)
95+
cp.buffered_bytes = 0
96+
# app finish operation, resume reading more from buffer
97+
tr.resume_reading()
98+
99+
await asyncio.wait_for(conn_lost, timeout=self.TIMEOUT)
100+
await asyncio.sleep(3)
101+
tr.close()
102+
103+
test_server = self.tcp_server(server)
104+
port = test_server._sock.getsockname()[1]
105+
thread1 = threading.Thread(target=lambda : test_server.start())
106+
thread2 = threading.Thread(target=lambda : self.loop.run_until_complete(client(('127.0.0.1', port))))
107+
108+
thread1.start()
109+
thread2.start()
110+
111+
thread1.join()
112+
thread2.join()
113+
114+
115+
if __name__ == "__main__":
116+
unittest.main()

uvloop/sslproto.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ cdef class SSLProtocol:
101101
cdef _start_shutdown(self, object context=*)
102102
cdef _check_shutdown_timeout(self)
103103
cdef _do_read_into_void(self, object context)
104+
cdef _do_read_flush(self)
104105
cdef _do_flush(self, object context=*)
105106
cdef _do_shutdown(self, object context=*)
106107
cdef _on_shutdown_complete(self, shutdown_exc)

uvloop/sslproto.pyx

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -402,12 +402,17 @@ cdef class SSLProtocol:
402402
if self._state == DO_HANDSHAKE:
403403
self._on_handshake_complete(ConnectionResetError)
404404

405-
elif self._state == WRAPPED or self._state == FLUSHING:
406-
# We treat a low-level EOF as a critical situation similar to a
407-
# broken connection - just send whatever is in the buffer and
408-
# close. No application level eof_received() is called -
409-
# because we don't want the user to think that this is a
410-
# graceful shutdown triggered by SSL "close_notify".
405+
elif self._state == WRAPPED:
406+
self._set_state(FLUSHING)
407+
if self._app_reading_paused:
408+
return True
409+
else:
410+
self._do_read_flush()
411+
412+
elif self._state == FLUSHING:
413+
self._do_write()
414+
self._process_outgoing()
415+
self._control_app_writing()
411416
self._set_state(SHUTDOWN)
412417
self._on_shutdown_complete(None)
413418

@@ -443,9 +448,6 @@ cdef class SSLProtocol:
443448
elif self._state == WRAPPED and new_state == FLUSHING:
444449
allowed = True
445450

446-
elif self._state == WRAPPED and new_state == SHUTDOWN:
447-
allowed = True
448-
449451
elif self._state == FLUSHING and new_state == SHUTDOWN:
450452
allowed = True
451453

@@ -597,6 +599,11 @@ cdef class SSLProtocol:
597599
if close_notify:
598600
self._call_eof_received(context)
599601

602+
cdef _do_read_flush(self):
603+
self._do_read()
604+
self._set_state(SHUTDOWN)
605+
self._on_shutdown_complete(None)
606+
600607
cdef _do_flush(self, object context=None):
601608
"""Flush the write backlog, discarding new data received.
602609
@@ -701,7 +708,7 @@ cdef class SSLProtocol:
701708
# Incoming flow
702709

703710
cdef _do_read(self):
704-
if self._state != WRAPPED:
711+
if self._state != WRAPPED and self._state != FLUSHING:
705712
return
706713
try:
707714
if not self._app_reading_paused:
@@ -885,6 +892,13 @@ cdef class SSLProtocol:
885892
<method_t>self._do_read,
886893
context,
887894
self))
895+
elif self._state == FLUSHING:
896+
self._loop._call_soon_handle(
897+
new_MethodHandle(self._loop,
898+
"SSLProtocol._do_read_flush",
899+
<method_t> self._do_read_flush,
900+
context,
901+
self))
888902

889903
# Flow control for reads from SSL socket
890904

0 commit comments

Comments
 (0)