Skip to content

Commit a32cac8

Browse files
committed
Address python/cpython#118950 in uvloop by porting fix and adding tests to ensure asyncio.streams code effectively can schedule connection_lost and raise ConnectionResetError
1 parent 7bb12a1 commit a32cac8

File tree

3 files changed

+116
-2
lines changed

3 files changed

+116
-2
lines changed

tests/test_aiohttp.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
skip_tests = False
88

99
import asyncio
10+
import os
1011
import sys
1112
import unittest
1213
import weakref
1314

1415
from uvloop import _testbase as tb
1516

1617

17-
class _TestAioHTTP:
18+
class _TestAioHTTP(tb.SSLTestCase):
1819

1920
def test_aiohttp_basic_1(self):
2021

@@ -115,6 +116,64 @@ async def stop():
115116

116117
self.loop.run_until_complete(stop())
117118

119+
def test_aiohttp_connection_lost_when_busy(self):
120+
if self.implementation == 'asyncio':
121+
raise unittest.SkipTest('bug in asyncio #118950 tests in CPython.')
122+
123+
cert = tb._cert_fullname(__file__, 'ssl_cert.pem')
124+
key = tb._cert_fullname(__file__, 'ssl_key.pem')
125+
ssl_context = self._create_server_ssl_context(cert, key)
126+
client_ssl_context = self._create_client_ssl_context()
127+
128+
asyncio.set_event_loop(self.loop)
129+
app = aiohttp.web.Application()
130+
131+
async def handler(request):
132+
ws = aiohttp.web.WebSocketResponse()
133+
await ws.prepare(request)
134+
async for msg in ws:
135+
print("Received:", msg.data)
136+
return ws
137+
138+
app.router.add_get('/', handler)
139+
140+
runner = aiohttp.web.AppRunner(app)
141+
self.loop.run_until_complete(runner.setup())
142+
host = '0.0.0.0'
143+
site = aiohttp.web.TCPSite(runner, host, '0', ssl_context=ssl_context)
144+
self.loop.run_until_complete(site.start())
145+
port = site._server.sockets[0].getsockname()[1]
146+
session = aiohttp.ClientSession(loop=self.loop)
147+
148+
async def test():
149+
async with session.ws_connect(
150+
f"wss://{host}:{port}/",
151+
ssl=client_ssl_context
152+
) as ws:
153+
transport = ws._writer.transport
154+
s = transport.get_extra_info('socket')
155+
156+
if self.implementation == 'asyncio':
157+
s._sock.close()
158+
else:
159+
os.close(s.fileno())
160+
161+
# FLOW_CONTROL_HIGH_WATER * 1024
162+
bytes_to_send = 64 * 1024
163+
iterations = 10
164+
msg = b'Hello world, still there?'
165+
166+
# Send enough messages to trigger a socket write + one extra
167+
for _ in range(iterations + 1):
168+
await ws.send_bytes(
169+
msg * ((bytes_to_send // len(msg)) // iterations))
170+
171+
self.assertRaises(
172+
ConnectionResetError, self.loop.run_until_complete, test())
173+
174+
self.loop.run_until_complete(session.close())
175+
self.loop.run_until_complete(runner.cleanup())
176+
118177

119178
@unittest.skipIf(skip_tests, "no aiohttp module")
120179
class Test_UV_AioHTTP(_TestAioHTTP, tb.UVTestCase):

tests/test_tcp.py

+52
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import asyncio.sslproto
3+
import contextlib
34
import gc
45
import os
56
import select
@@ -3192,6 +3193,57 @@ async def run_main():
31923193

31933194
self.loop.run_until_complete(run_main())
31943195

3196+
def test_connection_lost_when_busy(self):
3197+
if self.implementation == 'asyncio':
3198+
raise unittest.SkipTest('bug in asyncio #118950 tests in CPython.')
3199+
3200+
ssl_context = self._create_server_ssl_context(
3201+
self.ONLYCERT, self.ONLYKEY)
3202+
client_ssl_context = self._create_client_ssl_context()
3203+
port = tb.find_free_port()
3204+
3205+
@contextlib.asynccontextmanager
3206+
async def server():
3207+
async def client_handler(reader, writer):
3208+
...
3209+
3210+
srv = await asyncio.start_server(
3211+
client_handler, '0.0.0.0',
3212+
port, ssl=ssl_context, reuse_port=True)
3213+
3214+
try:
3215+
yield
3216+
finally:
3217+
srv.close()
3218+
3219+
async def client():
3220+
reader, writer = await asyncio.open_connection(
3221+
'0.0.0.0', port, ssl=client_ssl_context)
3222+
transport = writer.transport
3223+
s = transport.get_extra_info('socket')
3224+
3225+
if self.implementation == 'asyncio':
3226+
s._sock.close()
3227+
else:
3228+
os.close(s.fileno())
3229+
3230+
# FLOW_CONTROL_HIGH_WATER * 1024
3231+
bytes_to_send = 64 * 1024
3232+
iterations = 10
3233+
msg = b'An really important message :)'
3234+
3235+
# Busy drain loop
3236+
for _ in range(iterations + 1):
3237+
writer.write(msg * ((bytes_to_send // len(msg)) // iterations))
3238+
await writer.drain()
3239+
3240+
async def test():
3241+
async with server():
3242+
await client()
3243+
3244+
self.assertRaises(
3245+
ConnectionResetError, self.loop.run_until_complete, test())
3246+
31953247

31963248
class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
31973249
pass

uvloop/sslproto.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ cdef class _SSLProtocolTransport:
3737
return self._ssl_protocol._app_protocol
3838

3939
def is_closing(self):
40-
return self._closed
40+
return self._closed or self._ssl_protocol._is_transport_closing()
4141

4242
def close(self):
4343
"""Close the transport.
@@ -316,6 +316,9 @@ cdef class SSLProtocol:
316316
self._app_transport_created = True
317317
return self._app_transport
318318

319+
def _is_transport_closing(self):
320+
return self._transport is not None and self._transport.is_closing()
321+
319322
def connection_made(self, transport):
320323
"""Called when the low-level connection is made.
321324

0 commit comments

Comments
 (0)