Skip to content

Commit e24ac54

Browse files
Add UnforgivingExecutionContext (#1255)
1 parent a53b782 commit e24ac54

File tree

2 files changed

+215
-3
lines changed

2 files changed

+215
-3
lines changed

graphene/types/schema.py

+100-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
GraphQLString,
2929
Undefined,
3030
)
31+
from graphql.execution import ExecutionContext
32+
from graphql.execution.values import get_argument_values
3133

3234
from ..utils.str_converters import to_camel_case
3335
from ..utils.get_unbound_function import get_unbound_function
@@ -317,7 +319,7 @@ def create_fields_for_type(self, graphene_type, is_input_type=False):
317319
)
318320
subscribe = field.wrap_subscribe(
319321
self.get_function_for_type(
320-
graphene_type, f"subscribe_{name}", name, field.default_value,
322+
graphene_type, f"subscribe_{name}", name, field.default_value
321323
)
322324
)
323325

@@ -394,6 +396,101 @@ def resolve_type(self, resolve_type_func, type_name, root, info, _type):
394396
return type_
395397

396398

399+
class UnforgivingExecutionContext(ExecutionContext):
400+
"""An execution context which doesn't swallow exceptions.
401+
402+
The only difference between this execution context and the one it inherits from is
403+
that ``except Exception`` is commented out within ``resolve_field_value_or_error``.
404+
By removing that exception handling, only ``GraphQLError``'s are caught.
405+
"""
406+
407+
def resolve_field_value_or_error(
408+
self, field_def, field_nodes, resolve_fn, source, info
409+
):
410+
"""Resolve field to a value or an error.
411+
412+
Isolates the "ReturnOrAbrupt" behavior to not de-opt the resolve_field()
413+
method. Returns the result of resolveFn or the abrupt-return Error object.
414+
415+
For internal use only.
416+
"""
417+
try:
418+
# Build a dictionary of arguments from the field.arguments AST, using the
419+
# variables scope to fulfill any variable references.
420+
args = get_argument_values(field_def, field_nodes[0], self.variable_values)
421+
422+
# Note that contrary to the JavaScript implementation, we pass the context
423+
# value as part of the resolve info.
424+
result = resolve_fn(source, info, **args)
425+
if self.is_awaitable(result):
426+
# noinspection PyShadowingNames
427+
async def await_result():
428+
try:
429+
return await result
430+
except GraphQLError as error:
431+
return error
432+
# except Exception as error:
433+
# return GraphQLError(str(error), original_error=error)
434+
435+
# Yes, this is commented out code. It's been intentionally
436+
# _not_ removed to show what has changed from the original
437+
# implementation.
438+
439+
return await_result()
440+
return result
441+
except GraphQLError as error:
442+
return error
443+
# except Exception as error:
444+
# return GraphQLError(str(error), original_error=error)
445+
446+
# Yes, this is commented out code. It's been intentionally _not_
447+
# removed to show what has changed from the original implementation.
448+
449+
def complete_value_catching_error(
450+
self, return_type, field_nodes, info, path, result
451+
):
452+
"""Complete a value while catching an error.
453+
454+
This is a small wrapper around completeValue which detects and logs errors in
455+
the execution context.
456+
"""
457+
try:
458+
if self.is_awaitable(result):
459+
460+
async def await_result():
461+
value = self.complete_value(
462+
return_type, field_nodes, info, path, await result
463+
)
464+
if self.is_awaitable(value):
465+
return await value
466+
return value
467+
468+
completed = await_result()
469+
else:
470+
completed = self.complete_value(
471+
return_type, field_nodes, info, path, result
472+
)
473+
if self.is_awaitable(completed):
474+
# noinspection PyShadowingNames
475+
async def await_completed():
476+
try:
477+
return await completed
478+
479+
# CHANGE WAS MADE HERE
480+
# ``GraphQLError`` was swapped in for ``except Exception``
481+
except GraphQLError as error:
482+
self.handle_field_error(error, field_nodes, path, return_type)
483+
484+
return await_completed()
485+
return completed
486+
487+
# CHANGE WAS MADE HERE
488+
# ``GraphQLError`` was swapped in for ``except Exception``
489+
except GraphQLError as error:
490+
self.handle_field_error(error, field_nodes, path, return_type)
491+
return None
492+
493+
397494
class Schema:
398495
"""Schema Definition.
399496
@@ -481,6 +578,8 @@ def execute(self, *args, **kwargs):
481578
request_string, an operation name must be provided for the result to be provided.
482579
middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as
483580
defined in `graphql-core`.
581+
execution_context_class (ExecutionContext, optional): The execution context class
582+
to use when resolving queries and mutations.
484583
485584
Returns:
486585
:obj:`ExecutionResult` containing any data and errors for the operation.

graphene/types/tests/test_schema.py

+115-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from graphql.type import GraphQLObjectType, GraphQLSchema
2-
from pytest import raises
2+
from graphql import GraphQLError
3+
from pytest import mark, raises, fixture
34

45
from graphene.tests.utils import dedent
56

67
from ..field import Field
78
from ..objecttype import ObjectType
89
from ..scalars import String
9-
from ..schema import Schema
10+
from ..schema import Schema, UnforgivingExecutionContext
1011

1112

1213
class MyOtherType(ObjectType):
@@ -68,3 +69,115 @@ def test_schema_requires_query_type():
6869
assert len(result.errors) == 1
6970
error = result.errors[0]
7071
assert error.message == "Query root type must be provided."
72+
73+
74+
class TestUnforgivingExecutionContext:
75+
@fixture
76+
def schema(self):
77+
class ErrorFieldsMixin:
78+
sanity_field = String()
79+
expected_error_field = String()
80+
unexpected_value_error_field = String()
81+
unexpected_type_error_field = String()
82+
unexpected_attribute_error_field = String()
83+
unexpected_key_error_field = String()
84+
85+
@staticmethod
86+
def resolve_sanity_field(obj, info):
87+
return "not an error"
88+
89+
@staticmethod
90+
def resolve_expected_error_field(obj, info):
91+
raise GraphQLError("expected error")
92+
93+
@staticmethod
94+
def resolve_unexpected_value_error_field(obj, info):
95+
raise ValueError("unexpected error")
96+
97+
@staticmethod
98+
def resolve_unexpected_type_error_field(obj, info):
99+
raise TypeError("unexpected error")
100+
101+
@staticmethod
102+
def resolve_unexpected_attribute_error_field(obj, info):
103+
raise AttributeError("unexpected error")
104+
105+
@staticmethod
106+
def resolve_unexpected_key_error_field(obj, info):
107+
return {}["fails"]
108+
109+
class NestedObject(ErrorFieldsMixin, ObjectType):
110+
pass
111+
112+
class MyQuery(ErrorFieldsMixin, ObjectType):
113+
nested_object = Field(NestedObject)
114+
nested_object_error = Field(NestedObject)
115+
116+
@staticmethod
117+
def resolve_nested_object(obj, info):
118+
return object()
119+
120+
@staticmethod
121+
def resolve_nested_object_error(obj, info):
122+
raise TypeError()
123+
124+
schema = Schema(query=MyQuery)
125+
return schema
126+
127+
def test_sanity_check(self, schema):
128+
# this should pass with no errors (sanity check)
129+
result = schema.execute(
130+
"query { sanityField }",
131+
execution_context_class=UnforgivingExecutionContext,
132+
)
133+
assert not result.errors
134+
assert result.data == {"sanityField": "not an error"}
135+
136+
def test_nested_sanity_check(self, schema):
137+
# this should pass with no errors (sanity check)
138+
result = schema.execute(
139+
r"query { nestedObject { sanityField } }",
140+
execution_context_class=UnforgivingExecutionContext,
141+
)
142+
assert not result.errors
143+
assert result.data == {"nestedObject": {"sanityField": "not an error"}}
144+
145+
def test_graphql_error(self, schema):
146+
result = schema.execute(
147+
"query { expectedErrorField }",
148+
execution_context_class=UnforgivingExecutionContext,
149+
)
150+
assert len(result.errors) == 1
151+
assert result.errors[0].message == "expected error"
152+
assert result.data == {"expectedErrorField": None}
153+
154+
def test_nested_graphql_error(self, schema):
155+
result = schema.execute(
156+
r"query { nestedObject { expectedErrorField } }",
157+
execution_context_class=UnforgivingExecutionContext,
158+
)
159+
assert len(result.errors) == 1
160+
assert result.errors[0].message == "expected error"
161+
assert result.data == {"nestedObject": {"expectedErrorField": None}}
162+
163+
@mark.parametrize(
164+
"field,exception",
165+
[
166+
("unexpectedValueErrorField", ValueError),
167+
("unexpectedTypeErrorField", TypeError),
168+
("unexpectedAttributeErrorField", AttributeError),
169+
("unexpectedKeyErrorField", KeyError),
170+
("nestedObject { unexpectedValueErrorField }", ValueError),
171+
("nestedObject { unexpectedTypeErrorField }", TypeError),
172+
("nestedObject { unexpectedAttributeErrorField }", AttributeError),
173+
("nestedObject { unexpectedKeyErrorField }", KeyError),
174+
("nestedObjectError { __typename }", TypeError),
175+
],
176+
)
177+
def test_unexpected_error(self, field, exception, schema):
178+
with raises(exception):
179+
# no result, but the exception should be propagated
180+
schema.execute(
181+
f"query {{ {field} }}",
182+
execution_context_class=UnforgivingExecutionContext,
183+
)

0 commit comments

Comments
 (0)