Skip to content

Provide execution context as an argument to schema extensions #3640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ repos:
- id: alex
exclude: (CHANGELOG|TWEET).md

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
files: '^docs/.*\.mdx?$'

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
Expand Down
8 changes: 5 additions & 3 deletions strawberry/extensions/add_validation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
if TYPE_CHECKING:
from graphql import ASTValidationRule

from strawberry.types.execution import ExecutionContext


class AddValidationRules(SchemaExtension):
"""Add graphql-core validation rules.
Expand Down Expand Up @@ -42,9 +44,9 @@
def __init__(self, validation_rules: List[Type[ASTValidationRule]]) -> None:
self.validation_rules = validation_rules

def on_operation(self) -> Iterator[None]:
self.execution_context.validation_rules = (
self.execution_context.validation_rules + tuple(self.validation_rules)
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.validation_rules = execution_context.validation_rules + tuple(

Check warning on line 48 in strawberry/extensions/add_validation_rules.py

View check run for this annotation

Codecov / codecov/patch

strawberry/extensions/add_validation_rules.py#L48

Added line #L48 was not covered by tests
self.validation_rules
)
yield

Expand Down
25 changes: 14 additions & 11 deletions strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,33 @@ class LifecycleStep(Enum):


class SchemaExtension:
execution_context: ExecutionContext
if not TYPE_CHECKING:
# to support extensions that still use the old signature
# we have an optional argument here for ease of initialization.
def __init__(
self, *, execution_context: ExecutionContext | None = None
) -> None: ...

# to support extensions that still use the old signature
# we have an optional argument here for ease of initialization.
def __init__(
self, *, execution_context: ExecutionContext | None = None
) -> None: ...
def on_operation( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after a GraphQL operation (query / mutation) starts."""
yield None

