Skip to content

Commit d085c88

Browse files
authored
Subscription revamp (#1235)
* Integrate async tests into main code * Added full support for subscriptions * Fixed syntax using black * Fixed typo
1 parent 2130005 commit d085c88

11 files changed

+140
-64
lines changed

Makefile

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ install-dev:
88
pip install -e ".[dev]"
99

1010
test:
11-
py.test graphene examples tests_asyncio
11+
py.test graphene examples
1212

1313
.PHONY: docs ## Generate docs
1414
docs: install-dev
@@ -20,8 +20,8 @@ docs-live: install-dev
2020

2121
.PHONY: format
2222
format:
23-
black graphene examples setup.py tests_asyncio
23+
black graphene examples setup.py
2424

2525
.PHONY: lint
2626
lint:
27-
flake8 graphene examples setup.py tests_asyncio
27+
flake8 graphene examples setup.py

graphene/relay/connection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def connection_resolver(cls, resolver, connection_type, root, info, **args):
171171
on_resolve = partial(cls.resolve_connection, connection_type, args)
172172
return maybe_thenable(resolved, on_resolve)
173173

174-
def get_resolver(self, parent_resolver):
175-
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
174+
def wrap_resolve(self, parent_resolver):
175+
resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver)
176176
return partial(self.connection_resolver, resolver, self.type)
177177

178178

