Skip to content

Commit 7aa9f3d

Browse files
feat: allow before_request and after_request handlers to accept a parent argument, to wrap or override the handler from an enclosing scope
1 parent a38c6c1 commit 7aa9f3d

File tree

5 files changed

+75
-12
lines changed

5 files changed

+75
-12
lines changed

litestar/handlers/http_handlers/base.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import functools
4+
import inspect
35
from enum import Enum
46
from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast
57

@@ -62,6 +64,15 @@
6264
__all__ = ("HTTPRouteHandler", "route")
6365

6466

67+
def _wrap_layered_hooks(hooks: list[AsyncAnyCallable]) -> AsyncAnyCallable | None:
68+
"""Given a list of callables, starting from the end, set the parent= keyword argument of each to default to the preceding hook should any preceding hook exist and should that argument be accepted."""
69+
if not hooks:
70+
return None
71+
if "parent" in inspect.signature(hooks[-1]).parameters:
72+
return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1]))
73+
return hooks[-1]
74+
75+
6576
class ResponseHandlerMap(TypedDict):
6677
default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
6778
response_type_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
@@ -260,9 +271,9 @@ def __init__(
260271
)
261272

262273
self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
263-
self.after_response = ensure_async_callable(after_response) if after_response else None
274+
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
264275
self.background = background
265-
self.before_request = ensure_async_callable(before_request) if before_request else None
276+
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
266277
self.cache = cache
267278
self.cache_control = cache_control
268279
self.cache_key_builder = cache_key_builder
@@ -400,7 +411,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None:
400411
"""
401412
if self._resolved_before_request is Empty:
402413
before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request]
403-
self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None
414+
self._resolved_before_request = _wrap_layered_hooks(before_request_handlers)
404415
return cast("AsyncAnyCallable | None", self._resolved_before_request)
405416

406417
def resolve_after_response(self) -> AsyncAnyCallable | None:
@@ -418,7 +429,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None:
418429
for layer in self.ownership_layers
419430
if layer.after_response
420431
]
421-
self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None
432+
self._resolved_after_response = _wrap_layered_hooks(after_response_handlers)
422433

423434
return cast("AsyncAnyCallable | None", self._resolved_after_response)
424435

litestar/handlers/http_handlers/decorators.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def __init__(
121121
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
122122
Defaults to ``None``.
123123
before_request: A sync or async function called immediately before calling the route handler. Receives
124-
the :class:`.connection.Request` instance and any non-``None`` return value is used for the response,
125-
bypassing the route handler.
124+
the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the
125+
outer scope's before_request handler if any exists). Any non-``None`` return value is used for the
126+
response, bypassing the route handler.
126127
cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number
127128
of seconds (e.g. ``120``) to cache the response.
128129
cache_control: A ``cache-control`` header of type

litestar/router.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def __init__(
168168
"""
169169

170170
self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
171-
self.after_response = ensure_async_callable(after_response) if after_response else None
172-
self.before_request = ensure_async_callable(before_request) if before_request else None
171+
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
172+
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
173173
self.cache_control = cache_control
174174
self.dto = dto
175175
self.etag = etag

litestar/types/callable_types.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar
3+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar
44

55
if TYPE_CHECKING:
66
from typing_extensions import TypeAlias
@@ -23,12 +23,29 @@
2323
AfterRequestHookHandler: TypeAlias = (
2424
"Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]"
2525
)
26-
AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"
26+
27+
AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"
28+
29+
30+
class AfterResponseHookHandlerWithParent(Protocol):
31+
async def __call__(self, request: Request, /, *, parent: AfterResponseHookHandler | None = None) -> None: ...
32+
33+
34+
AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParent"
35+
2736
AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]]
2837
AnyCallable: TypeAlias = Callable[..., Any]
2938
AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]"
3039
BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]"
31-
BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
40+
41+
42+
class BeforeRequestHookHandlerWithParent(Protocol):
43+
async def __call__(self, request: Request, /, *, parent: BeforeRequestHookHandler | None = None) -> Any: ...
44+
45+
46+
BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
47+
BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParent"
48+
3249
CacheKeyBuilder: TypeAlias = "Callable[[Request], str]"
3350
ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]"
3451
ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]"

tests/e2e/test_life_cycle_hooks/test_before_request.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict, Optional
1+
import logging
2+
from typing import Any, Dict, Optional, Union
23

34
import pytest
45

@@ -7,6 +8,18 @@
78
from litestar.testing import create_test_client
89
from litestar.types import AnyCallable, BeforeRequestHookHandler
910

11+
logger = logging.getLogger(__name__)
12+
13+
14+
async def async_before_request_handler_with_parent(
15+
request: Request[Any, Any, State], /, *, parent: Optional[BeforeRequestHookHandler] = None
16+
) -> Optional[Dict[str, Union[str, int]]]:
17+
assert isinstance(request, Request)
18+
retval: Dict[str, Union[str, int]] = (None if parent is None else await parent(request)) or {}
19+
retval.setdefault("amended_count", 0)
20+
retval["amended_count"] += 1 # type: ignore
21+
return retval
22+
1023

1124
def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]:
1225
assert isinstance(request, Request)
@@ -88,6 +101,27 @@ def handler() -> Dict[str, str]:
88101
{"hello": "world"},
89102
],
90103
[None, None, None, async_before_request_handler_without_return_value, {"hello": "world"}],
104+
[
105+
sync_before_request_handler_with_return_value,
106+
None,
107+
None,
108+
async_before_request_handler_with_parent,
109+
{"hello": "moon", "amended_count": 1},
110+
],
111+
[
112+
sync_before_request_handler_with_return_value,
113+
None,
114+
async_before_request_handler_with_parent,
115+
async_before_request_handler_with_parent,
116+
{"hello": "moon", "amended_count": 2},
117+
],
118+
[
119+
sync_before_request_handler_with_return_value,
120+
async_before_request_handler_with_parent,
121+
async_before_request_handler_with_parent,
122+
async_before_request_handler_with_parent,
123+
{"hello": "moon", "amended_count": 3},
124+
],
91125
],
92126
)
93127
def test_before_request_handler_resolution(

0 commit comments

Comments
 (0)