Skip to content

Commit adc478e

Browse files
feat: allow before_request and after_request handlers to accept a parent argument, to wrap or override the handler from an enclosing scope
1 parent a38c6c1 commit adc478e

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

litestar/handlers/http_handlers/base.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from enum import Enum
44
from typing import TYPE_CHECKING, AnyStr, Mapping, Sequence, TypedDict, cast
5+
import functools
6+
import inspect
57

68
from litestar._layers.utils import narrow_response_cookies, narrow_response_headers
79
from litestar.connection import Request
@@ -61,6 +63,14 @@
6163

6264
__all__ = ("HTTPRouteHandler", "route")
6365

66+
def _wrap_layered_hooks(hooks: list[Callable]) -> AsyncAnyCallable | None:
67+
"""Given a list of callables, starting from the end, """
68+
if not hooks:
69+
return None
70+
if 'parent' in inspect.signature(hooks[-1]).parameters:
71+
return functools.partial(hooks[-1], parent=_wrap_layered_hooks(hooks[:-1]))
72+
return hooks[-1]
73+
6474

6575
class ResponseHandlerMap(TypedDict):
6676
default_handler: Callable[[Any], Awaitable[ASGIApp]] | EmptyType
@@ -400,7 +410,7 @@ def resolve_before_request(self) -> AsyncAnyCallable | None:
400410
"""
401411
if self._resolved_before_request is Empty:
402412
before_request_handlers = [layer.before_request for layer in self.ownership_layers if layer.before_request]
403-
self._resolved_before_request = before_request_handlers[-1] if before_request_handlers else None
413+
self._resolved_before_request = _wrap_layered_hooks(before_request_handlers)
404414
return cast("AsyncAnyCallable | None", self._resolved_before_request)
405415

406416
def resolve_after_response(self) -> AsyncAnyCallable | None:
@@ -418,7 +428,7 @@ def resolve_after_response(self) -> AsyncAnyCallable | None:
418428
for layer in self.ownership_layers
419429
if layer.after_response
420430
]
421-
self._resolved_after_response = after_response_handlers[-1] if after_response_handlers else None
431+
self._resolved_after_response = _wrap_layered_hooks(after_response_handlers)
422432

423433
return cast("AsyncAnyCallable | None", self._resolved_after_response)
424434

litestar/handlers/http_handlers/decorators.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def __init__(
121121
:class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
122122
Defaults to ``None``.
123123
before_request: A sync or async function called immediately before calling the route handler. Receives
124-
the :class:`.connection.Request` instance and any non-``None`` return value is used for the response,
125-
bypassing the route handler.
124+
the :class:`.connection.Request` instance (and, if it accepts a keyword argument named `parent`, the
125+
outer scope's before_request handler if any exists). Any non-``None`` return value is used for the
126+
response, bypassing the route handler.
126127
cache: Enables response caching if configured on the application level. Valid values are ``True`` or a number
127128
of seconds (e.g. ``120``) to cache the response.
128129
cache_control: A ``cache-control`` header of type

litestar/types/callable_types.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar
3+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator, Protocol, TypeVar
44

55
if TYPE_CHECKING:
66
from typing_extensions import TypeAlias
@@ -23,12 +23,29 @@
2323
AfterRequestHookHandler: TypeAlias = (
2424
"Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]] | Callable[[Response], SyncOrAsyncUnion[Response]]"
2525
)
26-
AfterResponseHookHandler: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"
26+
27+
AfterResponseHookHandlerSimple: TypeAlias = "Callable[[Request], SyncOrAsyncUnion[None]]"
28+
29+
class AfterResponseHookHandlerWithParentAsync(Protocol):
30+
async def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None") -> Any: ...
31+
class AfterResponseHookHandlerWithParentSync(Protocol):
32+
def __call__(self, request: "Request", /, *, parent: "AfterResponseHookHandler | None") -> Any: ...
33+
34+
AfterResponseHookHandler: TypeAlias = "AfterResponseHookHandlerSimple | AfterResponseHookHandlerWithParentAsync | AfterResponseHookHandlerWithParentSync"
35+
2736
AsyncAnyCallable: TypeAlias = Callable[..., Awaitable[Any]]
2837
AnyCallable: TypeAlias = Callable[..., Any]
2938
AnyGenerator: TypeAlias = "Generator[Any, Any, Any] | AsyncGenerator[Any, Any]"
3039
BeforeMessageSendHookHandler: TypeAlias = "Callable[[Message, Scope], SyncOrAsyncUnion[None]]"
31-
BeforeRequestHookHandler: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
40+
41+
class BeforeRequestHookHandlerWithParentAsync(Protocol):
42+
async def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None") -> Any: ...
43+
class BeforeRequestHookHandlerWithParentSync(Protocol):
44+
def __call__(self, request: "Request", /, *, parent: "BeforeRequestHookHandler | None") -> Any: ...
45+
46+
BeforeRequestHookHandlerSimple: TypeAlias = "Callable[[Request], Any | Awaitable[Any]]"
47+
BeforeRequestHookHandler: TypeAlias = "BeforeRequestHookHandlerSimple | BeforeRequestHookHandlerWithParentSync | BeforeRequestHookHandlerWithParentAsync"
48+
3249
CacheKeyBuilder: TypeAlias = "Callable[[Request], str]"
3350
ExceptionHandler: TypeAlias = "Callable[[Request, ExceptionT], Response]"
3451
ExceptionLoggingHandler: TypeAlias = "Callable[[Logger, Scope, list[str]], None]"

0 commit comments

Comments
 (0)