Skip to content

Commit 1b7c5dd

Browse files
committed
Transfer buffered data from old protocol to SSL layer in start_tls()
1 parent 74f4c96 commit 1b7c5dd

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

tests/test_tcp.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,95 @@ class _TestSSL(tb.SSLTestCase):
12631263
PAYLOAD_SIZE = 1024 * 100
12641264
TIMEOUT = 60
12651265

1266+
def test_start_tls_buffer_transfer(self):
1267+
if self.implementation == 'asyncio':
1268+
raise unittest.SkipTest()
1269+
1270+
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1271+
BUFFERED_MSG = b'buffered data before TLS'
1272+
1273+
server_context = self._create_server_ssl_context(
1274+
self.ONLYCERT, self.ONLYKEY)
1275+
client_context = self._create_client_ssl_context()
1276+
1277+
async def handle_client(reader, writer):
1278+
# Send data before TLS upgrade
1279+
writer.write(BUFFERED_MSG)
1280+
await writer.drain()
1281+
await asyncio.sleep(0.2)
1282+
1283+
# Read pre-TLS data
1284+
data = await reader.readexactly(len(HELLO_MSG))
1285+
self.assertEqual(len(data), len(HELLO_MSG))
1286+
1287+
# Upgrade to TLS (server side)
1288+
try:
1289+
# We need the wait_for because the broken version hangs here
1290+
await asyncio.wait_for(writer.start_tls(server_context),
1291+
timeout=2
1292+
)
1293+
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
1294+
except asyncio.TimeoutError:
1295+
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
1296+
1297+
# Send/receive over TLS
1298+
writer.write(b'OK')
1299+
await writer.drain()
1300+
1301+
data = await reader.readexactly(len(HELLO_MSG))
1302+
self.assertEqual(len(data), len(HELLO_MSG))
1303+
1304+
writer.close()
1305+
await self.wait_closed(writer)
1306+
1307+
async def client(addr):
1308+
# Use open_connection for StreamReader/StreamWriter
1309+
reader, writer = await asyncio.open_connection(*addr)
1310+
1311+
# Read buffered data before TLS
1312+
buffered = await reader.readexactly(len(BUFFERED_MSG))
1313+
self.assertEqual(buffered, BUFFERED_MSG,
1314+
"Client didn't receive buffered data before TLS upgrade")
1315+
1316+
# Write before TLS upgrade
1317+
writer.write(HELLO_MSG)
1318+
await writer.drain()
1319+
1320+
# Upgrade to TLS
1321+
try:
1322+
# We need the wait_for because the broken version hangs here
1323+
await asyncio.wait_for(writer.start_tls(client_context),
1324+
timeout=2
1325+
)
1326+
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
1327+
except asyncio.TimeoutError:
1328+
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
1329+
1330+
# Verify communication over TLS
1331+
tls_data = await reader.readexactly(2)
1332+
self.assertEqual(tls_data, b'OK',
1333+
"Client didn't receive TLS response correctly")
1334+
1335+
# Continue over TLS
1336+
writer.write(HELLO_MSG)
1337+
await writer.drain()
1338+
1339+
writer.close()
1340+
await self.wait_closed(writer)
1341+
1342+
async def run_test():
1343+
srv = await asyncio.start_server(
1344+
handle_client, '127.0.0.1', 0, family=socket.AF_INET)
1345+
1346+
addr = srv.sockets[0].getsockname()
1347+
1348+
await asyncio.wait_for(client(addr), timeout=10)
1349+
1350+
srv.close()
1351+
await srv.wait_closed()
1352+
1353+
self.loop.run_until_complete(run_test())
1354+
12661355
def test_create_server_ssl_1(self):
12671356
CNT = 0 # number of clients that were successful
12681357
TOTAL_CNT = 25 # total number of clients that test will create

uvloop/loop.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,6 +1616,19 @@ cdef class Loop:
16161616
ssl_shutdown_timeout=ssl_shutdown_timeout,
16171617
call_connection_made=False)
16181618

1619+
# Transfer buffered data from the old protocol to the new one.
1620+
if not hasattr(protocol, '_stream_reader'):
1621+
return
1622+
1623+
stream_reader = protocol._stream_reader
1624+
if stream_reader is None:
1625+
return
1626+
1627+
buffer = stream_reader._buffer
1628+
if buffer:
1629+
ssl_protocol._incoming.write(buffer)
1630+
buffer.clear()
1631+
16191632
# Pause early so that "ssl_protocol.data_received()" doesn't
16201633
# have a chance to get called before "ssl_protocol.connection_made()".
16211634
transport.pause_reading()

0 commit comments

Comments
 (0)