Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions CHANGES/11898.bugfix.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Restored :py:meth:`~aiohttp.BodyPartReader.decode` as a synchronous method
for backward compatibility. The method was inadvertently changed to async
in 3.13.3 as part of the decompression bomb security fix. A new
:py:meth:`~aiohttp.BodyPartReader.decode_async` method is now available
for non-blocking decompression of large payloads. Internal aiohttp code
uses the async variant to maintain security protections -- by :user:`bdraco`.
:py:meth:`~aiohttp.BodyPartReader.decode_iter` method is now available
for non-blocking decompression of large payloads using an async generator.
Internal aiohttp code uses the async variant to maintain security protections
-- by :user:`bdraco`.
41 changes: 24 additions & 17 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
import warnings
from collections import deque
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from types import TracebackType
from typing import TYPE_CHECKING, Any, Union, cast
from urllib.parse import parse_qsl, unquote, urlencode
Expand Down Expand Up @@ -314,7 +314,10 @@ async def read(self, *, decode: bool = False) -> bytes:
data.extend(await self.read_chunk(self.chunk_size))
# https://github.com/python/mypy/issues/17537
if decode: # type: ignore[unreachable]
return await self.decode_async(data)
decoded_data = bytearray()
async for d in self.decode_iter(data):
decoded_data.extend(d)
return decoded_data
return data

async def read_chunk(self, size: int = chunk_size) -> bytes:
Expand Down Expand Up @@ -509,15 +512,15 @@ def decode(self, data: bytes) -> bytes:
Decodes data according the specified Content-Encoding
or Content-Transfer-Encoding headers value.

Note: For large payloads, consider using decode_async() instead
Note: For large payloads, consider using decode_iter() instead
to avoid blocking the event loop during decompression.
"""
data = self._apply_content_transfer_decoding(data)
if self._needs_content_decoding():
return self._decode_content(data)
return data

async def decode_async(self, data: bytes) -> bytes:
async def decode_iter(self, data: bytes) -> AsyncIterator[bytes]:
"""Decodes data asynchronously.

Decodes data according the specified Content-Encoding
Expand All @@ -528,8 +531,10 @@ async def decode_async(self, data: bytes) -> bytes:
"""
data = self._apply_content_transfer_decoding(data)
if self._needs_content_decoding():
return await self._decode_content_async(data)
return data
async for d in self._decode_content_async(data):
yield d
else:
yield data

def _decode_content(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_ENCODING, "").lower()
Expand All @@ -543,17 +548,20 @@ def _decode_content(self, data: bytes) -> bytes:

raise RuntimeError(f"unknown content encoding: {encoding}")

async def _decode_content_async(self, data: bytes) -> bytes:
async def _decode_content_async(self, data: bytes) -> AsyncIterator[bytes]:
encoding = self.headers.get(CONTENT_ENCODING, "").lower()
if encoding == "identity":
return data
if encoding in {"deflate", "gzip"}:
return await ZLibDecompressor(
yield data
elif encoding in {"deflate", "gzip"}:
d = ZLibDecompressor(
encoding=encoding,
suppress_deflate_header=True,
).decompress(data, max_length=self._max_decompress_size)

raise RuntimeError(f"unknown content encoding: {encoding}")
)
yield await d.decompress(data, max_length=self._max_decompress_size)
while d.data_available:
yield await d.decompress(b"", max_length=self._max_decompress_size)
else:
raise RuntimeError(f"unknown content encoding: {encoding}")

def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
Expand Down Expand Up @@ -624,10 +632,9 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt

async def write(self, writer: AbstractStreamWriter) -> None:
field = self._value
chunk = await field.read_chunk(size=2**16)
while chunk:
await writer.write(await field.decode_async(chunk))
chunk = await field.read_chunk(size=2**16)
while chunk := await field.read_chunk(size=2**18):
async for d in field.decode_iter(chunk):
await writer.write(d)


class MultipartReader:
Expand Down
20 changes: 10 additions & 10 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,17 +714,17 @@ async def post(self) -> "MultiDictProxy[str | bytes | FileField]":
tmp = await self._loop.run_in_executor(
None, tempfile.TemporaryFile
)
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = await field.decode_async(chunk)
await self._loop.run_in_executor(None, tmp.write, chunk)
size += len(chunk)
if 0 < max_size < size:
await self._loop.run_in_executor(None, tmp.close)
raise HTTPRequestEntityTooLarge(
max_size=max_size, actual_size=size
while chunk := await field.read_chunk(size=2**18):
async for decoded_chunk in field.decode_iter(chunk):
await self._loop.run_in_executor(
None, tmp.write, decoded_chunk
)
chunk = await field.read_chunk(size=2**16)
size += len(decoded_chunk)
if 0 < max_size < size:
await self._loop.run_in_executor(None, tmp.close)
raise HTTPRequestEntityTooLarge(
max_size=max_size, actual_size=size
)
await self._loop.run_in_executor(None, tmp.seek, 0)

if field_ct is None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ async def test_decode_async_with_content_transfer_encoding_base64(self) -> None:
result = b""
while not obj.at_eof():
chunk = await obj.read_chunk(size=6)
result += await obj.decode_async(chunk)
async for decoded_chunk in obj.decode_iter(chunk):
result += decoded_chunk
assert b"Time to Relax!" == result

async def test_decode_with_content_encoding_deflate(self) -> None:
Expand Down
Loading