|
| 1 | +"""Compatibility layer for graphql-core 3.2.x and 3.3.x.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING, Any, cast |
| 6 | + |
| 7 | +from graphql import ( |
| 8 | + FieldNode, |
| 9 | + GraphQLInterfaceType, |
| 10 | + GraphQLObjectType, |
| 11 | +) |
| 12 | +from graphql.version import VersionInfo, version_info |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from graphql.type.definition import GraphQLResolveInfo |
| 16 | + |
| 17 | +IS_GQL_33 = version_info >= VersionInfo.from_str("3.3.0a0") |
| 18 | +IS_GQL_32 = not IS_GQL_33 |
| 19 | + |
| 20 | + |
| 21 | +def get_sub_field_selections( |
| 22 | + info: GraphQLResolveInfo, |
| 23 | + parent_type: GraphQLObjectType | GraphQLInterfaceType, |
| 24 | +) -> dict[str, list[FieldNode]]: |
| 25 | + """Collect sub-fields, handling API differences between 3.2.x and 3.3.x.""" |
| 26 | + if IS_GQL_32: |
| 27 | + return _get_selections_gql32(info, parent_type) |
| 28 | + return _get_selections_gql33(info, parent_type) |
| 29 | + |
| 30 | + |
| 31 | +def _get_selections_gql32( |
| 32 | + info: GraphQLResolveInfo, |
| 33 | + parent_type: GraphQLObjectType | GraphQLInterfaceType, |
| 34 | +) -> dict[str, list[FieldNode]]: |
| 35 | + from graphql.execution.collect_fields import ( |
| 36 | + collect_sub_fields, |
| 37 | + ) |
| 38 | + |
| 39 | + return collect_sub_fields( |
| 40 | + info.schema, |
| 41 | + info.fragments, |
| 42 | + info.variable_values, |
| 43 | + cast("GraphQLObjectType", parent_type), |
| 44 | + info.field_nodes, |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def _get_selections_gql33( |
| 49 | + info: GraphQLResolveInfo, |
| 50 | + parent_type: GraphQLObjectType | GraphQLInterfaceType, |
| 51 | +) -> dict[str, list[FieldNode]]: |
| 52 | + from graphql.execution.collect_fields import ( |
| 53 | + FieldDetails, # type: ignore |
| 54 | + collect_subfields, # type: ignore |
| 55 | + ) |
| 56 | + |
| 57 | + field_group: list[Any] = [ |
| 58 | + FieldDetails(node=fn, defer_usage=None) for fn in info.field_nodes |
| 59 | + ] |
| 60 | + |
| 61 | + collected = collect_subfields( |
| 62 | + info.schema, |
| 63 | + info.fragments, |
| 64 | + info.variable_values, |
| 65 | + info.operation, |
| 66 | + cast("GraphQLObjectType", parent_type), |
| 67 | + field_group, |
| 68 | + ) |
| 69 | + |
| 70 | + return { |
| 71 | + key: [fd.node for fd in field_details] |
| 72 | + for key, field_details in collected.grouped_field_set.items() |
| 73 | + } |
0 commit comments