Skip to content

Commit 3803ff4

Browse files
committed
Address python/cpython#118950 in uvloop by porting fix and adding tests to ensure asyncio.streams code effectively can schedule connection_lost and raise ConnectionResetError
1 parent 7bb12a1 commit 3803ff4

File tree

3 files changed

+107
-3
lines changed

3 files changed

+107
-3
lines changed

tests/test_aiohttp.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
skip_tests = False
88

99
import asyncio
10+
import os
1011
import sys
1112
import unittest
1213
import weakref
1314

1415
from uvloop import _testbase as tb
1516

1617

17-
class _TestAioHTTP:
18+
class _TestAioHTTP(tb.SSLTestCase):
1819

1920
def test_aiohttp_basic_1(self):
2021

@@ -95,7 +96,7 @@ async def on_shutdown(app):
9596
async def client():
9697
async with aiohttp.ClientSession() as client:
9798
async with client.ws_connect(
98-
'http://127.0.0.1:{}'.format(port)) as ws:
99+
'http://127.0.0.1:{}'.format(port)) as ws:
99100
await ws.send_str("hello")
100101
async for msg in ws:
101102
assert msg.data == "hello"
@@ -115,6 +116,59 @@ async def stop():
115116

116117
self.loop.run_until_complete(stop())
117118

119+
def test_aiohttp_connection_lost_when_busy(self):
120+
if self.implementation == 'asyncio':
121+
raise unittest.SkipTest('bug in asyncio #118950, tests in CPython.')
122+
123+
cert = tb._cert_fullname(__file__, 'ssl_cert.pem')
124+
key = tb._cert_fullname(__file__, 'ssl_key.pem')
125+
ssl_context = self._create_server_ssl_context(cert, key)
126+
client_ssl_context = self._create_client_ssl_context()
127+
128+
asyncio.set_event_loop(self.loop)
129+
app = aiohttp.web.Application()
130+
131+
async def handler(request):
132+
ws = aiohttp.web.WebSocketResponse()
133+
await ws.prepare(request)
134+
async for msg in ws:
135+
print("Received:", msg.data)
136+
return ws
137+
138+
app.router.add_get('/', handler)
139+
140+
runner = aiohttp.web.AppRunner(app)
141+
self.loop.run_until_complete(runner.setup())
142+
host = '0.0.0.0'
143+
site = aiohttp.web.TCPSite(runner, host, '0', ssl_context=ssl_context)
144+
self.loop.run_until_complete(site.start())
145+
port = site._server.sockets[0].getsockname()[1]
146+
session = aiohttp.ClientSession(loop=self.loop)
147+
148+
async def test():
149+
async with session.ws_connect(f"wss://{host}:{port}/", ssl=client_ssl_context) as ws:
150+
transport = ws._writer.transport
151+
s = transport.get_extra_info('socket')
152+
153+
if self.implementation == 'asyncio':
154+
s._sock.close()
155+
else:
156+
os.close(s.fileno())
157+
158+
# FLOW_CONTROL_HIGH_WATER * 1024
159+
bytes_to_send = 64 * 1024
160+
iterations = 10
161+
msg = b'Hello world, still there?'
162+
163+
# Send enough messages to trigger a socket write + one extra
164+
for _ in range(iterations + 1):
165+
await ws.send_bytes(msg * ((bytes_to_send // len(msg)) // iterations))
166+
167+
self.assertRaises(ConnectionResetError, self.loop.run_until_complete, test())
168+
169+
self.loop.run_until_complete(session.close())
170+
self.loop.run_until_complete(runner.cleanup())
171+
118172

119173
@unittest.skipIf(skip_tests, "no aiohttp module")
120174
class Test_UV_AioHTTP(_TestAioHTTP, tb.UVTestCase):

tests/test_tcp.py

+47
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import asyncio.sslproto
3+
import contextlib
34
import gc
45
import os
56
import select
@@ -3192,6 +3193,52 @@ async def run_main():
31923193

31933194
self.loop.run_until_complete(run_main())
31943195

3196+
def test_connection_lost_when_busy(self):
3197+
if self.implementation == 'asyncio':
3198+
raise unittest.SkipTest('bug in asyncio #118950, tests in CPython.')
3199+
3200+
ssl_context = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
3201+
client_ssl_context = self._create_client_ssl_context()
3202+
port = tb.find_free_port()
3203+
3204+
@contextlib.asynccontextmanager
3205+
async def server():
3206+
async def client_handler(reader, writer):
3207+
...
3208+
3209+
srv = await asyncio.start_server(client_handler, '0.0.0.0', port, ssl=ssl_context, reuse_port=True)
3210+
3211+
try:
3212+
yield
3213+
finally:
3214+
srv.close()
3215+
3216+
async def client():
3217+
reader, writer = await asyncio.open_connection('0.0.0.0', port, ssl=client_ssl_context)
3218+
transport = writer.transport
3219+
s = transport.get_extra_info('socket')
3220+
3221+
if self.implementation == 'asyncio':
3222+
s._sock.close()
3223+
else:
3224+
os.close(s.fileno())
3225+
3226+
# FLOW_CONTROL_HIGH_WATER * 1024
3227+
bytes_to_send = 64 * 1024
3228+
iterations = 10
3229+
msg = b'An really important message :)'
3230+
3231+
# Busy drain loop
3232+
for _ in range(iterations + 1):
3233+
writer.write(msg * ((bytes_to_send // len(msg)) // iterations))
3234+
await writer.drain()
3235+
3236+
async def test():
3237+
async with server():
3238+
await client()
3239+
3240+
self.assertRaises(ConnectionResetError, self.loop.run_until_complete, test())
3241+
31953242

31963243
class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
31973244
pass

uvloop/sslproto.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ cdef class _SSLProtocolTransport:
3737
return self._ssl_protocol._app_protocol
3838

3939
def is_closing(self):
40-
return self._closed
40+
return self._closed or self._ssl_protocol._is_transport_closing()
4141

4242
def close(self):
4343
"""Close the transport.
@@ -316,6 +316,9 @@ cdef class SSLProtocol:
316316
self._app_transport_created = True
317317
return self._app_transport
318318

319+
def _is_transport_closing(self):
320+
return self._transport is not None and self._transport.is_closing()
321+
319322
def connection_made(self, transport):
320323
"""Called when the low-level connection is made.
321324

0 commit comments

Comments
 (0)