11# Tests for aiohttp/http_writer.py
22import array
33import asyncio
4+ import zlib
5+ from typing import Iterable
46from unittest import mock
57
68import pytest
@@ -23,7 +25,12 @@ def transport(buf):
2325 def write (chunk ):
2426 buf .extend (chunk )
2527
28+ def writelines (chunks : Iterable [bytes ]) -> None :
29+ for chunk in chunks :
30+ buf .extend (chunk )
31+
2632 transport .write .side_effect = write
33+ transport .writelines .side_effect = writelines
2734 transport .is_closing .return_value = False
2835 return transport
2936
@@ -85,21 +92,53 @@ async def test_write_payload_length(protocol, transport, loop) -> None:
8592 assert b"da" == content .split (b"\r \n \r \n " , 1 )[- 1 ]
8693
8794
88- async def test_write_payload_chunked_filter (protocol , transport , loop ) -> None :
89- write = transport .write = mock .Mock ()
95+ async def test_write_large_payload_deflate_compression_data_in_eof (
96+ protocol : BaseProtocol ,
97+ transport : asyncio .Transport ,
98+ loop : asyncio .AbstractEventLoop ,
99+ ) -> None :
100+ msg = http .StreamWriter (protocol , loop )
101+ msg .enable_compression ("deflate" )
102+
103+ await msg .write (b"data" * 4096 )
104+ assert transport .write .called # type: ignore[attr-defined]
105+ chunks = [c [1 ][0 ] for c in list (transport .write .mock_calls )] # type: ignore[attr-defined]
106+ transport .write .reset_mock () # type: ignore[attr-defined]
107+ assert not transport .writelines .called # type: ignore[attr-defined]
90108
109+ # This payload compresses to 20447 bytes
110+ payload = b"" .join (
111+ [bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 ) for _ in range (64 )]
112+ )
113+ await msg .write_eof (payload )
114+ assert not transport .write .called # type: ignore[attr-defined]
115+ assert transport .writelines .called # type: ignore[attr-defined]
116+ chunks .extend (transport .writelines .mock_calls [0 ][1 ][0 ]) # type: ignore[attr-defined]
117+ content = b"" .join (chunks )
118+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
119+
120+
121+ async def test_write_payload_chunked_filter (
122+ protocol : BaseProtocol ,
123+ transport : asyncio .Transport ,
124+ loop : asyncio .AbstractEventLoop ,
125+ ) -> None :
91126 msg = http .StreamWriter (protocol , loop )
92127 msg .enable_chunking ()
93128 await msg .write (b"da" )
94129 await msg .write (b"ta" )
95130 await msg .write_eof ()
96131
97- content = b"" .join ([c [1 ][0 ] for c in list (write .mock_calls )])
132+ content = b"" .join ([b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )]) # type: ignore[attr-defined]
133+ content += b"" .join ([c [1 ][0 ] for c in list (transport .write .mock_calls )]) # type: ignore[attr-defined]
98134 assert content .endswith (b"2\r \n da\r \n 2\r \n ta\r \n 0\r \n \r \n " )
99135
100136
101- async def test_write_payload_chunked_filter_mutiple_chunks (protocol , transport , loop ):
102- write = transport .write = mock .Mock ()
137+ async def test_write_payload_chunked_filter_multiple_chunks (
138+ protocol : BaseProtocol ,
139+ transport : asyncio .Transport ,
140+ loop : asyncio .AbstractEventLoop ,
141+ ) -> None :
103142 msg = http .StreamWriter (protocol , loop )
104143 msg .enable_chunking ()
105144 await msg .write (b"da" )
@@ -108,14 +147,14 @@ async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport,
108147 await msg .write (b"at" )
109148 await msg .write (b"a2" )
110149 await msg .write_eof ()
111- content = b"" .join ([c [1 ][0 ] for c in list (write .mock_calls )])
150+ content = b"" .join ([b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )]) # type: ignore[attr-defined]
151+ content += b"" .join ([c [1 ][0 ] for c in list (transport .write .mock_calls )]) # type: ignore[attr-defined]
112152 assert content .endswith (
113153 b"2\r \n da\r \n 2\r \n ta\r \n 2\r \n 1d\r \n 2\r \n at\r \n 2\r \n a2\r \n 0\r \n \r \n "
114154 )
115155
116156
117157async def test_write_payload_deflate_compression (protocol , transport , loop ) -> None :
118-
119158 COMPRESSED = b"x\x9c KI,I\x04 \x00 \x04 \x00 \x01 \x9b "
120159 write = transport .write = mock .Mock ()
121160 msg = http .StreamWriter (protocol , loop )
@@ -129,7 +168,30 @@ async def test_write_payload_deflate_compression(protocol, transport, loop) -> N
129168 assert COMPRESSED == content .split (b"\r \n \r \n " , 1 )[- 1 ]
130169
131170
132- async def test_write_payload_deflate_and_chunked (buf , protocol , transport , loop ):
171+ async def test_write_payload_deflate_compression_chunked (
172+ protocol : BaseProtocol ,
173+ transport : asyncio .Transport ,
174+ loop : asyncio .AbstractEventLoop ,
175+ ) -> None :
176+ expected = b"2\r \n x\x9c \r \n a\r \n KI,I\x04 \x00 \x04 \x00 \x01 \x9b \r \n 0\r \n \r \n "
177+ msg = http .StreamWriter (protocol , loop )
178+ msg .enable_compression ("deflate" )
179+ msg .enable_chunking ()
180+ await msg .write (b"data" )
181+ await msg .write_eof ()
182+
183+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
184+ assert all (chunks )
185+ content = b"" .join (chunks )
186+ assert content == expected
187+
188+
189+ async def test_write_payload_deflate_and_chunked (
190+ buf : bytearray ,
191+ protocol : BaseProtocol ,
192+ transport : asyncio .Transport ,
193+ loop : asyncio .AbstractEventLoop ,
194+ ) -> None :
133195 msg = http .StreamWriter (protocol , loop )
134196 msg .enable_compression ("deflate" )
135197 msg .enable_chunking ()
@@ -142,8 +204,71 @@ async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop)
142204 assert thing == buf
143205
144206
145- async def test_write_payload_bytes_memoryview (buf , protocol , transport , loop ):
207+ async def test_write_payload_deflate_compression_chunked_data_in_eof (
208+ protocol : BaseProtocol ,
209+ transport : asyncio .Transport ,
210+ loop : asyncio .AbstractEventLoop ,
211+ ) -> None :
212+ expected = b"2\r \n x\x9c \r \n d\r \n KI,IL\xcd K\x01 \x00 \x0b @\x02 \xd2 \r \n 0\r \n \r \n "
213+ msg = http .StreamWriter (protocol , loop )
214+ msg .enable_compression ("deflate" )
215+ msg .enable_chunking ()
216+ await msg .write (b"data" )
217+ await msg .write_eof (b"end" )
218+
219+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
220+ assert all (chunks )
221+ content = b"" .join (chunks )
222+ assert content == expected
223+
224+
225+ async def test_write_large_payload_deflate_compression_chunked_data_in_eof (
226+ protocol : BaseProtocol ,
227+ transport : asyncio .Transport ,
228+ loop : asyncio .AbstractEventLoop ,
229+ ) -> None :
230+ msg = http .StreamWriter (protocol , loop )
231+ msg .enable_compression ("deflate" )
232+ msg .enable_chunking ()
233+
234+ await msg .write (b"data" * 4096 )
235+ # This payload compresses to 1111 bytes
236+ payload = b"" .join ([bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 )])
237+ await msg .write_eof (payload )
238+ assert not transport .write .called # type: ignore[attr-defined]
146239
240+ chunks = []
241+ for write_lines_call in transport .writelines .mock_calls : # type: ignore[attr-defined]
242+ chunked_payload = list (write_lines_call [1 ][0 ])[1 :]
243+ chunked_payload .pop ()
244+ chunks .extend (chunked_payload )
245+
246+ assert all (chunks )
247+ content = b"" .join (chunks )
248+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
249+
250+
251+ async def test_write_payload_deflate_compression_chunked_connection_lost (
252+ protocol : BaseProtocol ,
253+ transport : asyncio .Transport ,
254+ loop : asyncio .AbstractEventLoop ,
255+ ) -> None :
256+ msg = http .StreamWriter (protocol , loop )
257+ msg .enable_compression ("deflate" )
258+ msg .enable_chunking ()
259+ await msg .write (b"data" )
260+ with pytest .raises (
261+ ClientConnectionResetError , match = "Cannot write to closing transport"
262+ ), mock .patch .object (transport , "is_closing" , return_value = True ):
263+ await msg .write_eof (b"end" )
264+
265+
266+ async def test_write_payload_bytes_memoryview (
267+ buf : bytearray ,
268+ protocol : BaseProtocol ,
269+ transport : asyncio .Transport ,
270+ loop : asyncio .AbstractEventLoop ,
271+ ) -> None :
147272 msg = http .StreamWriter (protocol , loop )
148273
149274 mv = memoryview (b"abcd" )
0 commit comments