Skip to content

Commit b0b45a6

Browse files
feat: allow before_request and after_request handlers to accept a parent argument, to wrap or override the handler from an enclosing scope
1 parent a38c6c1 commit b0b45a6

File tree

5 files changed

+48
-11
lines changed

5 files changed

+48
-11
lines changed

litestar/handlers/http_handlers/base.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from enum import Enum
44
from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast
5+
import functools
6+
import inspect
57

68
from litestar._layers.utils import narrow_response_cookies, narrow_response_headers
79
from litestar.connection import Request
@@ -61,6 +63,14 @@
6163

6264
__all__ = ("HTTPRouteHandler", "route")
6365

66+
def _wrap_layered_hooks(hooks: list[AsyncAnyCallable]) -> AsyncAnyCallable | None:
67+
"""Given a list of callables, starting from the end, """
68+
if not hooks:
69+
return None
70+
if 'parent' in inspect.signature(hooks[-1]).parameters:
71+
return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1]))
72+
return hooks[-1]
73+
6474

6575
class ResponseHandlerMap(TypedDict):
6676
default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
@@ -260,9 +270,9 @@ def __init__(
260270
)
261271

262272
self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
263-
self.after_response = ensure_async_callable(after_response) if after_response else None
273+
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
264274
self.background = background
265-
self.before_request = ensure_async_callable(before_request) if before_request else None
275+
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
266276
self.cache = cache
267277
self.cache_control = cache_control
268278
self.cache_key_builder = cache_key_builder
@@ -400,7 +410,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None:
400410
"""
401411
if self._resolved_before_request is Empty:
402412
before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request]
403-
self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None
413+
self._resolved_before_request = _wrap_layered_hooks(before_request_handlers)
404414
return cast("AsyncAnyCallable | None", self._resolved_before_request)
405415

406416
def resolve_after_response(self) -> AsyncAnyCallable | None:
@@ -418,7 +428,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None:
418428
for layer in self.ownership_layers
419429
if layer.after_response
420430
]
421-
self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None
431+
self._resolved_after_response = _wrap_layered_hooks(after_response_handlers)
422432

423433
return cast("AsyncAnyCallable | None", self._resolved_after_response)
424434

litestar/handlers/http_handlers/decorators.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def __init__(
121121
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
122122
Defaults to ``None``.
123123
before_request: A sync or async function called immediately before calling the route handler. Receives
124-
the :class:`.connection.Request` instance and any non-``None`` return value is used for the response,
125-
bypassing the route handler.
124+
the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the
125+
outer scope's before_request handler if any exists). Any non-``None`` return value is used for the
126+
response, bypassing the route handler.
126127
cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number
127128
of seconds (e.g. ``120``) to cache the response.
128129
cache_control: A ``cache-control`` header of type

litestar/router.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def __init__(
168168
"""
169169

170170
self.after_request = ensure_async_callable(after_request) if after_request else None # pyright: ignore
171-
self.after_response = ensure_async_callable(after_response) if after_response else None
172-
self.before_request = ensure_async_callable(before_request) if before_request else None
171+
self.after_response = ensure_async_callable(after_response) if after_response else None # pyright: ignore
172+
self.before_request = ensure_async_callable(before_request) if before_request else None # pyright: ignore
173173
self.cache_control = cache_control
174174
self.dto = dto
175175
self.etag = etag

litestar/types/callable_types.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar
3+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar
44

55
if TYPE_CHECKING:
66
from typing_extensions import TypeAlias
@@ -23,12 +23,25 @@
2323
AfterRequestHookHandler: TypeAlias = (
2424
"Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]"
2525
)
26-
AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"
26+
27+
AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"
28+
29+
class AfterResponseHookHandlerWithParent(Protocol):
30+
async def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None" = None) -> None: ...
31+
32+
AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParent"
33+
2734
AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]]
2835
AnyCallable: TypeAlias = Callable[..., Any]
2936
AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]"
3037
BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]"
31-
BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
38+
39+
class BeforeRequestHookHandlerWithParent(Protocol):
40+
async def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None" = None) -> Any: ...
41+
42+
BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
43+
BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParent"
44+
3245
CacheKeyBuilder: TypeAlias = "Callable[[Request], str]"
3346
ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]"
3447
ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]"

tests/e2e/test_life_cycle_hooks/test_before_request.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
from typing import Any, Dict, Optional
2+
import logging
23

34
import pytest
45

56
from litestar import Controller, Request, Response, Router, get
7+
from litestar.handlers.http_handlers.base import _wrap_layered_hooks
68
from litestar.datastructures import State
79
from litestar.testing import create_test_client
810
from litestar.types import AnyCallable, BeforeRequestHookHandler
911

12+
logger = logging.getLogger(__name__)
13+
14+
async def async_before_request_handler_with_parent(request: Request[Any, Any, State], /, *, parent: Optional[BeforeRequestHookHandler] = None):
15+
assert isinstance(request, Request)
16+
retval = (None if parent is None else await parent(request)) or {}
17+
retval.setdefault('amended_count', 0)
18+
retval['amended_count'] += 1
19+
return retval
1020

1121
def sync_before_request_handler_with_return_value(request: Request[Any, Any, State]) -> Dict[str, str]:
1222
assert isinstance(request, Request)
@@ -88,6 +98,9 @@ def handler() -> Dict[str, str]:
8898
{"hello": "world"},
8999
],
90100
[None, None, None, async_before_request_handler_without_return_value, {"hello": "world"}],
101+
[sync_before_request_handler_with_return_value, None, None, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 1}],
102+
[sync_before_request_handler_with_return_value, None, async_before_request_handler_with_parent, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 2}],
103+
[sync_before_request_handler_with_return_value, async_before_request_handler_with_parent, async_before_request_handler_with_parent, async_before_request_handler_with_parent, {"hello": "moon", "amended_count": 3}],
91104
],
92105
)
93106
def test_before_request_handler_resolution(

0 commit comments

Comments
 (0)