diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..cbb948b320 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Update `extensions/mask_errors.py` to handle pre-execution errors and execution errors separately. diff --git a/strawberry/extensions/mask_errors.py b/strawberry/extensions/mask_errors.py index c2ecb0a707..292b73abc9 100644 --- a/strawberry/extensions/mask_errors.py +++ b/strawberry/extensions/mask_errors.py @@ -1,8 +1,7 @@ from collections.abc import Iterator -from typing import Any, Callable +from typing import Callable from graphql.error import GraphQLError -from graphql.execution import ExecutionResult from strawberry.extensions.base_extension import SchemaExtension @@ -34,30 +33,29 @@ def anonymise_error(self, error: GraphQLError) -> GraphQLError: original_error=None, ) - # TODO: proper typing - def _process_result(self, result: Any) -> None: - if not result.errors: - return - + def _process_errors(self, errors: list[GraphQLError]) -> list[GraphQLError]: processed_errors: list[GraphQLError] = [] - for error in result.errors: + for error in errors: if self.should_mask_error(error): processed_errors.append(self.anonymise_error(error)) else: processed_errors.append(error) - result.errors = processed_errors + return processed_errors def on_operation(self) -> Iterator[None]: yield + pre_execution_errors = self.execution_context.pre_execution_errors or [] + self.execution_context.pre_execution_errors = self._process_errors( + pre_execution_errors + ) + result = self.execution_context.result - if isinstance(result, ExecutionResult): - self._process_result(result) - elif result: - self._process_result(result.initial_result) + if result is not None and result.errors: + result.errors = self._process_errors(result.errors) __all__ = ["MaskErrors"]