Skip to content

Commit ef3f020

Browse files
committed
Improve test
1 parent b0e3501 commit ef3f020

File tree

1 file changed

+67
-52
lines changed

1 file changed

+67
-52
lines changed

tests/test_tcp.py

Lines changed: 67 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,14 +1264,20 @@ class _TestSSL(tb.SSLTestCase):
12641264
TIMEOUT = 60
12651265

12661266
def test_start_tls_buffer_transfer(self):
1267+
TIMEOUT = 10
1268+
12671269
if (
12681270
self.implementation == 'asyncio'
1269-
and sys.version_info[:2] <= (3, 11)
1271+
or sys.version_info[:2] < (3, 11)
12701272
):
12711273
# StreamWriter.start_tls() introduced in Python 3.11
12721274
raise unittest.SkipTest(
12731275
'StreamWriter.start_tls() not supported'
12741276
)
1277+
self.loop.set_exception_handler(lambda loop, ctx: None)
1278+
1279+
client_read_buffered = asyncio.Event()
1280+
server_sent_ok = asyncio.Event()
12751281

12761282
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
12771283
BUFFERED_MSG = b'buffered data before TLS'
@@ -1281,80 +1287,89 @@ def test_start_tls_buffer_transfer(self):
12811287
client_context = self._create_client_ssl_context()
12821288

12831289
async def handle_client(reader, writer):
1284-
# Send data before TLS upgrade
1285-
writer.write(BUFFERED_MSG)
1286-
await writer.drain()
1287-
await asyncio.sleep(0.2)
1290+
try:
1291+
# Send data before TLS upgrade
1292+
writer.write(BUFFERED_MSG)
1293+
await writer.drain()
12881294

1289-
# Read pre-TLS data
1290-
data = await reader.readexactly(len(HELLO_MSG))
1291-
self.assertEqual(len(data), len(HELLO_MSG))
1295+
await asyncio.wait_for(
1296+
client_read_buffered.wait(),
1297+
timeout=TIMEOUT
1298+
)
12921299

1293-
# Upgrade to TLS (server side)
1294-
try:
1300+
# Read pre-TLS data
1301+
data = await asyncio.wait_for(
1302+
reader.readexactly(len(HELLO_MSG)),
1303+
timeout=TIMEOUT,
1304+
)
1305+
self.assertEqual(len(data), len(HELLO_MSG))
1306+
1307+
# Upgrade to TLS (server side)
12951308
# We need the wait_for because the broken version hangs here
12961309
await asyncio.wait_for(
12971310
writer.start_tls(server_context),
1298-
timeout=2)
1299-
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
1300-
except asyncio.TimeoutError:
1311+
timeout=TIMEOUT,)
13011312
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
13021313

1303-
# Send/receive over TLS
1304-
writer.write(b'OK')
1305-
await writer.drain()
1306-
1307-
data = await reader.readexactly(len(HELLO_MSG))
1308-
self.assertEqual(len(data), len(HELLO_MSG))
1314+
# Send/receive over TLS
1315+
writer.write(b'OK')
1316+
await writer.drain()
1317+
server_sent_ok.set()
13091318

1310-
writer.close()
1311-
await self.wait_closed(writer)
1319+
data = await asyncio.wait_for(
1320+
reader.readexactly(len(HELLO_MSG)),
1321+
timeout=TIMEOUT,
1322+
)
1323+
self.assertEqual(len(data), len(HELLO_MSG))
1324+
finally:
1325+
if not writer.is_closing():
1326+
writer.close()
1327+
await self.wait_closed(writer)
13121328

13131329
async def client(addr):
13141330
# Use open_connection for StreamReader/StreamWriter
13151331
reader, writer = await asyncio.open_connection(*addr)
13161332

1317-
# Read buffered data before TLS
1318-
buffered = await reader.readexactly(len(BUFFERED_MSG))
1319-
self.assertEqual(buffered, BUFFERED_MSG,
1320-
"Wrong pre-TLS buffered data from server")
1321-
1322-
# Write before TLS upgrade
1323-
writer.write(HELLO_MSG)
1324-
await writer.drain()
1325-
1326-
# Upgrade to TLS
13271333
try:
1328-
# We need the wait_for because the broken version hangs here
1329-
await asyncio.wait_for(
1330-
writer.start_tls(client_context),
1331-
timeout=2)
1334+
# Read buffered data before TLS
1335+
buffered = await reader.readexactly(len(BUFFERED_MSG))
1336+
self.assertEqual(buffered, BUFFERED_MSG,
1337+
"Wrong pre-TLS buffered data from server")
1338+
client_read_buffered.set()
1339+
1340+
# Write before TLS upgrade
1341+
writer.write(HELLO_MSG)
1342+
await writer.drain()
1343+
1344+
# Upgrade to TLS
1345+
await writer.start_tls(client_context)
13321346
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
1333-
except asyncio.TimeoutError:
1334-
self.assertIsNotNone(writer.get_extra_info('sslcontext'))
1335-
1336-
# Verify communication over TLS
1337-
tls_data = await reader.readexactly(2)
1338-
self.assertEqual(tls_data, b'OK',
1339-
"Wrong data from server after TLS upgrade")
13401347

1341-
# Continue over TLS
1342-
writer.write(HELLO_MSG)
1343-
await writer.drain()
1348+
# Verify communication over TLS
1349+
await server_sent_ok.wait()
1350+
tls_data = await reader.readexactly(2)
1351+
self.assertEqual(tls_data, b'OK',
1352+
"Wrong data from server after TLS upgrade")
13441353

1345-
writer.close()
1346-
await self.wait_closed(writer)
1354+
# Continue over TLS
1355+
writer.write(HELLO_MSG)
1356+
await writer.drain()
1357+
finally:
1358+
if not writer.is_closing():
1359+
writer.close()
1360+
await self.wait_closed(writer)
13471361

13481362
async def run_test():
13491363
srv = await asyncio.start_server(
13501364
handle_client, '127.0.0.1', 0, family=socket.AF_INET)
13511365

1352-
addr = srv.sockets[0].getsockname()
1353-
1354-
await asyncio.wait_for(client(addr), timeout=10)
1366+
try:
1367+
addr = srv.sockets[0].getsockname()
13551368

1356-
srv.close()
1357-
await srv.wait_closed()
1369+
await asyncio.wait_for(client(addr), timeout=self.TIMEOUT)
1370+
finally:
1371+
srv.close()
1372+
await srv.wait_closed()
13581373

13591374
self.loop.run_until_complete(run_test())
13601375

0 commit comments

Comments
 (0)