Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES/11766.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Added ``RequestKey`` and ``ResponseKey`` classes,
which enable static type checking for request & response
context storages similarly to ``AppKey``
-- by :user:`gsoldatov`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Gennady Andreyev
Georges Dubus
Greg Holt
Gregory Haynes
Grigoriy Soldatov
Gus Goulart
Gustavo Carneiro
Günther Jena
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def __init__(

def __init_subclass__(cls: type["ClientSession"]) -> None:
raise TypeError(
f"Inheritance class {cls.__name__} from ClientSession " "is forbidden"
f"Inheritance class {cls.__name__} from ClientSession is forbidden"
)

def __del__(self, _warnings: Any = warnings) -> None:
Expand Down
27 changes: 21 additions & 6 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,11 @@ def set_exception(


@functools.total_ordering
class AppKey(Generic[_T]):
"""Keys for static typing support in Application."""
class BaseKey(Generic[_T]):
"""Base for concrete context storage key classes.

Each storage is provided with its own sub-class for the sake of some additional type safety.
"""

__slots__ = ("_name", "_t", "__orig_class__")

Expand All @@ -861,9 +864,9 @@ def __init__(self, name: str, t: type[_T] | None = None):
self._t = t

def __lt__(self, other: object) -> bool:
if isinstance(other, AppKey):
if isinstance(other, BaseKey):
return self._name < other._name
return True # Order AppKey above other types.
return True # Order BaseKey above other types.

def __repr__(self) -> str:
t = self._t
Expand All @@ -881,7 +884,19 @@ def __repr__(self) -> str:
t_repr = f"{t.__module__}.{t.__qualname__}"
else:
t_repr = repr(t) # type: ignore[unreachable]
return f"<AppKey({self._name}, type={t_repr})>"
return f"<{self.__class__.__name__}({self._name}, type={t_repr})>"


class AppKey(BaseKey[_T]):
"""Keys for static typing support in Application."""


class RequestKey(BaseKey[_T]):
"""Keys for static typing support in Request."""


class ResponseKey(BaseKey[_T]):
"""Keys for static typing support in Response."""


@final
Expand All @@ -893,7 +908,7 @@ def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None:

def __init_subclass__(cls) -> None:
raise TypeError(
f"Inheritance class {cls.__name__} from ChainMapProxy " "is forbidden"
f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden"
)

@overload # type: ignore[override]
Expand Down
4 changes: 3 additions & 1 deletion aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, cast

from .abc import AbstractAccessLogger
from .helpers import AppKey
from .helpers import AppKey, RequestKey, ResponseKey
from .log import access_logger
from .typedefs import PathLike
from .web_app import Application, CleanupError
Expand Down Expand Up @@ -203,11 +203,13 @@
"BaseRequest",
"FileField",
"Request",
"RequestKey",
# web_response
"ContentCoding",
"Response",
"StreamResponse",
"json_response",
"ResponseKey",
# web_routedef
"AbstractRouteDef",
"RouteDef",
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(

def __init_subclass__(cls: type["Application"]) -> None:
raise TypeError(
f"Inheritance class {cls.__name__} from web.Application " "is forbidden"
f"Inheritance class {cls.__name__} from web.Application is forbidden"
)

# MutableMapping API
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ async def finish_response(
self.log_exception("Missing return statement on request handler") # type: ignore[unreachable]
else:
self.log_exception(
"Web-handler should return a response instance, " f"got {resp!r}"
f"Web-handler should return a response instance, got {resp!r}"
)
exc = HTTPInternalServerError()
resp = Response(
Expand Down
42 changes: 35 additions & 7 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import sys
import tempfile
import types
import warnings
from collections.abc import Iterator, Mapping, MutableMapping
from re import Pattern
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Final, Optional, cast
from typing import TYPE_CHECKING, Any, Final, Optional, TypeVar, cast, overload
from urllib.parse import parse_qsl

from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
Expand All @@ -26,6 +27,7 @@
ChainMapProxy,
ETag,
HeadersMixin,
RequestKey,
frozen_dataclass_decorator,
is_expected_content_type,
parse_http_date,
Expand All @@ -48,6 +50,7 @@
HTTPBadRequest,
HTTPRequestEntityTooLarge,
HTTPUnsupportedMediaType,
NotAppKeyWarning,
)
from .web_response import StreamResponse

Expand All @@ -65,6 +68,9 @@
from .web_urldispatcher import UrlMappingMatchInfo


_T = TypeVar("_T")


@frozen_dataclass_decorator
class FileField:
name: str
Expand Down Expand Up @@ -101,7 +107,7 @@ class FileField:
############################################################


class BaseRequest(MutableMapping[str, Any], HeadersMixin):
class BaseRequest(MutableMapping[str | RequestKey[Any], Any], HeadersMixin):
POST_METHODS = {
hdrs.METH_PATCH,
hdrs.METH_POST,
Expand All @@ -112,6 +118,7 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):

_post: MultiDictProxy[str | bytes | FileField] | None = None
_read_bytes: bytes | None = None
_seen_str_keys: set[str] = set()

def __init__(
self,
Expand All @@ -123,7 +130,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
*,
client_max_size: int = 1024**2,
state: dict[str, Any] | None = None,
state: dict[RequestKey[Any] | str, Any] | None = None,
scheme: str | None = None,
host: str | None = None,
remote: str | None = None,
Expand Down Expand Up @@ -253,19 +260,40 @@ def rel_url(self) -> URL:

# MutableMapping API

def __getitem__(self, key: str) -> Any:
@overload # type: ignore[override]
def __getitem__(self, key: RequestKey[_T]) -> _T: ...

@overload
def __getitem__(self, key: str) -> Any: ...

def __getitem__(self, key: str | RequestKey[_T]) -> Any:
return self._state[key]

def __setitem__(self, key: str, value: Any) -> None:
@overload # type: ignore[override]
def __setitem__(self, key: RequestKey[_T], value: _T) -> None: ...

@overload
def __setitem__(self, key: str, value: Any) -> None: ...

def __setitem__(self, key: str | RequestKey[_T], value: Any) -> None:
if not isinstance(key, RequestKey) and key not in BaseRequest._seen_str_keys:
BaseRequest._seen_str_keys.add(key)
warnings.warn(
"It is recommended to use web.RequestKey instances for keys.\n"
+ "https://docs.aiohttp.org/en/stable/web_advanced.html"
+ "#request-s-storage",
category=NotAppKeyWarning,
stacklevel=2,
)
self._state[key] = value

def __delitem__(self, key: str) -> None:
def __delitem__(self, key: str | RequestKey[_T]) -> None:
del self._state[key]

def __len__(self) -> int:
return len(self._state)

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[str | RequestKey[Any]]:
return iter(self._state)

########
Expand Down
46 changes: 39 additions & 7 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Iterator, MutableMapping
from concurrent.futures import Executor
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast, overload

from multidict import CIMultiDict, istr

Expand All @@ -21,6 +21,7 @@
CookieMixin,
ETag,
HeadersMixin,
ResponseKey,
must_be_empty_body,
parse_http_date,
populate_with_cookies,
Expand All @@ -32,6 +33,7 @@
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
from .web_exceptions import NotAppKeyWarning

REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus}
LARGE_BODY_SIZE = 1024**2
Expand All @@ -43,6 +45,9 @@
from .web_request import BaseRequest


_T = TypeVar("_T")


# TODO(py311): Convert to StrEnum for wider use
class ContentCoding(enum.Enum):
# The content codings that we have support for.
Expand All @@ -61,7 +66,9 @@ class ContentCoding(enum.Enum):
############################################################


class StreamResponse(MutableMapping[str, Any], HeadersMixin, CookieMixin):
class StreamResponse(
MutableMapping[str | ResponseKey[Any], Any], HeadersMixin, CookieMixin
):

_body: None | bytes | bytearray | Payload
_length_check = True
Expand All @@ -77,6 +84,7 @@ class StreamResponse(MutableMapping[str, Any], HeadersMixin, CookieMixin):
_must_be_empty_body: bool | None = None
_body_length = 0
_send_headers_immediately = True
_seen_str_keys: set[str] = set()

def __init__(
self,
Expand All @@ -93,7 +101,7 @@ def __init__(
the headers when creating a new response object. It is not intended
to be used by external code.
"""
self._state: dict[str, Any] = {}
self._state: dict[str | ResponseKey[Any], Any] = {}

if _real_headers is not None:
self._headers = _real_headers
Expand Down Expand Up @@ -483,19 +491,43 @@ def __repr__(self) -> str:
info = "not prepared"
return f"<{self.__class__.__name__} {self.reason} {info}>"

def __getitem__(self, key: str) -> Any:
@overload # type: ignore[override]
def __getitem__(self, key: ResponseKey[_T]) -> _T: ...

@overload
def __getitem__(self, key: str) -> Any: ...

def __getitem__(self, key: str | ResponseKey[_T]) -> Any:
return self._state[key]

def __setitem__(self, key: str, value: Any) -> None:
@overload # type: ignore[override]
def __setitem__(self, key: ResponseKey[_T], value: _T) -> None: ...

@overload
def __setitem__(self, key: str, value: Any) -> None: ...

def __setitem__(self, key: str | ResponseKey[_T], value: Any) -> None:
if (
not isinstance(key, ResponseKey)
and key not in StreamResponse._seen_str_keys
):
StreamResponse._seen_str_keys.add(key)
warnings.warn(
"It is recommended to use web.ResponseKey instances for keys.\n"
+ "https://docs.aiohttp.org/en/stable/web_advanced.html"
+ "#response-s-storage",
category=NotAppKeyWarning,
stacklevel=2,
)
self._state[key] = value

def __delitem__(self, key: str) -> None:
def __delitem__(self, key: str | ResponseKey[_T]) -> None:
del self._state[key]

def __len__(self) -> int:
return len(self._state)

def __iter__(self) -> Iterator[str]:
def __iter__(self) -> Iterator[str | ResponseKey[Any]]:
return iter(self._state)

def __hash__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(
) -> None:
if not isinstance(app, Application):
raise TypeError(
"The first argument should be web.Application " f"instance, got {app!r}"
f"The first argument should be web.Application instance, got {app!r}"
)
kwargs["access_log_class"] = access_log_class

Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
pass
else:
raise TypeError(
"Only async functions are allowed as web-handlers " f", got {handler!r}"
f"Only async functions are allowed as web-handlers, got {handler!r}"
)

self._method = method
Expand Down
11 changes: 9 additions & 2 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,15 @@ support the :class:`dict` interface.

Therefore, data may be stored inside a request object. ::

async def handler(request):
request['unique_key'] = data
request_id_key = web.RequestKey("request_id_key", str)

@web.middleware
async def request_id_middleware(request, handler):
request[request_id_key] = "some_request_id"
return await handler(request)

async def handler(request):
request_id = request[request_id_key]

See https://github.com/aio-libs/aiohttp_session code for an example.
The ``aiohttp_session.get_session(request)`` method uses ``SESSION_KEY``
Expand Down
Loading
Loading