@@ -1263,6 +1263,95 @@ class _TestSSL(tb.SSLTestCase):
12631263 PAYLOAD_SIZE = 1024 * 100
12641264 TIMEOUT = 60
12651265
1266+ def test_start_tls_buffer_transfer (self ):
1267+ if self .implementation == 'asyncio' :
1268+ raise unittest .SkipTest ()
1269+
1270+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
1271+ BUFFERED_MSG = b'buffered data before TLS'
1272+
1273+ server_context = self ._create_server_ssl_context (
1274+ self .ONLYCERT , self .ONLYKEY )
1275+ client_context = self ._create_client_ssl_context ()
1276+
1277+ async def handle_client (reader , writer ):
1278+ # Send data before TLS upgrade
1279+ writer .write (BUFFERED_MSG )
1280+ await writer .drain ()
1281+ await asyncio .sleep (0.2 )
1282+
1283+ # Read pre-TLS data
1284+ data = await reader .readexactly (len (HELLO_MSG ))
1285+ self .assertEqual (len (data ), len (HELLO_MSG ))
1286+
1287+ # Upgrade to TLS (server side)
1288+ try :
1289+ # We need the wait_for because the broken version hangs here
1290+ await asyncio .wait_for (writer .start_tls (server_context ),
1291+ timeout = 2
1292+ )
1293+ self .assertIsNotNone (writer .get_extra_info ('sslcontext' ))
1294+ except asyncio .TimeoutError :
1295+ self .assertIsNotNone (writer .get_extra_info ('sslcontext' ))
1296+
1297+ # Send/receive over TLS
1298+ writer .write (b'OK' )
1299+ await writer .drain ()
1300+
1301+ data = await reader .readexactly (len (HELLO_MSG ))
1302+ self .assertEqual (len (data ), len (HELLO_MSG ))
1303+
1304+ writer .close ()
1305+ await self .wait_closed (writer )
1306+
1307+ async def client (addr ):
1308+ # Use open_connection for StreamReader/StreamWriter
1309+ reader , writer = await asyncio .open_connection (* addr )
1310+
1311+ # Read buffered data before TLS
1312+ buffered = await reader .readexactly (len (BUFFERED_MSG ))
1313+ self .assertEqual (buffered , BUFFERED_MSG ,
1314+ "Client didn't receive buffered data before TLS upgrade" )
1315+
1316+ # Write before TLS upgrade
1317+ writer .write (HELLO_MSG )
1318+ await writer .drain ()
1319+
1320+ # Upgrade to TLS
1321+ try :
1322+ # We need the wait_for because the broken version hangs here
1323+ await asyncio .wait_for (writer .start_tls (client_context ),
1324+ timeout = 2
1325+ )
1326+ self .assertIsNotNone (writer .get_extra_info ('sslcontext' ))
1327+ except asyncio .TimeoutError :
1328+ self .assertIsNotNone (writer .get_extra_info ('sslcontext' ))
1329+
1330+ # Verify communication over TLS
1331+ tls_data = await reader .readexactly (2 )
1332+ self .assertEqual (tls_data , b'OK' ,
1333+ "Client didn't receive TLS response correctly" )
1334+
1335+ # Continue over TLS
1336+ writer .write (HELLO_MSG )
1337+ await writer .drain ()
1338+
1339+ writer .close ()
1340+ await self .wait_closed (writer )
1341+
1342+ async def run_test ():
1343+ srv = await asyncio .start_server (
1344+ handle_client , '127.0.0.1' , 0 , family = socket .AF_INET )
1345+
1346+ addr = srv .sockets [0 ].getsockname ()
1347+
1348+ await asyncio .wait_for (client (addr ), timeout = 10 )
1349+
1350+ srv .close ()
1351+ await srv .wait_closed ()
1352+
1353+ self .loop .run_until_complete (run_test ())
1354+
12661355 def test_create_server_ssl_1 (self ):
12671356 CNT = 0 # number of clients that were successful
12681357 TOTAL_CNT = 25 # total number of clients that test will create
0 commit comments