def on_validate( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after the validation step."""
yield None

def on_parse( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after the parsing step."""
yield None

def on_execute( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after the execution step."""
yield None
Expand All @@ -55,12 +55,15 @@ def resolve(
_next: Callable,
root: Any,
info: GraphQLResolveInfo,
execution_context: ExecutionContext,
*args: str,
**kwargs: Any,
) -> AwaitableOrValue[object]:
return _next(root, info, *args, **kwargs)

def get_results(self) -> AwaitableOrValue[Dict[str, Any]]:
def get_results(
self, execution_context: ExecutionContext
) -> AwaitableOrValue[Dict[str, Any]]:
return {}

@classmethod
Expand Down
126 changes: 41 additions & 85 deletions strawberry/extensions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import inspect
import types
import warnings
from asyncio import iscoroutinefunction
from typing import (
TYPE_CHECKING,
Expand All @@ -21,64 +20,58 @@
)

from strawberry.extensions import SchemaExtension
from strawberry.utils.await_maybe import AwaitableOrValue, await_maybe

if TYPE_CHECKING:
from types import TracebackType

from strawberry.extensions.base_extension import Hook
from strawberry.types.execution import ExecutionContext
from strawberry.utils.await_maybe import AwaitableOrValue


class WrappedHook(NamedTuple):
extension: SchemaExtension
hook: Callable[..., Union[AsyncContextManager[None], ContextManager[None]]]
hook: Callable[
[ExecutionContext], Union[AsyncContextManager[None], ContextManager[None]]
]
is_async: bool


class ExtensionContextManagerBase:
__slots__ = (
"hooks",
"deprecation_message",
"default_hook",
"async_exit_stack",
"exit_stack",
"execution_context",
)

HOOK_NAME: str
DEFAULT_HOOK: Hook

def __init_subclass__(cls) -> None:
cls.DEPRECATION_MESSAGE = (
f"Event driven styled extensions for "
f"{cls.LEGACY_ENTER} or {cls.LEGACY_EXIT}"
f" are deprecated, use {cls.HOOK_NAME} instead"
)
cls.DEFAULT_HOOK = getattr(SchemaExtension, cls.HOOK_NAME)

HOOK_NAME: str
DEPRECATION_MESSAGE: str
LEGACY_ENTER: str
LEGACY_EXIT: str
def __init__(
self, hooks: List[WrappedHook], execution_context: ExecutionContext
) -> None:
self.hooks = hooks
self.execution_context = execution_context

@classmethod
def get_hooks(cls, extensions: List[SchemaExtension]) -> List[WrappedHook]:
hooks = []

def __init__(self, extensions: List[SchemaExtension]) -> None:
self.hooks: List[WrappedHook] = []
self.default_hook: Hook = getattr(SchemaExtension, self.HOOK_NAME)
for extension in extensions:
hook = self.get_hook(extension)
hook = cls.get_hook(extension)
if hook:
self.hooks.append(hook)
hooks.append(hook)

def get_hook(self, extension: SchemaExtension) -> Optional[WrappedHook]:
on_start = getattr(extension, self.LEGACY_ENTER, None)
on_end = getattr(extension, self.LEGACY_EXIT, None)
return hooks

is_legacy = on_start is not None or on_end is not None
hook_fn: Optional[Hook] = getattr(type(extension), self.HOOK_NAME)
hook_fn = hook_fn if hook_fn is not self.default_hook else None
if is_legacy and hook_fn is not None:
raise ValueError(
f"{extension} defines both legacy and new style extension hooks for "
"{self.HOOK_NAME}"
)
elif is_legacy:
warnings.warn(self.DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=3)
return self.from_legacy(extension, on_start, on_end)
@classmethod
def get_hook(cls, extension: SchemaExtension) -> Optional[WrappedHook]:
hook_fn: Optional[Hook] = getattr(type(extension), cls.HOOK_NAME)
hook_fn = hook_fn if hook_fn is not cls.DEFAULT_HOOK else None

if hook_fn:
if inspect.isgeneratorfunction(hook_fn):
Expand All @@ -98,67 +91,36 @@ def get_hook(self, extension: SchemaExtension) -> Optional[WrappedHook]:
)

if callable(hook_fn):
return self.from_callable(extension, hook_fn)
return cls.from_callable(extension, hook_fn)

raise ValueError(
f"Hook {self.HOOK_NAME} on {extension} "
f"Hook {cls.HOOK_NAME} on {extension} "
f"must be callable, received {hook_fn!r}"
)

return None # Current extension does not define a hook for this lifecycle stage

@staticmethod
def from_legacy(
extension: SchemaExtension,
on_start: Optional[Callable[[], None]] = None,
on_end: Optional[Callable[[], None]] = None,
) -> WrappedHook:
if iscoroutinefunction(on_start) or iscoroutinefunction(on_end):

@contextlib.asynccontextmanager
async def iterator() -> AsyncIterator:
if on_start:
await await_maybe(on_start())

yield

if on_end:
await await_maybe(on_end())

return WrappedHook(extension=extension, hook=iterator, is_async=True)

else:

@contextlib.contextmanager
def iterator_async() -> Iterator[None]:
if on_start:
on_start()

yield

if on_end:
on_end()

return WrappedHook(extension=extension, hook=iterator_async, is_async=False)

@staticmethod
def from_callable(
extension: SchemaExtension,
func: Callable[[SchemaExtension], AwaitableOrValue[Any]],
func: Callable[[SchemaExtension, ExecutionContext], AwaitableOrValue[Any]],
) -> WrappedHook:
self_ = extension
if iscoroutinefunction(func):

@contextlib.asynccontextmanager
async def iterator() -> AsyncIterator[None]:
await func(extension)
async def iterator(
execution_context: ExecutionContext,
) -> AsyncIterator[None]:
await func(self_, execution_context)
yield

return WrappedHook(extension=extension, hook=iterator, is_async=True)
else:

@contextlib.contextmanager
def iterator() -> Iterator[None]:
func(extension)
def iterator(execution_context: ExecutionContext) -> Iterator[None]:
func(self_, execution_context)
yield

return WrappedHook(extension=extension, hook=iterator, is_async=False)
Expand All @@ -175,7 +137,7 @@ def __enter__(self) -> None:
"failed to complete synchronously."
)
else:
self.exit_stack.enter_context(hook.hook()) # type: ignore
self.exit_stack.enter_context(hook.hook(self.execution_context))

def __exit__(
self,
Expand All @@ -192,9 +154,11 @@ async def __aenter__(self) -> None:

for hook in self.hooks:
if hook.is_async:
await self.async_exit_stack.enter_async_context(hook.hook()) # type: ignore
await self.async_exit_stack.enter_async_context(
hook.hook(self.execution_context)
) # type: ignore
else:
self.async_exit_stack.enter_context(hook.hook()) # type: ignore
self.async_exit_stack.enter_context(hook.hook(self.execution_context)) # type: ignore

async def __aexit__(
self,
Expand All @@ -207,23 +171,15 @@ async def __aexit__(

class OperationContextManager(ExtensionContextManagerBase):
HOOK_NAME = SchemaExtension.on_operation.__name__
LEGACY_ENTER = "on_request_start"
LEGACY_EXIT = "on_request_end"


class ValidationContextManager(ExtensionContextManagerBase):
HOOK_NAME = SchemaExtension.on_validate.__name__
LEGACY_ENTER = "on_validation_start"
LEGACY_EXIT = "on_validation_end"


class ParsingContextManager(ExtensionContextManagerBase):
HOOK_NAME = SchemaExtension.on_parse.__name__
LEGACY_ENTER = "on_parsing_start"
LEGACY_EXIT = "on_parsing_end"


class ExecutingContextManager(ExtensionContextManagerBase):
HOOK_NAME = SchemaExtension.on_execute.__name__
LEGACY_ENTER = "on_executing_start"
LEGACY_EXIT = "on_executing_end"
5 changes: 3 additions & 2 deletions strawberry/extensions/disable_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Iterator

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.types.execution import ExecutionContext


class DisableValidation(SchemaExtension):
Expand All @@ -26,8 +27,8 @@
# some in the future
pass

def on_operation(self) -> Iterator[None]:
self.execution_context.validation_rules = () # remove all validation_rules
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.validation_rules = () # remove all validation_rules

Check warning on line 31 in strawberry/extensions/disable_validation.py

View check run for this annotation

Codecov / codecov/patch

strawberry/extensions/disable_validation.py#L31

Added line #L31 was not covered by tests
yield


Expand Down
5 changes: 3 additions & 2 deletions strawberry/extensions/mask_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from graphql.error import GraphQLError

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.types.execution import ExecutionContext


def default_should_mask_error(_: GraphQLError) -> bool:
Expand Down Expand Up @@ -32,10 +33,10 @@
original_error=None,
)

def on_operation(self) -> Iterator[None]:
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
yield
result = self.execution_context.result
result = execution_context.result
if result and result.errors:

Check warning on line 39 in strawberry/extensions/mask_errors.py

View check run for this annotation

Codecov / codecov/patch

strawberry/extensions/mask_errors.py#L38-L39

Added lines #L38 - L39 were not covered by tests
processed_errors: List[GraphQLError] = []
for error in result.errors:
if self.should_mask_error(error):
Expand Down
5 changes: 3 additions & 2 deletions strawberry/extensions/max_tokens.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Iterator

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.types.execution import ExecutionContext


class MaxTokensLimiter(SchemaExtension):
Expand Down Expand Up @@ -34,8 +35,8 @@ def __init__(
"""
self.max_token_count = max_token_count

def on_operation(self) -> Iterator[None]:
self.execution_context.parse_options["max_tokens"] = self.max_token_count
def on_operation(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.parse_options["max_tokens"] = self.max_token_count
yield


Expand Down
5 changes: 2 additions & 3 deletions strawberry/extensions/parser_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.schema.execute import parse_document
from strawberry.types.execution import ExecutionContext


class ParserCache(SchemaExtension):
Expand Down Expand Up @@ -33,9 +34,7 @@ def __init__(self, maxsize: Optional[int] = None) -> None:
"""
self.cached_parse_document = lru_cache(maxsize=maxsize)(parse_document)

def on_parse(self) -> Iterator[None]:
execution_context = self.execution_context

def on_parse(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.graphql_document = self.cached_parse_document(
execution_context.query, **execution_context.parse_options
)
Expand Down
Loading
Loading