Skip to content

Commit f408f12

Browse files
committed
Add convert_case support to MutationType, make convert_case util less magic
1 parent 71443e8 commit f408f12

File tree

6 files changed

+97
-54
lines changed

6 files changed

+97
-54
lines changed
+43-45
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,42 @@
1-
from typing import Dict, Optional, Union, cast
1+
from functools import partial
2+
from typing import Optional, Union
23

34
from ariadne import convert_camel_case_to_snake
4-
from graphql import DefinitionNode
5+
from graphql import FieldDefinitionNode
56

6-
from .types import FieldsDict
7-
8-
Overrides = Dict[str, str]
9-
ArgsOverrides = Dict[str, Overrides]
7+
from .types import FieldsDict, InputFieldsDict
108

119

1210
def convert_case(
13-
overrides_or_fields: Optional[Union[FieldsDict, dict]] = None,
14-
map_fields_args=False,
11+
overrides: Optional[dict] = None,
12+
*,
13+
object_fields: Optional[Union[FieldsDict, InputFieldsDict]] = None,
14+
fields_args: Optional[FieldsDict] = None,
15+
field_args: Optional[FieldDefinitionNode] = None,
1516
):
16-
no_args_call = convert_case_call_without_args(overrides_or_fields)
17-
18-
overrides = {}
19-
if not no_args_call:
20-
overrides = cast(dict, overrides_or_fields)
21-
22-
def create_case_mappings(fields: FieldsDict, map_fields_args=False):
23-
if map_fields_args:
24-
return convert_args_cas(fields, overrides)
25-
26-
return convert_aliases_case(fields, overrides)
27-
28-
if no_args_call:
29-
fields = cast(FieldsDict, overrides_or_fields)
30-
return create_case_mappings(fields, map_fields_args)
31-
32-
return create_case_mappings
17+
if overrides and not object_fields and not fields_args and not field_args:
18+
return partial(convert_case, overrides)
3319

20+
if object_fields:
21+
return convert_object_fields_case(object_fields, overrides or {})
3422

35-
def convert_case_call_without_args(
36-
overrides_or_fields: Optional[Union[FieldsDict, dict]] = None
37-
) -> bool:
38-
if overrides_or_fields is None:
39-
return True
23+
if fields_args:
24+
return convert_fields_args_case(fields_args, overrides or {})
4025

41-
if isinstance(list(overrides_or_fields.values())[0], DefinitionNode):
42-
return True
26+
if field_args:
27+
return convert_field_args_case(field_args, overrides or {})
4328

44-
return False
29+
raise ValueError(
30+
"convert_case was called without any arguments. "
31+
"If you meant to use it for automatic case conversion, remove call "
32+
"(convert_case() -> convert_case) or call it with dict of overrides "
33+
"as only argument."
34+
)
4535

4636

47-
def convert_aliases_case(fields: FieldsDict, overrides: Overrides) -> Overrides:
37+
def convert_object_fields_case(
38+
fields: Union[FieldsDict, InputFieldsDict], overrides: dict
39+
):
4840
final_mappings = {}
4941
for field_name in fields:
5042
if field_name in overrides:
@@ -56,19 +48,25 @@ def convert_aliases_case(fields: FieldsDict, overrides: Overrides) -> Overrides:
5648
return final_mappings
5749

5850

59-
def convert_args_cas(fields: FieldsDict, overrides: ArgsOverrides) -> ArgsOverrides:
51+
def convert_fields_args_case(fields: FieldsDict, overrides: dict):
6052
final_mappings = {}
6153
for field_name, field_def in fields.items():
62-
arg_overrides: Overrides = overrides.get(field_name, {})
63-
arg_mappings = {}
64-
for arg in field_def.arguments:
65-
arg_name = arg.name.value
66-
if arg_name in arg_overrides:
67-
arg_name_final = arg_overrides[arg_name]
68-
else:
69-
arg_name_final = convert_camel_case_to_snake(arg_name)
70-
if arg_name != arg_name_final:
71-
arg_mappings[arg_name] = arg_name_final
54+
arg_mappings = convert_field_args_case(
55+
field_def, overrides.get(field_name) or {}
56+
)
7257
if arg_mappings:
7358
final_mappings[field_name] = arg_mappings
7459
return final_mappings
60+
61+
62+
def convert_field_args_case(field: FieldDefinitionNode, overrides: dict):
63+
final_mappings = {}
64+
for arg in field.arguments:
65+
arg_name = arg.name.value
66+
if arg_name in overrides:
67+
arg_name_final = overrides[arg_name]
68+
else:
69+
arg_name_final = convert_camel_case_to_snake(arg_name)
70+
if arg_name != arg_name_final:
71+
final_mappings[arg_name] = arg_name_final
72+
return final_mappings

ariadne_graphql_modules/input_type.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init_subclass__(cls) -> None:
3939

4040
if callable(cls.__args__):
4141
# pylint: disable=not-callable
42-
cls.__args__ = cls.__args__(cls.graphql_fields)
42+
cls.__args__ = cls.__args__(object_fields=cls.graphql_fields)
4343

