diff --git a/docs/source/api.rst b/docs/source/api.rst index 4b6798e..ebee14a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -128,6 +128,18 @@ have a newline inside a header value, and ``Content-Length: hello`` is an error because `Content-Length` should always be an integer. We may add additional checks in the future. +It is possible to get the first or all headers for a given name. + +.. ipython:: python + + res = h11.Response(status_code=200, headers=[ + ("Date", b"Thu, 09 Jan 2025 18:37:23 GMT"), + ("Set-Cookie", b"sid=1234"), + ("Set-Cookie", b"lang=en_US"), + ]) + res.headers.get(b"date") + res.headers.getlist(b"set-cookie") + While we make sure to expose header names as lowercased bytes, we also preserve the original header casing that is used. Compliant HTTP agents should always treat headers in a case insensitive manner, but diff --git a/h11/_headers.py b/h11/_headers.py index b97d020..60f4bb2 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -1,5 +1,5 @@ import re -from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union +from typing import List, overload, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union from ._abnf import field_name, field_value from ._util import bytesify, LocalProtocolError, validate @@ -13,6 +13,8 @@ from typing_extensions import Literal # type: ignore +T = TypeVar("T") + # Facts # ----- # @@ -84,19 +86,29 @@ class Headers(Sequence[Tuple[bytes, bytes]]): r = Request( method="GET", target="/", - headers=[("Host", "example.org"), ("Connection", "keep-alive")], + headers=[ + ("Host", "example.org"), + ("Connection", "keep-alive"), + ("Cookie", "session=1234"), + ("Cookie", "lang=en_US"), + ], http_version="1.1", ) assert r.headers == [ (b"host", b"example.org"), - (b"connection", b"keep-alive") + (b"connection", b"keep-alive"), + (b"cookie", b"session=1234"), + (b"cookie", b"lang=en_US"), ] assert r.headers.raw_items() == [ (b"Host", b"example.org"), - (b"Connection", b"keep-alive") + (b"Connection", b"keep-alive"), + (b"Cookie", b"session=1234"), + (b"Cookie", b"lang=en_US"), ] + assert r.headers.get(b"host") == b"example.org" + assert r.headers.getlist(b"cookie") == [b"session=1234", b"lang=en_US"] """ - __slots__ = "_full_items" def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: @@ -118,6 +130,27 @@ def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override _, name, value = self._full_items[idx] return (name, value) + def get(self, name: bytes, default: T = None) -> Union[bytes, T]: + """Find the first header with lowercased-name :param:`name`, it returns + its value when found, and :param:`default` otherwise. + + Args: + name (bytes): The lowercased header name to find. + + default: The value to return when the header is not found. + """ + return next((value for name_, value in self if name_ == name), default) + + def getlist(self, name: bytes) -> List[bytes]: + """Find the all the headers with lowercased-name :param:`name`, + it returns their values in a list. It returns an empty list when + no header matched. + + Args: + name (bytes): The lowercased header name to find. + """ + return [value for name_, value in self if name_ == name] + def raw_items(self) -> List[Tuple[bytes, bytes]]: return [(raw_name, value) for raw_name, _, value in self._full_items]