diff --git a/aiohttp/_cparser.pxd b/aiohttp/_cparser.pxd index 1b3be6d4efb..cc7ef58d664 100644 --- a/aiohttp/_cparser.pxd +++ b/aiohttp/_cparser.pxd @@ -145,6 +145,7 @@ cdef extern from "llhttp.h": int llhttp_should_keep_alive(const llhttp_t* parser) + void llhttp_resume(llhttp_t* parser) void llhttp_resume_after_upgrade(llhttp_t* parser) llhttp_errno_t llhttp_get_errno(const llhttp_t* parser) diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index f7c393ed42a..83afac5d3b5 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -291,15 +291,19 @@ cdef class HttpParser: bint _response_with_body bint _read_until_eof + bytes _tail bint _started object _url bytearray _buf str _path str _reason list _headers + bint _last_had_more_data list _raw_headers bint _upgraded list _messages + bint _more_data_available + bint _paused object _payload bint _payload_error object _payload_exception @@ -345,6 +349,9 @@ cdef class HttpParser: self._timer = timer self._buf = bytearray() + self._last_had_more_data = False + self._more_data_available = False + self._paused = False self._payload = None self._payload_error = 0 self._payload_exception = payload_exception @@ -352,6 +359,7 @@ cdef class HttpParser: self._raw_name = EMPTY_BYTES self._raw_value = EMPTY_BYTES + self._tail = b"" self._has_value = False self._header_name_size = 0 @@ -503,6 +511,10 @@ cdef class HttpParser: ### Public API ### + def pause_reading(self): + assert self._payload is not None + self._paused = True + def feed_eof(self): cdef bytes desc @@ -529,6 +541,23 @@ cdef class HttpParser: size_t nb cdef cparser.llhttp_errno_t errno + if self._tail: + data, self._tail = self._tail + data, EMPTY_BYTES + + had_more_data = self._more_data_available + if self._more_data_available: + result = cb_on_body(self._cparser, EMPTY_BYTES, 0) + if result is cparser.HPE_PAUSED: + self._last_had_more_data = had_more_data + self._tail = data + return (), False, EMPTY_BYTES + # TODO: Do we need to handle error case (-1)? + # If the last pause had more data, then we probably paused at the + # end of the body. Therefore we need to continue with empty bytes. + if not data and not self._last_had_more_data: + return (), False, EMPTY_BYTES + self._last_had_more_data = False + PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE) data_len = self.py_buf.len @@ -539,12 +568,15 @@ cdef class HttpParser: if errno is cparser.HPE_PAUSED_UPGRADE: cparser.llhttp_resume_after_upgrade(self._cparser) - nb = cparser.llhttp_get_error_pos(self._cparser) - self.py_buf.buf + elif errno is cparser.HPE_PAUSED: + cparser.llhttp_resume(self._cparser) + pos = cparser.llhttp_get_error_pos(self._cparser) - self.py_buf.buf + self._tail = data[pos:] PyBuffer_Release(&self.py_buf) - if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE): + if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED, cparser.HPE_PAUSED_UPGRADE): if self._payload_error == 0: if self._last_error is not None: ex = self._last_error @@ -569,7 +601,7 @@ cdef class HttpParser: if self._upgraded: return messages, True, data[nb:] else: - return messages, False, b"" + return messages, False, EMPTY_BYTES def set_upgraded(self, val): self._upgraded = val @@ -762,19 +794,26 @@ cdef int cb_on_body(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data cdef bytes body = at[:length] - try: - pyparser._payload.feed_data(body) - except BaseException as underlying_exc: - reraised_exc = underlying_exc - if pyparser._payload_exception is not None: - reraised_exc = pyparser._payload_exception(str(underlying_exc)) - - set_exception(pyparser._payload, reraised_exc, underlying_exc) - - pyparser._payload_error = 1 - return -1 - else: - return 0 + while body or pyparser._more_data_available: + try: + pyparser._more_data_available = pyparser._payload.feed_data(body) + except BaseException as underlying_exc: + reraised_exc = underlying_exc + if pyparser._payload_exception is not None: + reraised_exc = pyparser._payload_exception(str(underlying_exc)) + + set_exception(pyparser._payload, reraised_exc, underlying_exc) + + pyparser._payload_error = 1 + pyparser._paused = False + return -1 + body = EMPTY_BYTES + + if pyparser._paused: + pyparser._paused = False + return cparser.HPE_PAUSED + pyparser._paused = False + return 0 cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1: diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index 7f01830f4e9..179391cd3a0 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,26 +1,35 @@ import asyncio -from typing import cast +from typing import TYPE_CHECKING, Any, cast from .client_exceptions import ClientConnectionResetError from .helpers import set_exception from .tcp_helpers import tcp_nodelay +if TYPE_CHECKING: + from .http_parser import HttpParser + class BaseProtocol(asyncio.Protocol): __slots__ = ( "_loop", "_paused", + "_parser", "_drain_waiter", "_connection_lost", "_reading_paused", + "_upgraded", "transport", ) - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, loop: asyncio.AbstractEventLoop, parser: "HttpParser[Any] | None" = None + ) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False self._drain_waiter: asyncio.Future[None] | None = None self._reading_paused = False + self._parser = parser + self._upgraded = False self.transport: asyncio.Transport | None = None @@ -48,15 +57,27 @@ def resume_writing(self) -> None: waiter.set_result(None) def pause_reading(self) -> None: - if not self._reading_paused and self.transport is not None: + self._reading_paused = True + # Parser shouldn't be paused on websockets. + if not self._upgraded: + assert self._parser is not None + self._parser.pause_reading() + if self.transport is not None: try: self.transport.pause_reading() except (AttributeError, NotImplementedError, RuntimeError): pass - self._reading_paused = True - def resume_reading(self) -> None: - if self._reading_paused and self.transport is not None: + def resume_reading(self, resume_parser: bool = True) -> None: + self._reading_paused = False + + # This will resume parsing any unprocessed data from the last pause. + if not self._upgraded and resume_parser: + self.data_received(b"") + + # Reading may have been paused again in the above call if there was a lot of + # compressed data still pending. + if not self._reading_paused and self.transport is not None: try: self.transport.resume_reading() except (AttributeError, NotImplementedError, RuntimeError): diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 601b545c82a..fd1fc2a0ad1 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -26,7 +26,7 @@ class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamRe """Helper class to adapt between Protocol and StreamReader.""" def __init__(self, loop: asyncio.AbstractEventLoop) -> None: - BaseProtocol.__init__(self, loop=loop) + BaseProtocol.__init__(self, loop=loop, parser=None) DataQueue.__init__(self, loop) self._should_close = False @@ -36,10 +36,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._payload_parser: WebSocketReader | None = None self._timer = None - self._tail = b"" - self._upgraded = False - self._parser: HttpResponseParser | None = None self._read_timeout: float | None = None self._read_timeout_handle: asyncio.TimerHandle | None = None @@ -190,8 +187,8 @@ def pause_reading(self) -> None: super().pause_reading() self._drop_timeout() - def resume_reading(self) -> None: - super().resume_reading() + def resume_reading(self, resume_parser: bool = True) -> None: + super().resume_reading(resume_parser) self._reschedule_timeout() def set_exception( @@ -293,9 +290,6 @@ def _on_read_timeout(self) -> None: def data_received(self, data: bytes) -> None: self._reschedule_timeout() - if not data: - return - # custom payload parser - currently always WebSocketReader if self._payload_parser is not None: eof, tail = self._payload_parser.feed_data(data) diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 0bc4a30d8ed..0e3aa99fd71 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -34,7 +34,9 @@ MAX_SYNC_CHUNK_SIZE = 4096 -DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB +# Matches the max size we receive from sockets: +# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766 +DEFAULT_MAX_DECOMPRESS_SIZE = 256 * 1024 # Unlimited decompression constants - different libraries use different conventions ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited @@ -53,6 +55,9 @@ def flush(self, length: int = ..., /) -> bytes: ... @property def eof(self) -> bool: ... + @property + def unconsumed_tail(self) -> bytes: ... + class ZLibBackendProtocol(Protocol): MAX_WBITS: int @@ -179,6 +184,11 @@ async def decompress( ) return self.decompress_sync(data, max_length) + @property + @abstractmethod + def data_available(self) -> bool: + """Return True if more output is available by passing b"".""" + class ZLibCompressor: def __init__( @@ -271,7 +281,9 @@ def __init__( def decompress_sync( self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: - return self._decompressor.decompress(data, max_length) + return self._decompressor.decompress( + self._decompressor.unconsumed_tail + data, max_length + ) def flush(self, length: int = 0) -> bytes: return ( @@ -280,6 +292,10 @@ def flush(self, length: int = 0) -> bytes: else self._decompressor.flush() ) + @property + def data_available(self) -> bool: + return bool(self._decompressor.unconsumed_tail) + @property def eof(self) -> bool: return self._decompressor.eof @@ -301,6 +317,7 @@ def __init__( "Please install `Brotli` module" ) self._obj = brotli.Decompressor() + self._last_empty = False super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) def decompress_sync( @@ -308,8 +325,12 @@ def decompress_sync( ) -> bytes: """Decompress the given data.""" if hasattr(self._obj, "decompress"): - return cast(bytes, self._obj.decompress(data, max_length)) - return cast(bytes, self._obj.process(data, max_length)) + result = cast(bytes, self._obj.decompress(data, max_length)) + else: + result = cast(bytes, self._obj.process(data, max_length)) + # Only way to know that brotli has no further data is checking we get no output + self._last_empty = result == b"" + return result def flush(self) -> bytes: """Flush the decompressor.""" @@ -317,6 +338,10 @@ def flush(self) -> bytes: return cast(bytes, self._obj.flush()) return b"" + @property + def data_available(self) -> bool: + return not self._obj.is_finished() and not self._last_empty + class ZSTDDecompressor(DecompressionBaseHandler): def __init__( @@ -346,3 +371,7 @@ def decompress_sync( def flush(self) -> bytes: return b"" + + @property + def data_available(self) -> bool: + return not self._obj.needs_input and not self._obj.eof diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index cf3c05434c5..95d0d6373ae 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -73,10 +73,6 @@ class ContentLengthError(PayloadEncodingError): """Not enough data to satisfy content length header.""" -class DecompressSizeError(PayloadEncodingError): - """Decompressed size exceeds the configured limit.""" - - class LineTooLong(BadHttpMessage): def __init__(self, line: bytes, limit: int) -> None: super().__init__(f"Got more than {limit} bytes when reading: {line!r}.") diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index c5560b7a5ac..864a877b6fc 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -35,7 +35,6 @@ BadStatusLine, ContentEncodingError, ContentLengthError, - DecompressSizeError, InvalidHeader, InvalidURLError, LineTooLong, @@ -100,6 +99,12 @@ class RawResponseMessage(NamedTuple): _MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) +class PayloadState(IntEnum): + PAYLOAD_COMPLETE = 0 + PAYLOAD_NEEDS_INPUT = 1 + PAYLOAD_HAS_PENDING_INPUT = 2 + + class ParseState(IntEnum): PARSE_NONE = 0 PARSE_LENGTH = 1 @@ -236,6 +241,7 @@ def __init__( self._upgraded = False self._payload = None self._payload_parser: HttpPayloadParser | None = None + self._payload_has_more_data = False self._auto_decompress = auto_decompress self._limit = limit self._headers_parser = HeadersParser(max_field_size, self.lax) @@ -246,6 +252,10 @@ def parse_message(self, lines: list[bytes]) -> _MsgT: ... @abc.abstractmethod def _is_chunked_te(self, te: str) -> bool: ... + def pause_reading(self) -> None: + assert self._payload_parser is not None + self._payload_parser.pause_reading() + def feed_eof(self) -> _MsgT | None: if self._payload_parser is not None: self._payload_parser.feed_eof() @@ -282,7 +292,7 @@ def feed_data( max_line_length = self.max_line_size should_close = False - while start_pos < data_len: + while start_pos < data_len or self._payload_has_more_data: # read HTTP message (request/response line + headers), \r\n\r\n # and split by lines if self._payload_parser is None and not self._upgraded: @@ -441,11 +451,13 @@ def get_content_length() -> int | None: break # feed payload - elif data and start_pos < data_len: + elif self._payload_has_more_data or (data and start_pos < data_len): assert not self._lines assert self._payload_parser is not None try: - eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) + payload_state, data = self._payload_parser.feed_data( + data[start_pos:], SEP + ) except BaseException as underlying_exc: reraised_exc = underlying_exc if self.payload_exception is not None: @@ -457,18 +469,25 @@ def get_content_length() -> int | None: underlying_exc, ) - eof = True + payload_state = PayloadState.PAYLOAD_COMPLETE data = b"" if isinstance( underlying_exc, (InvalidHeader, TransferEncodingError) ): raise - if eof: - start_pos = 0 - data_len = len(data) - self._payload_parser = None - continue + self._payload_has_more_data = ( + payload_state == PayloadState.PAYLOAD_HAS_PENDING_INPUT + ) + + if payload_state is not PayloadState.PAYLOAD_COMPLETE: + # We've either consumed all available data, or we're pausing + # until the reader buffer is freed up. + break + + start_pos = 0 + data_len = len(data) + self._payload_parser = None else: break @@ -751,6 +770,7 @@ def __init__( max_trailers: int = 128, ) -> None: self._length = 0 + self._paused = False self._type = ParseState.PARSE_UNTIL_EOF self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 @@ -761,6 +781,7 @@ def __init__( self._max_line_size = max_line_size self._max_field_size = max_field_size self._max_trailers = max_trailers + self._more_data_available = False self._trailer_lines: list[bytes] = [] self.done = False @@ -789,6 +810,9 @@ def __init__( self.payload = real_payload + def pause_reading(self) -> None: + self._paused = True + def feed_eof(self) -> None: if self._type == ParseState.PARSE_UNTIL_EOF: self.payload.feed_eof() @@ -803,32 +827,52 @@ def feed_eof(self) -> None: def feed_data( self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";" - ) -> tuple[bool, bytes]: + ) -> tuple[PayloadState, bytes]: + """Receive a chunk of data to process. + + Return: + PayloadState - The current state of payload processing. + This function may be called with empty bytes after returning + PAYLOAD_HAS_PENDING_INPUT to continue processing after a pause. + bytes - If payload is complete, this is the unconsumed bytes intended for the + next message/payload, b"" otherwise. + """ # Read specified amount of bytes if self._type == ParseState.PARSE_LENGTH: + if self._chunk_tail: + chunk = self._chunk_tail + chunk + self._chunk_tail = b"" + required = self._length self._length = max(required - len(chunk), 0) - self.payload.feed_data(chunk[:required]) + self._more_data_available = self.payload.feed_data(chunk[:required]) + while self._more_data_available: + if self._paused: + self._paused = False + self._chunk_tail = chunk[required:] + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + self._more_data_available = self.payload.feed_data(b"") + if self._length == 0: self.payload.feed_eof() - return True, chunk[required:] - + return PayloadState.PAYLOAD_COMPLETE, chunk[required:] # Chunked transfer encoding parser elif self._type == ParseState.PARSE_CHUNKED: if self._chunk_tail: - # We should never have a tail if we're inside the payload body. - assert self._chunk != ChunkState.PARSE_CHUNKED_CHUNK - # We should check the length is sane. - max_line_length = self._max_line_size - if self._chunk == ChunkState.PARSE_TRAILERS: - max_line_length = self._max_field_size - if len(self._chunk_tail) > max_line_length: - raise LineTooLong(self._chunk_tail[:100] + b"...", max_line_length) + # We should check the length is sane when not processing payload body. + if self._chunk != ChunkState.PARSE_CHUNKED_CHUNK: + max_line_length = self._max_line_size + if self._chunk == ChunkState.PARSE_TRAILERS: + max_line_length = self._max_field_size + if len(self._chunk_tail) > max_line_length: + raise LineTooLong( + self._chunk_tail[:100] + b"...", max_line_length + ) chunk = self._chunk_tail + chunk self._chunk_tail = b"" - while chunk: + while chunk or self._more_data_available: # read next chunk size if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: pos = chunk.find(SEP) @@ -868,17 +912,26 @@ def feed_data( self.payload.begin_http_chunk_receiving() else: self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" # read chunk and feed buffer if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: + if self._paused: + self._paused = False + self._chunk_tail = chunk + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + required = self._chunk_size self._chunk_size = max(required - len(chunk), 0) - self.payload.feed_data(chunk[:required]) + self._more_data_available = self.payload.feed_data(chunk[:required]) + chunk = chunk[required:] + + if self._more_data_available: + continue if self._chunk_size: - return False, b"" - chunk = chunk[required:] + self._paused = False + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF self.payload.end_http_chunk_receiving() @@ -891,13 +944,13 @@ def feed_data( self._chunk = ChunkState.PARSE_CHUNKED_SIZE else: self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" if self._chunk == ChunkState.PARSE_TRAILERS: pos = chunk.find(SEP) if pos < 0: # No line found self._chunk_tail = chunk - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" line = chunk[:pos] chunk = chunk[pos + len(SEP) :] @@ -923,13 +976,18 @@ def feed_data( finally: self._trailer_lines.clear() self.payload.feed_eof() - return True, chunk + return PayloadState.PAYLOAD_COMPLETE, chunk # Read all bytes until eof elif self._type == ParseState.PARSE_UNTIL_EOF: - self.payload.feed_data(chunk) + self._more_data_available = self.payload.feed_data(chunk) + while self._more_data_available: + if self._paused: + self._paused = False + return PayloadState.PAYLOAD_HAS_PENDING_INPUT, b"" + self._more_data_available = self.payload.feed_data(b"") - return False, b"" + return PayloadState.PAYLOAD_NEEDS_INPUT, b"" class DeflateBuffer: @@ -974,10 +1032,8 @@ def set_exception( ) -> None: set_exception(self.out, exc, exc_cause) - def feed_data(self, chunk: bytes) -> None: - if not chunk: - return - + def feed_data(self, chunk: bytes) -> bool: + """Return True if more data is available and this method should be called again with b"".""" self.size += len(chunk) self.out.total_compressed_bytes = self.size @@ -996,9 +1052,8 @@ def feed_data(self, chunk: bytes) -> None: ) try: - # Decompress with limit + 1 so we can detect if output exceeds limit chunk = self.decompressor.decompress_sync( - chunk, max_length=self._max_decompress_size + 1 + chunk, max_length=self._max_decompress_size ) except Exception: raise ContentEncodingError( @@ -1007,15 +1062,9 @@ def feed_data(self, chunk: bytes) -> None: self._started_decoding = True - # Check if decompression limit was exceeded - if len(chunk) > self._max_decompress_size: - raise DecompressSizeError( - "Decompressed data exceeds the configured limit of %d bytes" - % self._max_decompress_size - ) - if chunk: self.out.feed_data(chunk) + return self.decompressor.data_available def feed_eof(self) -> None: chunk = self.decompressor.flush() diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 97fdae77d87..dac71484261 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -1,6 +1,7 @@ import base64 import binascii import json +import math import re import sys import uuid @@ -39,6 +40,7 @@ payload_type, ) from .streams import StreamReader +from .web_exceptions import HTTPRequestEntityTooLarge if sys.version_info >= (3, 11): from typing import Self @@ -267,6 +269,7 @@ def __init__( subtype: str = "mixed", default_charset: str | None = None, max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, + client_max_size: int = math.inf, ) -> None: self.headers = headers self._boundary = boundary @@ -284,6 +287,7 @@ def __init__( self._content_eof = 0 self._cache: dict[str, Any] = {} self._max_decompress_size = max_decompress_size + self._client_max_size = client_max_size def __aiter__(self) -> Self: return self @@ -312,11 +316,19 @@ async def read(self, *, decode: bool = False) -> bytes: data = bytearray() while not self._at_eof: data.extend(await self.read_chunk(self.chunk_size)) + if len(data) > self._client_max_size: + raise HTTPRequestEntityTooLarge( + max_size=self._client_max_size, actual_size=len(data) + ) # https://github.com/python/mypy/issues/17537 if decode: # type: ignore[unreachable] decoded_data = bytearray() async for d in self.decode_iter(data): decoded_data.extend(d) + if len(decoded_data) > self._client_max_size: + raise HTTPRequestEntityTooLarge( + max_size=self._client_max_size, actual_size=len(decoded_data) + ) return decoded_data return data @@ -558,6 +570,8 @@ async def _decode_content_async(self, data: bytes) -> AsyncIterator[bytes]: suppress_deflate_header=True, ) yield await d.decompress(data, max_length=self._max_decompress_size) + while d.data_available: + yield await d.decompress(b"", max_length=self._max_decompress_size) else: raise RuntimeError(f"unknown content encoding: {encoding}") @@ -646,7 +660,13 @@ class MultipartReader: #: Body part reader class for non multipart/* content types. part_reader_cls = BodyPartReader - def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: + def __init__( + self, + headers: Mapping[str, str], + content: StreamReader, + *, + client_max_size: int = math.inf, + ) -> None: self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) assert self._mimetype.type == "multipart", "multipart/* content type expected" if "boundary" not in self._mimetype.parameters: @@ -656,6 +676,7 @@ def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: self.headers = headers self._boundary = ("--" + self._get_boundary()).encode() + self._client_max_size = client_max_size self._content = content self._default_charset: str | None = None self._last_part: MultipartReader | BodyPartReader | None = None @@ -758,8 +779,12 @@ def _get_part_reader( if mimetype.type == "multipart": if self.multipart_reader_cls is None: - return type(self)(headers, self._content) - return self.multipart_reader_cls(headers, self._content) + return type(self)( + headers, self._content, client_max_size=self._client_max_size + ) + return self.multipart_reader_cls( + headers, self._content, client_max_size=self._client_max_size + ) else: return self.part_reader_cls( self._boundary, @@ -767,6 +792,7 @@ def _get_part_reader( self._content, subtype=self._mimetype.subtype, default_charset=self._default_charset, + client_max_size=self._client_max_size, ) def _get_boundary(self) -> str: diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 034fcc540c0..123c974970d 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -218,8 +218,8 @@ def feed_eof(self) -> None: self._eof_waiter = None set_result(waiter, None) - if self._protocol._reading_paused: - self._protocol.resume_reading() + # At EOF the parser is done, there won't be unprocessed data. + self._protocol.resume_reading(resume_parser=False) for cb in self._eof_callbacks: try: @@ -273,11 +273,11 @@ def unread_data(self, data: bytes) -> None: self._buffer.appendleft(data) self._eof_counter = 0 - def feed_data(self, data: bytes) -> None: + def feed_data(self, data: bytes) -> bool: assert not self._eof, "feed_data after feed_eof" if not data: - return + return False data_len = len(data) self._size += data_len @@ -289,8 +289,9 @@ def feed_data(self, data: bytes) -> None: self._waiter = None set_result(waiter, None) - if self._size > self._high_water and not self._protocol._reading_paused: + if self._size > self._high_water: self._protocol.pause_reading() + return False def begin_http_chunk_receiving(self) -> None: if self._http_chunk_splits is None: @@ -327,10 +328,7 @@ def end_http_chunk_receiving(self) -> None: # If we get too many small chunks before self._high_water is reached, then any # .read() call becomes computationally expensive, and could block the event loop # for too long, hence an additional self._high_water_chunks here. - if ( - len(self._http_chunk_splits) > self._high_water_chunks - and not self._protocol._reading_paused - ): + if len(self._http_chunk_splits) > self._high_water_chunks: self._protocol.pause_reading() # wake up readchunk when end of http chunk received @@ -527,13 +525,9 @@ def _read_nowait_chunk(self, n: int) -> bytes: while chunk_splits and chunk_splits[0] < self._cursor: chunk_splits.popleft() - if ( - self._protocol._reading_paused - and self._size < self._low_water - and ( - self._http_chunk_splits is None - or len(self._http_chunk_splits) < self._low_water_chunks - ) + if self._size < self._low_water and ( + self._http_chunk_splits is None + or len(self._http_chunk_splits) < self._low_water_chunks ): self._protocol.resume_reading() return data @@ -593,8 +587,8 @@ def at_eof(self) -> bool: async def wait_eof(self) -> None: return - def feed_data(self, data: bytes) -> None: - pass + def feed_data(self, data: bytes) -> bool: + return False async def readline(self) -> bytes: return b"" diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index bd39c48050d..973d6ba298a 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -22,6 +22,7 @@ HttpVersion10, RawRequestMessage, StreamWriter, + WebSocketReader, ) from .http_exceptions import BadHttpMethod from .log import access_logger, server_logger @@ -168,9 +169,7 @@ class RequestHandler(BaseProtocol, Generic[_Request]): "_handler_waiter", "_waiter", "_task_handler", - "_upgrade", "_payload_parser", - "_request_parser", "logger", "access_log", "access_logger", @@ -203,7 +202,17 @@ def __init__( auto_decompress: bool = True, timeout_ceil_threshold: float = 5, ): - super().__init__(loop) + parser = HttpRequestParser( + self, + loop, + read_bufsize, + max_line_size=max_line_size, + max_field_size=max_field_size, + max_headers=max_headers, + payload_exception=RequestPayloadError, + auto_decompress=auto_decompress, + ) + super().__init__(loop, parser) # _request_count is the number of requests processed with the same connection. self._request_count = 0 @@ -230,19 +239,7 @@ def __init__( self._waiter: asyncio.Future[None] | None = None self._handler_waiter: asyncio.Future[None] | None = None self._task_handler: asyncio.Task[None] | None = None - - self._upgrade = False self._payload_parser: Any = None - self._request_parser: HttpRequestParser | None = HttpRequestParser( - self, - loop, - read_bufsize, - max_line_size=max_line_size, - max_field_size=max_field_size, - max_headers=max_headers, - payload_exception=RequestPayloadError, - auto_decompress=auto_decompress, - ) self._timeout_ceil_threshold: float = 5 try: @@ -383,7 +380,7 @@ def connection_lost(self, exc: BaseException | None) -> None: self._manager = None self._request_factory = None self._request_handler = None - self._request_parser = None + self._parser = None if self._keepalive_handle is not None: self._keepalive_handle.cancel() @@ -402,8 +399,7 @@ def connection_lost(self, exc: BaseException | None) -> None: self._payload_parser.feed_eof() self._payload_parser = None - def set_parser(self, parser: Any) -> None: - # Actual type is WebReader + def set_parser(self, parser: WebSocketReader) -> None: assert self._payload_parser is None self._payload_parser = parser @@ -420,10 +416,10 @@ def data_received(self, data: bytes) -> None: return # parse http messages messages: Sequence[_MsgType] - if self._payload_parser is None and not self._upgrade: - assert self._request_parser is not None + if self._payload_parser is None and not self._upgraded: + assert self._parser is not None try: - messages, upgraded, tail = self._request_parser.feed_data(data) + messages, upgraded, tail = self._parser.feed_data(data) except HttpProcessingError as exc: messages = [ (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) @@ -440,12 +436,12 @@ def data_received(self, data: bytes) -> None: # don't set result twice waiter.set_result(None) - self._upgrade = upgraded + self._upgraded = upgraded if upgraded and tail: self._message_tail = tail # no parser, just store - elif self._payload_parser is None and self._upgrade and data: + elif self._payload_parser is None and self._upgraded and data: self._message_tail += data # feed payload @@ -705,11 +701,11 @@ async def finish_response( prematurely. """ request._finish() - if self._request_parser is not None: - self._request_parser.set_upgraded(False) - self._upgrade = False + if self._parser is not None: + self._parser.set_upgraded(False) + self._upgraded = False if self._message_tail: - self._request_parser.feed_data(self._message_tail) + self._parser.feed_data(self._message_tail) self._message_tail = b"" try: prepare_meth = resp.prepare diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 260b822482d..ab25b7398c2 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -672,7 +672,9 @@ async def json( async def multipart(self) -> MultipartReader: """Return async iterator to process BODY as multipart.""" - return MultipartReader(self._headers, self._payload) + return MultipartReader( + self._headers, self._payload, client_max_size=self._client_max_size + ) async def post(self) -> "MultiDictProxy[str | bytes | FileField]": """Return POST parameters.""" diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index 713dba2d0c2..234e9927c02 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -5,6 +5,7 @@ import pytest from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_parser import HttpParser async def test_loop() -> None: @@ -26,33 +27,28 @@ async def test_pause_writing() -> None: async def test_pause_reading_no_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) - assert not pr._reading_paused + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) pr.pause_reading() - assert not pr._reading_paused + parser.pause_reading.assert_called_once() async def test_pause_reading_stub_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) tr = asyncio.Transport() pr.transport = tr assert not pr._reading_paused pr.pause_reading() assert pr._reading_paused - - -async def test_resume_reading_no_transport() -> None: - loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) - pr._reading_paused = True - pr.resume_reading() - assert pr._reading_paused + parser.pause_reading.assert_called_once() # type: ignore[unreachable] async def test_resume_reading_stub_transport() -> None: loop = asyncio.get_event_loop() - pr = BaseProtocol(loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + pr = BaseProtocol(loop, parser=parser) tr = asyncio.Transport() pr.transport = tr pr._reading_paused = True diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index fec3d3c6c3e..efded32169f 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -53,7 +53,6 @@ ) from aiohttp.client_reqrep import ClientRequest from aiohttp.compression_utils import DEFAULT_MAX_DECOMPRESS_SIZE -from aiohttp.http_exceptions import DecompressSizeError from aiohttp.payload import ( AsyncIterablePayload, BufferedReaderPayload, @@ -2386,10 +2385,9 @@ async def test_payload_decompress_size_limit(aiohttp_client: AiohttpClient) -> N When a compressed payload expands beyond the configured limit, we raise DecompressSizeError. """ - # Create a highly compressible payload that exceeds the decompression limit. - # 64MiB of repeated bytes compresses to ~32KB but expands beyond the - # 32MiB per-call limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload. + payload_size = 64 * 2**20 + original = b"A" * payload_size compressed = zlib.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2406,11 +2404,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size @pytest.mark.skipif(brotli is None, reason="brotli is not installed") @@ -2419,8 +2417,9 @@ async def test_payload_decompress_size_limit_brotli( ) -> None: """Test that brotli decompression size limit triggers DecompressSizeError.""" assert brotli is not None - # Create a highly compressible payload that exceeds the decompression limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload + payload_size = 64 * 2**20 + original = b"A" * payload_size compressed = brotli.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2436,11 +2435,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size @pytest.mark.skipif(ZstdCompressor is None, reason="backports.zstd is not installed") @@ -2449,8 +2448,9 @@ async def test_payload_decompress_size_limit_zstd( ) -> None: """Test that zstd decompression size limit triggers DecompressSizeError.""" assert ZstdCompressor is not None - # Create a highly compressible payload that exceeds the decompression limit. - original = b"A" * (64 * 2**20) + # Create a highly compressible payload. + payload_size = 64 * 2**20 + original = b"A" * payload_size compressor = ZstdCompressor() compressed = compressor.compress(original) + compressor.flush() assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE @@ -2467,11 +2467,11 @@ async def handler(request: web.Request) -> web.Response: async with client.get("/") as resp: assert resp.status == 200 - with pytest.raises(aiohttp.ClientPayloadError) as exc_info: - await resp.read() + received = 0 + async for chunk in resp.content.iter_chunked(1024): + received += len(chunk) - assert isinstance(exc_info.value.__cause__, DecompressSizeError) - assert "Decompressed data exceeds" in str(exc_info.value.__cause__) + assert received == payload_size async def test_bad_payload_chunked_encoding(aiohttp_client: AiohttpClient) -> None: diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 49a81c8dbb3..0a26a211453 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -10,7 +10,7 @@ from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientResponse from aiohttp.helpers import TimerNoop -from aiohttp.http_parser import RawResponseMessage +from aiohttp.http_parser import HttpParser, RawResponseMessage async def test_force_close(loop: asyncio.AbstractEventLoop) -> None: @@ -35,7 +35,9 @@ async def test_oserror(loop: asyncio.AbstractEventLoop) -> None: async def test_pause_resume_on_error(loop: asyncio.AbstractEventLoop) -> None: + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) proto = ResponseHandler(loop=loop) + proto._parser = parser transport = mock.Mock() proto.connection_made(transport) diff --git a/tests/test_flowcontrol_streams.py b/tests/test_flowcontrol_streams.py index 9e21f786610..3654ba4aad2 100644 --- a/tests/test_flowcontrol_streams.py +++ b/tests/test_flowcontrol_streams.py @@ -5,6 +5,7 @@ from aiohttp import streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_parser import HttpParser @pytest.fixture @@ -38,7 +39,6 @@ async def test_readline(self, stream: streams.StreamReader) -> None: stream.feed_data(b"d\n") res = await stream.readline() assert res == b"d\n" - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readline_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -51,7 +51,6 @@ async def test_readany(self, stream: streams.StreamReader) -> None: stream.feed_data(b"data") res = await stream.readany() assert res == b"data" - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readany_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -65,7 +64,6 @@ async def test_readchunk(self, stream: streams.StreamReader) -> None: res, end_of_http_chunk = await stream.readchunk() assert res == b"data" assert not end_of_http_chunk - assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readchunk_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True @@ -120,7 +118,8 @@ async def test_resumed_on_eof(self, stream: streams.StreamReader) -> None: async def test_stream_reader_eof_when_full() -> None: loop = asyncio.get_event_loop() - protocol = BaseProtocol(loop=loop) + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) + protocol = BaseProtocol(loop=loop, parser=parser) protocol.transport = asyncio.Transport() stream = streams.StreamReader(protocol, 1024, loop=loop) diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index bfd84aae0d2..191a9a4e8a5 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -4,7 +4,7 @@ import re import sys import zlib -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from contextlib import suppress from typing import Any from unittest import mock @@ -17,6 +17,7 @@ import aiohttp from aiohttp import http_exceptions, streams from aiohttp.base_protocol import BaseProtocol +from aiohttp.client_proto import ResponseHandler from aiohttp.helpers import NO_EXTENSIONS from aiohttp.http_parser import ( DeflateBuffer, @@ -27,8 +28,12 @@ HttpRequestParserPy, HttpResponseParser, HttpResponseParserPy, + PayloadState, ) from aiohttp.http_writer import HttpVersion +from aiohttp.web_protocol import RequestHandler +from aiohttp.web_request import Request +from aiohttp.web_server import Server try: try: @@ -56,9 +61,23 @@ RESPONSE_PARSERS.append(HttpResponseParserC) +@pytest.fixture +def server() -> Any: + return mock.create_autospec( + Server, + request_factory=mock.Mock(), + request_handler=mock.AsyncMock(), + instance=True, + ) + + @pytest.fixture def protocol() -> Any: - return mock.create_autospec(BaseProtocol, spec_set=True, instance=True) + return mock.create_autospec( + BaseProtocol, + spec_set=True, + instance=True, + ) def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: @@ -71,11 +90,13 @@ def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) def parser( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, -) -> HttpRequestParser: +) -> Iterator[HttpRequestParser]: + protocol = RequestHandler(server, loop=loop) + # Parser implementations - return request.param( # type: ignore[no-any-return] + parser = request.param( protocol, loop, 2**16, @@ -83,6 +104,10 @@ def parser( max_headers=128, max_field_size=8190, ) + protocol._force_close = False + protocol._parser = parser + with mock.patch.object(protocol, "transport", True): + yield parser @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) @@ -94,11 +119,12 @@ def request_cls(request: pytest.FixtureRequest) -> type[HttpRequestParser]: @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) def response( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, request: pytest.FixtureRequest, ) -> HttpResponseParser: + protocol = ResponseHandler(loop) + # Parser implementations - return request.param( # type: ignore[no-any-return] + parser = request.param( protocol, loop, 2**16, @@ -107,6 +133,8 @@ def response( max_field_size=8190, read_until_eof=True, ) + protocol._parser = parser + return parser # type: ignore[no-any-return] @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) @@ -154,9 +182,11 @@ def test_reject_obsolete_line_folding(parser: HttpRequestParser) -> None: @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_character( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserC( protocol, loop, @@ -164,6 +194,7 @@ def test_invalid_character( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"POST / HTTP/1.1\r\nHost: localhost:8080\r\nSet-Cookie: abc\x01def\r\n\r\n" error_detail = re.escape( r""": @@ -178,9 +209,11 @@ def test_invalid_character( @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_linebreak( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request: pytest.FixtureRequest, ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserC( protocol, loop, @@ -188,6 +221,7 @@ def test_invalid_linebreak( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"GET /world HTTP/1.1\r\nHost: 127.0.0.1\n\r\n" error_detail = re.escape( r""": @@ -244,8 +278,10 @@ def test_bad_headers(parser: HttpRequestParser, hdr: str) -> None: def test_unpaired_surrogate_in_header_py( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, server: Server[Request] ) -> None: + protocol = RequestHandler(server, loop=loop) + parser = HttpRequestParserPy( protocol, loop, @@ -253,6 +289,7 @@ def test_unpaired_surrogate_in_header_py( max_line_size=8190, max_field_size=8190, ) + protocol._parser = parser text = b"POST / HTTP/1.1\r\n\xff\r\n\r\n" message = None try: @@ -827,6 +864,24 @@ def test_max_header_value_size_under_limit(parser: HttpRequestParser) -> None: assert msg.url == URL("/test") +async def test_chunk_splits_after_pause(parser: HttpRequestParser) -> None: + text = ( + b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + + b"1\r\nb\r\n" * 50000 + + b"0\r\n\r\n" + ) + + messages, upgrade, tail = parser.feed_data(text) + payload = messages[0][-1] + # Payload should have paused reading and stopped receiving new chunks after 16k. + assert payload._http_chunk_splits is not None + assert len(payload._http_chunk_splits) == 16385 + # We should still get the full result after read(), as it will continue processing. + result = await payload.read() + assert len(result) == 50000 # Compare len first, as it's easier to debug in diff. + assert result == b"b" * 50000 + + @pytest.mark.parametrize("size", [40965, 8191]) def test_max_header_value_size_continuation( response: HttpResponseParser, size: int @@ -1246,8 +1301,10 @@ async def test_http_response_parser_bad_chunked_lax( @pytest.mark.dev_mode async def test_http_response_parser_bad_chunked_strict_py( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, ) -> None: + protocol = ResponseHandler(loop) + response = HttpResponseParserPy( protocol, loop, @@ -1255,6 +1312,7 @@ async def test_http_response_parser_bad_chunked_strict_py( max_line_size=8190, max_field_size=8190, ) + protocol._parser = response text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) @@ -1268,8 +1326,10 @@ async def test_http_response_parser_bad_chunked_strict_py( reason="C based HTTP parser not available", ) async def test_http_response_parser_bad_chunked_strict_c( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, ) -> None: + protocol = ResponseHandler(loop) + response = HttpResponseParserC( protocol, loop, @@ -1277,6 +1337,7 @@ async def test_http_response_parser_bad_chunked_strict_c( max_line_size=8190, max_field_size=8190, ) + protocol._parser = response text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) @@ -1427,10 +1488,12 @@ async def test_request_chunked_reject_bad_trailer(parser: HttpRequestParser) -> def test_parse_no_length_or_te_on_post( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, + server: Server[Request], request_cls: type[HttpRequestParser], ) -> None: + protocol = RequestHandler(server, loop=loop) parser = request_cls(protocol, loop, limit=2**16) + protocol._parser = parser text = b"POST /test HTTP/1.1\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] @@ -1439,10 +1502,11 @@ def test_parse_no_length_or_te_on_post( def test_parse_payload_response_without_body( loop: asyncio.AbstractEventLoop, - protocol: BaseProtocol, response_cls: type[HttpResponseParser], ) -> None: + protocol = ResponseHandler(loop) parser = response_cls(protocol, loop, 2**16, response_with_body=False) + protocol._parser = parser text = b"HTTP/1.1 200 Ok\r\ncontent-length: 10\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] @@ -1703,8 +1767,10 @@ def test_parse_uri_utf8_percent_encoded(parser: HttpRequestParser) -> None: reason="C based HTTP parser not available", ) def test_parse_bad_method_for_c_parser_raises( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol + loop: asyncio.AbstractEventLoop, server: Server[Request] ) -> None: + protocol = RequestHandler(server, loop=loop) + payload = b"GET1 /test HTTP/1.1\r\n\r\n" parser = HttpRequestParserC( protocol, @@ -1714,6 +1780,7 @@ def test_parse_bad_method_for_c_parser_raises( max_headers=128, max_field_size=8190, ) + protocol._parser = parser with pytest.raises(aiohttp.http_exceptions.BadStatusLine): messages, upgrade, tail = parser.feed_data(payload) @@ -1818,8 +1885,8 @@ async def test_parse_chunked_payload_split_end_trailers4( async def test_http_payload_parser_length(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=2, headers_parser=HeadersParser()) - eof, tail = p.feed_data(b"1245") - assert eof + state, tail = p.feed_data(b"1245") + assert state is PayloadState.PAYLOAD_COMPLETE assert b"12" == out._buffer[0] assert b"45" == tail diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 5444817d5a4..0c266b48fc3 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -354,12 +354,17 @@ async def test_read_with_content_encoding_gzip(self) -> None: result = await obj.read(decode=True) assert b"Time to Relax!" == result + @pytest.mark.skipif(sys.version_info < (3, 11), reason="wbits not available") async def test_read_with_content_encoding_deflate(self) -> None: + content = b"A" * 1_000_000 # Large enough to exceed max_length. + compressed = ZLibBackend.compress(content, wbits=-ZLibBackend.MAX_WBITS) + h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) - with Stream(b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--") as stream: + with Stream(compressed + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) - assert b"Time to Relax!" == result + assert len(result) == len(content) # Simplifies diff on failure + assert result == content async def test_read_with_content_encoding_identity(self) -> None: thing = ( @@ -1721,6 +1726,28 @@ async def test_body_part_reader_payload_as_bytes() -> None: payload.decode() +async def test_body_part_reader_payload_write() -> None: + content = b"A" * 1_000_000 # Large enough to exceed max_length. + compressed = ZLibBackend.compress(content, wbits=-ZLibBackend.MAX_WBITS) + output = b"" + + async def write(inp: bytes) -> None: + nonlocal output + output += inp + + h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) + writer = mock.create_autospec( + AbstractStreamWriter, write=write, spec_set=True, instance=True + ) + with Stream(compressed + b"\r\n--:--") as stream: + body_part = aiohttp.BodyPartReader(BOUNDARY, h, stream) + payload = BodyPartReaderPayload(body_part) + await payload.write(writer) + + assert len(output) == len(content) # Simplifies diff on failure + assert output == content + + async def test_multipart_writer_close_with_exceptions() -> None: """Test that MultipartWriter.close() continues closing all parts even if one raises.""" writer = aiohttp.MultipartWriter() diff --git a/tests/test_streams.py b/tests/test_streams.py index e2fd1659191..3e2242a8768 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1109,7 +1109,7 @@ async def test_empty_stream_reader() -> None: assert s.set_exception(ValueError()) is None # type: ignore[func-returns-value] assert s.exception() is None assert s.feed_eof() is None # type: ignore[func-returns-value] - assert s.feed_data(b"data") is None # type: ignore[func-returns-value] + assert s.feed_data(b"data") is False assert s.at_eof() await s.wait_eof() assert await s.read() == b"" diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 71dc53b500e..dae505cb6d0 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -328,6 +328,27 @@ async def handler(request: web.Request) -> web.Response: resp.release() +async def test_multipart_client_max_size(aiohttp_client: AiohttpClient) -> None: + with multipart.MultipartWriter() as writer: + writer.append("A" * 1020) + + async def handler(request: web.Request) -> web.Response: + reader = await request.multipart() + assert isinstance(reader, multipart.MultipartReader) + + part = await reader.next() + assert isinstance(part, multipart.BodyPartReader) + await part.text() # Should raise HttpRequestEntityTooLarge + assert False + + app = web.Application(client_max_size=1000) + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + async with client.post("/", data=writer) as resp: + assert resp.status == 413 + + async def test_multipart_empty(aiohttp_client: AiohttpClient) -> None: with multipart.MultipartWriter() as writer: pass diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 6de09a2cb00..c965a9b4395 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -19,7 +19,7 @@ from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend, ZLibBackendWrapper -from aiohttp.http import WebSocketError, WSCloseCode, WSMsgType +from aiohttp.http import HttpParser, WebSocketError, WSCloseCode, WSMsgType from aiohttp.http_websocket import ( WebSocketReader, WSMessageBinary, @@ -113,8 +113,9 @@ def build_close_frame( @pytest.fixture() def protocol(loop: asyncio.AbstractEventLoop) -> BaseProtocol: + parser = mock.create_autospec(HttpParser, spec_set=True, instance=True) transport = mock.Mock(spec_set=asyncio.Transport) - protocol = BaseProtocol(loop) + protocol = BaseProtocol(loop, parser=parser) protocol.connection_made(transport) return protocol