graphene/relay/node.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def id_resolver(parent_resolver, node, root, info, parent_type_name=None, **args
3737
parent_type_name = parent_type_name or info.parent_type.name
3838
return node.to_global_id(parent_type_name, type_id) # root._meta.name
3939

40-
def get_resolver(self, parent_resolver):
40+
def wrap_resolve(self, parent_resolver):
4141
return partial(
4242
self.id_resolver,
4343
parent_resolver,
@@ -60,7 +60,7 @@ def __init__(self, node, type_=False, **kwargs):
6060
**kwargs,
6161
)
6262

63-
def get_resolver(self, parent_resolver):
63+
def wrap_resolve(self, parent_resolver):
6464
return partial(self.node_type.node_resolver, get_type(self.field_type))
6565

6666

graphene/relay/tests/test_global_id.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ def test_global_id_allows_overriding_of_node_and_required():
4545
def test_global_id_defaults_to_info_parent_type():
4646
my_id = "1"
4747
gid = GlobalID()
48-
id_resolver = gid.get_resolver(lambda *_: my_id)
48+
id_resolver = gid.wrap_resolve(lambda *_: my_id)
4949
my_global_id = id_resolver(None, Info(User))
5050
assert my_global_id == to_global_id(User._meta.name, my_id)
5151

5252

5353
def test_global_id_allows_setting_customer_parent_type():
5454
my_id = "1"
5555
gid = GlobalID(parent_type=User)
56-
id_resolver = gid.get_resolver(lambda *_: my_id)
56+
id_resolver = gid.wrap_resolve(lambda *_: my_id)
5757
my_global_id = id_resolver(None, None)
5858
assert my_global_id == to_global_id(User._meta.name, my_id)

graphene/types/field.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .structures import NonNull
99
from .unmountedtype import UnmountedType
1010
from .utils import get_type
11+
from ..utils.deprecated import warn_deprecation
1112

1213
base_type = type
1314

@@ -114,5 +115,24 @@ def __init__(
114115
def type(self):
115116
return get_type(self._type)
116117

117-
def get_resolver(self, parent_resolver):
118+
get_resolver = None
119+
120+
def wrap_resolve(self, parent_resolver):
121+
"""
122+
Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME}
123+
(parent_resolver) if the Field definition has no resolver.
124+
"""
125+
if self.get_resolver is not None:
126+
warn_deprecation(
127+
"The get_resolver method is being deprecated, please rename it to wrap_resolve."
128+
)
129+
return self.get_resolver(parent_resolver)
130+
118131
return self.resolver or parent_resolver
132+
133+
def wrap_subscribe(self, parent_subscribe):
134+
"""
135+
Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME}
136+
(parent_subscribe) if the Field definition has no subscribe.
137+
"""
138+
return parent_subscribe

graphene/types/schema.py

+53-20
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
parse,
1111
print_schema,
1212
subscribe,
13+
validate,
14+
ExecutionResult,
1315
GraphQLArgument,
1416
GraphQLBoolean,
17+
GraphQLError,
1518
GraphQLEnumValue,
1619
GraphQLField,
1720
GraphQLFloat,
@@ -76,6 +79,11 @@ def is_type_of_from_possible_types(possible_types, root, _info):
7679
return isinstance(root, possible_types)
7780

7881

82+
# We use this resolver for subscriptions
83+
def identity_resolve(root, info):
84+
return root
85+
86+
7987
class TypeMap(dict):
8088
def __init__(
8189
self,
@@ -307,30 +315,48 @@ def create_fields_for_type(self, graphene_type, is_input_type=False):
307315
if isinstance(arg.type, NonNull)
308316
else arg.default_value,
309317
)
318+
subscribe = field.wrap_subscribe(
319+
self.get_function_for_type(
320+
graphene_type, f"subscribe_{name}", name, field.default_value,
321+
)
322+
)
323+
324+
# If we are in a subscription, we use (by default) an
325+
# identity-based resolver for the root, rather than the
326+
# default resolver for objects/dicts.
327+
if subscribe:
328+
field_default_resolver = identity_resolve
329+
elif issubclass(graphene_type, ObjectType):
330+
default_resolver = (
331+
graphene_type._meta.default_resolver or get_default_resolver()
332+
)
333+
field_default_resolver = partial(
334+
default_resolver, name, field.default_value
335+
)
336+
else:
337+
field_default_resolver = None
338+
339+
resolve = field.wrap_resolve(
340+
self.get_function_for_type(
341+
graphene_type, f"resolve_{name}", name, field.default_value
342+
)
343+
or field_default_resolver
344+
)
345+
310346
_field = GraphQLField(
311347
field_type,
312348
args=args,
313-
resolve=field.get_resolver(
314-
self.get_resolver_for_type(
315-
graphene_type, f"resolve_{name}", name, field.default_value
316-
)
317-
),
318-
subscribe=field.get_resolver(
319-
self.get_resolver_for_type(
320-
graphene_type,
321-
f"subscribe_{name}",
322-
name,
323-
field.default_value,
324-
)
325-
),
349+
resolve=resolve,
350+
subscribe=subscribe,
326351
deprecation_reason=field.deprecation_reason,
327352
description=field.description,
328353
)
329354
field_name = field.name or self.get_name(name)
330355
fields[field_name] = _field
331356
return fields
332357

333-
def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
358+
def get_function_for_type(self, graphene_type, func_name, name, default_value):
359+
"""Gets a resolve or subscribe function for a given ObjectType"""
334360
if not issubclass(graphene_type, ObjectType):
335361
return
336362
resolver = getattr(graphene_type, func_name, None)
@@ -350,11 +376,6 @@ def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
350376
if resolver:
351377
return get_unbound_function(resolver)
352378

353-
default_resolver = (
354-
graphene_type._meta.default_resolver or get_default_resolver()
355-
)
356-
return partial(default_resolver, name, default_value)
357-
358379
def resolve_type(self, resolve_type_func, type_name, root, info, _type):
359380
type_ = resolve_type_func(root, info)
360381

@@ -476,7 +497,19 @@ async def execute_async(self, *args, **kwargs):
476497
return await graphql(self.graphql_schema, *args, **kwargs)
477498

478499
async def subscribe(self, query, *args, **kwargs):
479-
document = parse(query)
500+
"""Execute a GraphQL subscription on the schema asynchronously."""
501+
# Do parsing
502+
try:
503+
document = parse(query)
504+
except GraphQLError as error:
505+
return ExecutionResult(data=None, errors=[error])
506+
507+
# Do validation
508+
validation_errors = validate(self.graphql_schema, document)
509+
if validation_errors:
510+
return ExecutionResult(data=None, errors=validation_errors)
511+
512+
# Execute the query
480513
kwargs = normalize_execute_kwargs(kwargs)
481514
return await subscribe(self.graphql_schema, document, *args, **kwargs)
482515

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from pytest import mark
2+
3+
from graphene import ObjectType, Int, String, Schema, Field
4+
5+
6+
class Query(ObjectType):
7+
hello = String()
8+
9+
def resolve_hello(root, info):
10+
return "Hello, world!"
11+
12+
13+
class Subscription(ObjectType):
14+
count_to_ten = Field(Int)
15+
16+
async def subscribe_count_to_ten(root, info):
17+
count = 0
18+
while count < 10:
19+
count += 1
20+
yield count
21+
22+
23+
schema = Schema(query=Query, subscription=Subscription)
24+
25+
26+
@mark.asyncio
27+
async def test_subscription():
28+
subscription = "subscription { countToTen }"
29+
result = await schema.subscribe(subscription)
30+
count = 0
31+
async for item in result:
32+
count = item.data["countToTen"]
33+
assert count == 10
34+
35+
36+
@mark.asyncio
37+
async def test_subscription_fails_with_invalid_query():
38+
# It fails if the provided query is invalid
39+
subscription = "subscription { "
40+
result = await schema.subscribe(subscription)
41+
assert not result.data
42+
assert result.errors
43+
assert "Syntax Error: Expected Name, found <EOF>" in str(result.errors[0])
44+
45+
46+
@mark.asyncio
47+
async def test_subscription_fails_when_query_is_not_valid():
48+
# It can't subscribe to two fields at the same time, triggering a
49+
# validation error.
50+
subscription = "subscription { countToTen, b: countToTen }"
51+
result = await schema.subscribe(subscription)
52+
assert not result.data
53+
assert result.errors
54+
assert "Anonymous Subscription must select only one top level field." in str(
55+
result.errors[0]
56+
)

tests_asyncio/test_subscribe.py

-33
This file was deleted.

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ deps =
88
setenv =
99
PYTHONPATH = .:{envdir}
1010
commands =
11-
py{36,37}: pytest --cov=graphene graphene examples tests_asyncio {posargs}
11+
py{36,37}: pytest --cov=graphene graphene examples {posargs}
1212

1313
[testenv:pre-commit]
1414
basepython=python3.7

0 commit comments

Comments
 (0)