Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES/11766.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Added ``RequestKey`` and ``ResponseKey`` classes,
which enable static type checking for request & response
context storages similarly to ``AppKey``
-- by :user:`gsoldatov`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Gennady Andreyev
Georges Dubus
Greg Holt
Gregory Haynes
Grigoriy Soldatov
Gus Goulart
Gustavo Carneiro
Günther Jena
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def __init__(

def __init_subclass__(cls: type["ClientSession"]) -> None:
raise TypeError(
f"Inheritance class {cls.__name__} from ClientSession " "is forbidden"
f"Inheritance class {cls.__name__} from ClientSession is forbidden"
)

def __del__(self, _warnings: Any = warnings) -> None:
Expand Down
61 changes: 54 additions & 7 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import time
import warnings
import weakref
from collections import namedtuple
from collections import deque, namedtuple
from collections.abc import Callable, Iterable, Iterator, Mapping
from contextlib import suppress
from email.message import EmailMessage
Expand Down Expand Up @@ -834,8 +834,11 @@ def set_exception(


@functools.total_ordering
class AppKey(Generic[_T]):
"""Keys for static typing support in Application."""
class BaseKey(Generic[_T]):
"""Base for concrete context storage key classes.

Each storage is provided with its own sub-class for the sake of some additional type safety.
"""

__slots__ = ("_name", "_t", "__orig_class__")

Expand All @@ -861,9 +864,9 @@ def __init__(self, name: str, t: type[_T] | None = None):
self._t = t

def __lt__(self, other: object) -> bool:
if isinstance(other, AppKey):
if isinstance(other, BaseKey):
return self._name < other._name
return True # Order AppKey above other types.
return True # Order BaseKey above other types.

def __repr__(self) -> str:
t = self._t
Expand All @@ -881,7 +884,19 @@ def __repr__(self) -> str:
t_repr = f"{t.__module__}.{t.__qualname__}"
else:
t_repr = repr(t) # type: ignore[unreachable]
return f"<AppKey({self._name}, type={t_repr})>"
return f"<{self.__class__.__name__}({self._name}, type={t_repr})>"


class AppKey(BaseKey[_T]):
"""Keys for static typing support in Application."""


class RequestKey(BaseKey[_T]):
"""Keys for static typing support in Request."""


class ResponseKey(BaseKey[_T]):
"""Keys for static typing support in Response."""


@final
Expand All @@ -893,7 +908,7 @@ def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None:

def __init_subclass__(cls) -> None:
raise TypeError(
f"Inheritance class {cls.__name__} from ChainMapProxy " "is forbidden"
f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden"
)

@overload # type: ignore[override]
Expand Down Expand Up @@ -1106,3 +1121,35 @@ def should_remove_content_length(method: str, code: int) -> bool:
return code in EMPTY_BODY_STATUS_CODES or (
200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
)


class DebounceException(Exception):
"""Raised, when `DebounceContextManager` enter limit is exceeded."""


class DebounceContextManager:
"""Limits the number of times its context can be entered over an interval.

:param max_entries: Number of times the class context can be entered.
:params interval: Time interval in seconds, for which entry limit
is counted.
:raises DebounceException: if number of times the context is entered
over the specified interval.
"""

def __init__(self, max_entries: int, interval: int | float) -> None:
self._max_entries = max_entries
self._interval = interval
self._timestamps = deque()

def __enter__(self) -> None:
now = time.monotonic()
interval_start = now - self._interval
while self._timestamps and self._timestamps[0] <= interval_start:
self._timestamps.popleft()
if len(self._timestamps) >= self._max_entries:
raise DebounceException
self._timestamps.append(now)

def __exit__(self, exc_type, exc_value, traceback) -> None:
pass
4 changes: 3 additions & 1 deletion aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, cast

from .abc import AbstractAccessLogger
from .helpers import AppKey
from .helpers import AppKey, RequestKey, ResponseKey
from .log import access_logger
from .typedefs import PathLike
from .web_app import Application, CleanupError
Expand Down Expand Up @@ -203,11 +203,13 @@
"BaseRequest",
"FileField",
"Request",
"RequestKey",
# web_response
"ContentCoding",
"Response",
"StreamResponse",
"json_response",
"ResponseKey",
# web_routedef
"AbstractRouteDef",
"RouteDef",
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(

def __init_subclass__(cls: type["Application"]) -> None:
raise TypeError(
f"Inheritance class {cls.__name__} from web.Application " "is forbidden"
f"Inheritance class {cls.__name__} from web.Application is forbidden"
)

# MutableMapping API
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ async def finish_response(
self.log_exception("Missing return statement on request handler") # type: ignore[unreachable]
else:
self.log_exception(
"Web-handler should return a response instance, " f"got {resp!r}"
f"Web-handler should return a response instance, got {resp!r}"
)
exc = HTTPInternalServerError()
resp = Response(
Expand Down
48 changes: 41 additions & 7 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import sys
import tempfile
import types
import warnings
from collections.abc import Iterator, Mapping, MutableMapping
from re import Pattern
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Final, Optional, cast
from typing import TYPE_CHECKING, Any, Final, Optional, TypeVar, cast, overload
from urllib.parse import parse_qsl

from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
Expand All @@ -24,8 +25,11 @@
ETAG_ANY,
LIST_QUOTED_ETAG_RE,
ChainMapProxy,
DebounceContextManager,
DebounceException,
ETag,
HeadersMixin,
RequestKey,
frozen_dataclass_decorator,
is_expected_content_type,
parse_http_date,
Expand All @@ -48,6 +52,7 @@
HTTPBadRequest,
HTTPRequestEntityTooLarge,
HTTPUnsupportedMediaType,
NotAppKeyWarning,
)
from .web_response import StreamResponse

Expand All @@ -65,6 +70,9 @@
from .web_urldispatcher import UrlMappingMatchInfo


_T = TypeVar("_T")


@frozen_dataclass_decorator
class FileField:
name: str
Expand Down Expand Up @@ -101,7 +109,7 @@
############################################################


class BaseRequest(MutableMapping[str, Any], HeadersMixin):
class BaseRequest(MutableMapping[str | RequestKey[Any], Any], HeadersMixin):
POST_METHODS = {
hdrs.METH_PATCH,
hdrs.METH_POST,
Expand All @@ -113,6 +121,8 @@
_post: MultiDictProxy[str | bytes | FileField] | None = None
_read_bytes: bytes | None = None

_warning_debounce_context_manager = DebounceContextManager(10, 600)

def __init__(
self,
message: RawRequestMessage,
Expand All @@ -123,7 +133,7 @@
loop: asyncio.AbstractEventLoop,
*,
client_max_size: int = 1024**2,
state: dict[str, Any] | None = None,
state: dict[RequestKey[Any] | str, Any] | None = None,
scheme: str | None = None,
host: str | None = None,
remote: str | None = None,
Expand Down Expand Up @@ -253,19 +263,43 @@

# MutableMapping API

def __getitem__(self, key: str) -> Any:
@overload # type: ignore[override]
def __getitem__(self, key: RequestKey[_T]) -> _T: ...

@overload
def __getitem__(self, key: str) -> Any: ...

def __getitem__(self, key: str | RequestKey[_T]) -> Any:
return self._state[key]

def __setitem__(self, key: str, value: Any) -> None:
@overload # type: ignore[override]
def __setitem__(self, key: RequestKey[_T], value: _T) -> None: ...

@overload
def __setitem__(self, key: str, value: Any) -> None: ...

def __setitem__(self, key: str | RequestKey[_T], value: Any) -> None:
try:
if not isinstance(key, RequestKey):
with BaseRequest._warning_debounce_context_manager:
warnings.warn(
"It is recommended to use web.RequestKey instances for keys.\n"
+ "https://docs.aiohttp.org/en/stable/web_advanced.html"
+ "#request-s-storage",
category=NotAppKeyWarning,
stacklevel=2,
)
except DebounceException:
pass
self._state[key] = value

def __delitem__(self, key: str) -> None:
def __delitem__(self, key: str | RequestKey[_T]) -> None:
del self._state[key]

def __len__(self) -> int:
return len(self._state)

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[str | RequestKey[Any]]:
return iter(self._state)

########
Expand Down
48 changes: 41 additions & 7 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Iterator, MutableMapping
from concurrent.futures import Executor
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast, overload

from multidict import CIMultiDict, istr

Expand All @@ -19,8 +19,11 @@
ETAG_ANY,
QUOTED_ETAG_RE,
CookieMixin,
DebounceContextManager,
DebounceException,
ETag,
HeadersMixin,
ResponseKey,
must_be_empty_body,
parse_http_date,
populate_with_cookies,
Expand All @@ -32,6 +35,7 @@
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
from .web_exceptions import NotAppKeyWarning

REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus}
LARGE_BODY_SIZE = 1024**2
Expand All @@ -43,6 +47,9 @@
from .web_request import BaseRequest


_T = TypeVar("_T")


# TODO(py311): Convert to StrEnum for wider use
class ContentCoding(enum.Enum):
# The content codings that we have support for.
Expand All @@ -61,7 +68,9 @@
############################################################


class StreamResponse(MutableMapping[str, Any], HeadersMixin, CookieMixin):
class StreamResponse(
MutableMapping[str | ResponseKey[Any], Any], HeadersMixin, CookieMixin
):

_body: None | bytes | bytearray | Payload
_length_check = True
Expand All @@ -77,6 +86,7 @@
_must_be_empty_body: bool | None = None
_body_length = 0
_send_headers_immediately = True
_warning_debounce_context_manager = DebounceContextManager(10, 600)

def __init__(
self,
Expand All @@ -93,7 +103,7 @@
the headers when creating a new response object. It is not intended
to be used by external code.
"""
self._state: dict[str, Any] = {}
self._state: dict[str | ResponseKey[Any], Any] = {}

if _real_headers is not None:
self._headers = _real_headers
Expand Down Expand Up @@ -483,19 +493,43 @@
info = "not prepared"
return f"<{self.__class__.__name__} {self.reason} {info}>"

def __getitem__(self, key: str) -> Any:
@overload # type: ignore[override]
def __getitem__(self, key: ResponseKey[_T]) -> _T: ...

@overload
def __getitem__(self, key: str) -> Any: ...

def __getitem__(self, key: str | ResponseKey[_T]) -> Any:
return self._state[key]

def __setitem__(self, key: str, value: Any) -> None:
@overload # type: ignore[override]
def __setitem__(self, key: ResponseKey[_T], value: _T) -> None: ...

@overload
def __setitem__(self, key: str, value: Any) -> None: ...

def __setitem__(self, key: str | ResponseKey[_T], value: Any) -> None:
try:
with StreamResponse._warning_debounce_context_manager:
if not isinstance(key, ResponseKey):
warnings.warn(
"It is recommended to use web.ResponseKey instances for keys.\n"
+ "https://docs.aiohttp.org/en/stable/web_advanced.html"
+ "#response-s-storage",
category=NotAppKeyWarning,
stacklevel=2,
)
except DebounceException:
pass
self._state[key] = value

def __delitem__(self, key: str) -> None:
def __delitem__(self, key: str | ResponseKey[_T]) -> None:
del self._state[key]

def __len__(self) -> int:
return len(self._state)

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[str | ResponseKey[Any]]:
return iter(self._state)

def __hash__(self) -> int:
Expand Down
Loading
Loading