Skip to content

Commit cf7f95e

Browse files
committed
Deep copy schema with directive with args of custom type
1 parent 8a02866 commit cf7f95e

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

src/graphql/type/schema.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
)
1414

1515
from ..error import GraphQLError
16-
from ..language import ast, OperationType
16+
from ..language import OperationType, ast
1717
from ..pyutils import inspect, is_collection, is_description
1818
from .definition import (
1919
GraphQLAbstractType,
20-
GraphQLInterfaceType,
2120
GraphQLInputObjectType,
21+
GraphQLInputType,
22+
GraphQLInterfaceType,
2223
GraphQLNamedType,
2324
GraphQLObjectType,
24-
GraphQLUnionType,
2525
GraphQLType,
26+
GraphQLUnionType,
2627
GraphQLWrappingType,
2728
get_named_type,
2829
is_input_object_type,
@@ -31,7 +32,7 @@
3132
is_union_type,
3233
is_wrapping_type,
3334
)
34-
from .directives import GraphQLDirective, specified_directives, is_directive
35+
from .directives import GraphQLDirective, is_directive, specified_directives
3536
from .introspection import introspection_types
3637

3738
try:
@@ -310,8 +311,8 @@ def __copy__(self) -> "GraphQLSchema": # pragma: no cover
310311
def __deepcopy__(self, memo_: Dict) -> "GraphQLSchema":
311312
from ..type import (
312313
is_introspection_type,
313-
is_specified_scalar_type,
314314
is_specified_directive,
315+
is_specified_scalar_type,
315316
)
316317

317318
type_map: TypeMap = {
@@ -326,6 +327,8 @@ def __deepcopy__(self, memo_: Dict) -> "GraphQLSchema":
326327
directive if is_specified_directive(directive) else copy(directive)
327328
for directive in self.directives
328329
]
330+
for directive in directives:
331+
remap_directive(directive, type_map)
329332
return self.__class__(
330333
self.query_type and cast(GraphQLObjectType, type_map[self.query_type.name]),
331334
self.mutation_type
@@ -461,12 +464,7 @@ def remapped_type(type_: GraphQLType, type_map: TypeMap) -> GraphQLType:
461464

462465
def remap_named_type(type_: GraphQLNamedType, type_map: TypeMap) -> None:
463466
"""Change all references in the given named type to use this type map."""
464-
if is_union_type(type_):
465-
type_ = cast(GraphQLUnionType, type_)
466-
type_.types = [
467-
type_map.get(member_type.name, member_type) for member_type in type_.types
468-
]
469-
elif is_object_type(type_) or is_interface_type(type_):
467+
if is_object_type(type_) or is_interface_type(type_):
470468
type_ = cast(Union[GraphQLObjectType, GraphQLInterfaceType], type_)
471469
type_.interfaces = [
472470
type_map.get(interface_type.name, interface_type)
@@ -482,10 +480,23 @@ def remap_named_type(type_: GraphQLNamedType, type_map: TypeMap) -> None:
482480
arg.type = remapped_type(arg.type, type_map)
483481
args[arg_name] = arg
484482
fields[field_name] = field
483+
elif is_union_type(type_):
484+
type_ = cast(GraphQLUnionType, type_)
485+
type_.types = [
486+
type_map.get(member_type.name, member_type) for member_type in type_.types
487+
]
485488
elif is_input_object_type(type_):
486489
type_ = cast(GraphQLInputObjectType, type_)
487490
fields = type_.fields
488491
for field_name, field in fields.items():
489492
field = copy(field)
490493
field.type = remapped_type(field.type, type_map)
491494
fields[field_name] = field
495+
496+
def remap_directive(directive: GraphQLDirective, type_map: TypeMap) -> None:
497+
"""Change all references in the given directive to use this type map."""
498+
args = directive.args
499+
for arg_name, arg in args.items():
500+
arg = copy(arg) # noqa: PLW2901
501+
arg.type = cast(GraphQLInputType, remapped_type(arg.type, type_map))
502+
args[arg_name] = arg

0 commit comments

Comments
 (0)