Skip to content

Commit a91e4b2

Browse files
committed
Remove defer/stream support from subscriptions
Replicates graphql/graphql-js@1bf71ee
1 parent 98b44cc commit a91e4b2

12 files changed

+596
-485
lines changed

docs/modules/execution.rst

-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ Execution
5353

5454
.. autofunction:: subscribe
5555

56-
.. autofunction:: experimental_subscribe_incrementally
57-
5856
.. autofunction:: create_source_event_stream
5957

6058
.. autoclass:: Middleware

src/graphql/execution/__init__.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
default_field_resolver,
1414
default_type_resolver,
1515
subscribe,
16-
experimental_subscribe_incrementally,
1716
ExecutionContext,
1817
ExecutionResult,
1918
ExperimentalIncrementalExecutionResults,
@@ -30,7 +29,7 @@
3029
FormattedIncrementalResult,
3130
Middleware,
3231
)
33-
from .async_iterables import flatten_async_iterable, map_async_iterable
32+
from .async_iterables import map_async_iterable
3433
from .middleware import MiddlewareManager
3534
from .values import get_argument_values, get_directive_values, get_variable_values
3635

@@ -43,7 +42,6 @@
4342
"default_field_resolver",
4443
"default_type_resolver",
4544
"subscribe",
46-
"experimental_subscribe_incrementally",
4745
"ExecutionContext",
4846
"ExecutionResult",
4947
"ExperimentalIncrementalExecutionResults",
@@ -58,7 +56,6 @@
5856
"FormattedIncrementalDeferResult",
5957
"FormattedIncrementalStreamResult",
6058
"FormattedIncrementalResult",
61-
"flatten_async_iterable",
6259
"map_async_iterable",
6360
"Middleware",
6461
"MiddlewareManager",

src/graphql/execution/async_iterables.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Union,
1313
)
1414

15-
__all__ = ["aclosing", "flatten_async_iterable", "map_async_iterable"]
15+
__all__ = ["aclosing", "map_async_iterable"]
1616

1717
T = TypeVar("T")
1818
V = TypeVar("V")
@@ -42,21 +42,6 @@ async def __aexit__(self, *_exc_info: object) -> None:
4242
await aclose()
4343

4444

45-
async def flatten_async_iterable(
46-
iterable: AsyncIterableOrGenerator[AsyncIterableOrGenerator[T]],
47-
) -> AsyncGenerator[T, None]:
48-
"""Flatten async iterables.
49-
50-
Given an AsyncIterable of AsyncIterables, flatten all yielded results into a
51-
single AsyncIterable.
52-
"""
53-
async with aclosing(iterable) as sub_iterators: # type: ignore
54-
async for sub_iterator in sub_iterators:
55-
async with aclosing(sub_iterator) as items: # type: ignore
56-
async for item in items:
57-
yield item
58-
59-
6045
async def map_async_iterable(
6146
iterable: AsyncIterableOrGenerator[T], callback: Callable[[T], Awaitable[V]]
6247
) -> AsyncGenerator[V, None]:

src/graphql/execution/collect_fields.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
FragmentDefinitionNode,
99
FragmentSpreadNode,
1010
InlineFragmentNode,
11+
OperationDefinitionNode,
12+
OperationType,
1113
SelectionSetNode,
1214
)
1315
from ..type import (
@@ -43,7 +45,7 @@ def collect_fields(
4345
fragments: Dict[str, FragmentDefinitionNode],
4446
variable_values: Dict[str, Any],
4547
runtime_type: GraphQLObjectType,
46-
selection_set: SelectionSetNode,
48+
operation: OperationDefinitionNode,
4749
) -> FieldsAndPatches:
4850
"""Collect fields.
4951
@@ -61,8 +63,9 @@ def collect_fields(
6163
schema,
6264
fragments,
6365
variable_values,
66+
operation,
6467
runtime_type,
65-
selection_set,
68+
operation.selection_set,
6669
fields,
6770
patches,
6871
set(),
@@ -74,6 +77,7 @@ def collect_subfields(
7477
schema: GraphQLSchema,
7578
fragments: Dict[str, FragmentDefinitionNode],
7679
variable_values: Dict[str, Any],
80+
operation: OperationDefinitionNode,
7781
return_type: GraphQLObjectType,
7882
field_nodes: List[FieldNode],
7983
) -> FieldsAndPatches:
@@ -100,6 +104,7 @@ def collect_subfields(
100104
schema,
101105
fragments,
102106
variable_values,
107+
operation,
103108
return_type,
104109
node.selection_set,
105110
sub_field_nodes,
@@ -113,6 +118,7 @@ def collect_fields_impl(
113118
schema: GraphQLSchema,
114119
fragments: Dict[str, FragmentDefinitionNode],
115120
variable_values: Dict[str, Any],
121+
operation: OperationDefinitionNode,
116122
runtime_type: GraphQLObjectType,
117123
selection_set: SelectionSetNode,
118124
fields: Dict[str, List[FieldNode]],
@@ -133,13 +139,14 @@ def collect_fields_impl(
133139
) or not does_fragment_condition_match(schema, selection, runtime_type):
134140
continue
135141

136-
defer = get_defer_values(variable_values, selection)
142+
defer = get_defer_values(operation, variable_values, selection)
137143
if defer:
138144
patch_fields = defaultdict(list)
139145
collect_fields_impl(
140146
schema,
141147
fragments,
142148
variable_values,
149+
operation,
143150
runtime_type,
144151
selection.selection_set,
145152
patch_fields,
@@ -152,6 +159,7 @@ def collect_fields_impl(
152159
schema,
153160
fragments,
154161
variable_values,
162+
operation,
155163
runtime_type,
156164
selection.selection_set,
157165
fields,
@@ -164,7 +172,7 @@ def collect_fields_impl(
164172
if not should_include_node(variable_values, selection):
165173
continue
166174

167-
defer = get_defer_values(variable_values, selection)
175+
defer = get_defer_values(operation, variable_values, selection)
168176
if frag_name in visited_fragment_names and not defer:
169177
continue
170178

@@ -183,6 +191,7 @@ def collect_fields_impl(
183191
schema,
184192
fragments,
185193
variable_values,
194+
operation,
186195
runtime_type,
187196
fragment.selection_set,
188197
patch_fields,
@@ -195,6 +204,7 @@ def collect_fields_impl(
195204
schema,
196205
fragments,
197206
variable_values,
207+
operation,
198208
runtime_type,
199209
fragment.selection_set,
200210
fields,
@@ -210,7 +220,9 @@ class DeferValues(NamedTuple):
210220

211221

212222
def get_defer_values(
213-
variable_values: Dict[str, Any], node: Union[FragmentSpreadNode, InlineFragmentNode]
223+
operation: OperationDefinitionNode,
224+
variable_values: Dict[str, Any],
225+
node: Union[FragmentSpreadNode, InlineFragmentNode],
214226
) -> Optional[DeferValues]:
215227
"""Get values of defer directive if active.
216228
@@ -223,6 +235,13 @@ def get_defer_values(
223235
if not defer or defer.get("if") is False:
224236
return None
225237

238+
if operation.operation == OperationType.SUBSCRIPTION:
239+
msg = (
240+
"`@defer` directive not supported on subscription operations."
241+
" Disable `@defer` by setting the `if` argument to `false`."
242+
)
243+
raise TypeError(msg)
244+
226245
return DeferValues(defer.get("label"))
227246

228247

0 commit comments

Comments
 (0)