4444
cls.__validate_args__()
4545

ariadne_graphql_modules/interface_type.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def __init_subclass__(cls) -> None:
5656
cls.__validate_requirements__(requirements, dependencies)
5757

5858
if callable(cls.__fields_args__):
59-
cls.__fields_args__ = cls.__fields_args__(cls.graphql_fields, True)
59+
cls.__fields_args__ = cls.__fields_args__(fields_args=cls.graphql_fields)
6060

6161
cls.__validate_fields_args__()
6262

6363
if callable(cls.__aliases__):
64-
cls.__aliases__ = cls.__aliases__(cls.graphql_fields)
64+
cls.__aliases__ = cls.__aliases__(object_fields=cls.graphql_fields)
6565

6666
cls.__validate_aliases__()
6767
cls.resolvers = cls.__get_resolvers__()

ariadne_graphql_modules/mutation_type.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Type, Union, cast
1+
from typing import Callable, Dict, List, Optional, Type, Union, cast
22

33
from graphql import (
44
DefinitionNode,
@@ -13,12 +13,13 @@
1313
from .types import RequirementsDict
1414
from .utils import parse_definition
1515

16+
MutationArgs = Dict[str, str]
1617
ObjectNodeType = Union[ObjectTypeDefinitionNode, ObjectTypeExtensionNode]
1718

1819

1920
class MutationType(BindableType):
2021
__abstract__ = True
21-
__args__: Optional[Dict[str, str]] = None
22+
__args__: Optional[Union[MutationArgs, Callable[..., MutationArgs]]] = None
2223

2324
graphql_name = "Mutation"
2425
graphql_type: Union[Type[ObjectTypeDefinitionNode], Type[ObjectTypeExtensionNode]]
@@ -50,7 +51,12 @@ def __init_subclass__(cls) -> None:
5051
dependencies = cls.__get_dependencies__(graphql_def)
5152
cls.__validate_requirements__(requirements, dependencies)
5253

54+
if callable(cls.__args__):
55+
# pylint: disable=not-callable
56+
cls.__args__ = cls.__args__(field_args=field)
57+
5358
cls.__validate_args__(field)
59+
5460
cls.__validate_resolve_mutation__()
5561

5662
@classmethod
@@ -117,7 +123,7 @@ def __validate_args__(cls, field: FieldDefinitionNode):
117123
return
118124

119125
field_args = [arg.name.value for arg in field.arguments]
120-
invalid_args = set(cls.__args__) - set(field_args)
126+
invalid_args = set(cast(List[str], cls.__args__)) - set(field_args)
121127
if invalid_args:
122128
raise ValueError(
123129
f"{cls.__name__} class was defined with args not on "

ariadne_graphql_modules/object_type.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ def __init_subclass__(cls) -> None:
4747
cls.__validate_requirements__(requirements, dependencies)
4848

4949
if callable(cls.__fields_args__):
50-
cls.__fields_args__ = cls.__fields_args__(cls.graphql_fields, True)
50+
cls.__fields_args__ = cls.__fields_args__(fields_args=cls.graphql_fields)
5151

5252
cls.__validate_fields_args__()
5353

5454
if callable(cls.__aliases__):
55-
cls.__aliases__ = cls.__aliases__(cls.graphql_fields)
55+
cls.__aliases__ = cls.__aliases__(object_fields=cls.graphql_fields)
5656

5757
cls.__validate_aliases__()
5858
cls.resolvers = cls.__get_resolvers__()

tests/test_convert_case.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ariadne_graphql_modules import InputType, ObjectType, convert_case
1+
from ariadne_graphql_modules import InputType, MutationType, ObjectType, convert_case
22

33

44
def test_cases_are_mapped_for_aliases():
@@ -104,3 +104,42 @@ class ExampleObject(ObjectType):
104104
"thirdArg": "third_arg",
105105
},
106106
}
107+
108+
109+
def test_cases_are_mapped_for_mutation_args():
110+
class ExampleMutation(MutationType):
111+
__schema__ = """
112+
type Mutation {
113+
field(arg: Int, secondArg: Int): Int
114+
}
115+
"""
116+
__args__ = convert_case
117+
118+
@staticmethod
119+
def resolve_mutation(*_, **__):
120+
return 42
121+
122+
assert ExampleMutation.__args__ == {"secondArg": "second_arg"}
123+
124+
125+
def test_cases_are_mapped_for_mutation_args_with_overrides():
126+
class ExampleMutation(MutationType):
127+
__schema__ = """
128+
type Mutation {
129+
field(arg: Int, secondArg: Int): Int
130+
}
131+
"""
132+
__args__ = convert_case(
133+
{
134+
"arg": "override",
135+
}
136+
)
137+
138+
@staticmethod
139+
def resolve_mutation(*_, **__):
140+
return 42
141+
142+
assert ExampleMutation.__args__ == {
143+
"arg": "override",
144+
"secondArg": "second_arg",
145+
}

0 commit comments

Comments
 (0)