Skip to content

Commit 5adba4f

Browse files
authored
Merge pull request #128 from mirumee/unset_values
Remove unset arguments and fields from variables payload
2 parents f218e20 + 385b8d1 commit 5adba4f

File tree

32 files changed

+719
-517
lines changed

32 files changed

+719
-517
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
- Changed logic how custom scalar imports are generated. Deprecated `import_` key.
66
- Added escaping of GraphQL names which are Python keywords by appending `_` to them.
77
- Fixed parsing of list variables.
8+
- Changed base clients to remove unset arguments and input fields from variables payload.
89

910

1011
## 0.5.0 (2023-04-05)

EXAMPLE.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,10 @@ Generated client class inherits from `AsyncBaseClient` and has async method for
156156
```py
157157
# graphql_client/client.py
158158

159-
from typing import Optional
159+
from typing import Optional, Union
160160

161161
from .async_base_client import AsyncBaseClient
162+
from .base_model import UNSET, UnsetType
162163
from .create_user import CreateUser
163164
from .input_types import UserCreateInput
164165
from .list_all_users import ListAllUsers
@@ -207,7 +208,7 @@ class Client(AsyncBaseClient):
207208
return ListAllUsers.parse_obj(data)
208209

209210
async def list_users_by_country(
210-
self, country: Optional[str] = None
211+
self, country: Union[Optional[str], UnsetType] = UNSET
211212
) -> ListUsersByCountry:
212213
query = gql(
213214
"""

ariadne_codegen/client_generators/arguments.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ast
2-
from typing import Dict, List, Optional, Tuple, Union
2+
from typing import Dict, List, Optional, Tuple, Union, cast
33

44
from graphql import (
55
GraphQLEnumType,
@@ -22,11 +22,12 @@
2222
generate_dict,
2323
generate_list_annotation,
2424
generate_name,
25+
generate_union_annotation,
2526
)
2627
from ..exceptions import ParsingError
2728
from ..plugins.manager import PluginManager
2829
from ..utils import process_name
29-
from .constants import ANY, OPTIONAL, SIMPLE_TYPE_MAP
30+
from .constants import ANY, OPTIONAL, SIMPLE_TYPE_MAP, UNSET_NAME, UNSET_TYPE_NAME
3031
from .scalars import ScalarData
3132

3233

@@ -63,7 +64,10 @@ def generate(
6364
)
6465

6566
arg = generate_arg(name, annotation)
66-
if self.is_nullable(annotation):
67+
if self._is_nullable(annotation):
68+
arg.annotation = self._process_optional_arg_annotation(
69+
cast(ast.Subscript, annotation)
70+
)
6771
optional_args.append(arg)
6872
else:
6973
required_args.append(arg)
@@ -73,7 +77,7 @@ def generate(
7377

7478
arguments = generate_arguments(
7579
args=required_args + optional_args,
76-
defaults=[generate_constant(None) for _ in optional_args],
80+
defaults=[generate_name(UNSET_NAME) for _ in optional_args],
7781
)
7882

7983
if self.plugin_manager:
@@ -140,13 +144,20 @@ def _parse_named_type_node(
140144

141145
return generate_annotation_name(name, nullable), used_custom_scalar
142146

143-
def is_nullable(self, annotation: Union[ast.Name, ast.Subscript]) -> bool:
147+
def _is_nullable(self, annotation: Union[ast.Name, ast.Subscript]) -> bool:
144148
return (
145149
isinstance(annotation, ast.Subscript)
146150
and isinstance(annotation.value, ast.Name)
147151
and annotation.value.id == OPTIONAL
148152
)
149153

154+
def _process_optional_arg_annotation(
155+
self, annotation: ast.Subscript
156+
) -> ast.Subscript:
157+
return generate_union_annotation(
158+
types=[annotation, generate_name(UNSET_TYPE_NAME)], nullable=False
159+
)
160+
150161
def _get_dict_value(
151162
self, name: str, used_custom_scalar: Optional[str]
152163
) -> Union[ast.Name, ast.Call]:

ariadne_codegen/client_generators/client.py

+67-31
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import ast
2-
from typing import List, Optional, Union, cast
2+
from typing import Dict, List, Optional, Union, cast
3+
4+
from graphql import OperationDefinitionNode
35

46
from ..codegen import (
57
generate_ann_assign,
@@ -22,23 +24,39 @@
2224
generate_tuple,
2325
)
2426
from ..plugins.manager import PluginManager
25-
from .constants import ANY, LIST, OPTIONAL, TYPING_MODULE
27+
from .arguments import ArgumentsGenerator
28+
from .constants import ANY, LIST, OPTIONAL, TYPING_MODULE, UNION
29+
from .scalars import ScalarData, generate_scalar_imports
2630

2731

2832
class ClientGenerator:
2933
def __init__(
3034
self,
3135
name: str,
3236
base_client: str,
37+
enums_module_name: str,
38+
input_types_module_name: str,
39+
arguments_generator: ArgumentsGenerator,
40+
base_client_import: Optional[ast.ImportFrom] = None,
41+
unset_import: Optional[ast.ImportFrom] = None,
42+
custom_scalars: Optional[Dict[str, ScalarData]] = None,
3343
plugin_manager: Optional[PluginManager] = None,
3444
) -> None:
3545
self.name = name
46+
self.enums_module_name = enums_module_name
47+
self.input_types_module_name = input_types_module_name
3648
self.plugin_manager = plugin_manager
37-
self.class_def = generate_class_def(name=name, base_names=[base_client])
38-
self.imports: list = [
39-
generate_import_from([OPTIONAL, LIST, ANY], TYPING_MODULE)
40-
]
49+
self.custom_scalars = custom_scalars if custom_scalars else {}
50+
self.arguments_generator = arguments_generator
51+
52+
self._imports: List[ast.ImportFrom] = []
53+
self._add_import(
54+
generate_import_from([OPTIONAL, LIST, ANY, UNION], TYPING_MODULE)
55+
)
56+
self._add_import(base_client_import)
57+
self._add_import(unset_import)
4158

59+
self._class_def = generate_class_def(name=name, base_names=[base_client])
4260
self._gql_func_name = "gql"
4361
self._operation_str_variable = "query"
4462
self._variables_dict_variable = "variables"
@@ -47,49 +65,56 @@ def __init__(
4765

4866
def generate(self) -> ast.Module:
4967
"""Generate module with class definition of graphql client."""
68+
self._add_import(
69+
generate_import_from(
70+
names=self.arguments_generator.get_used_inputs(),
71+
from_=self.input_types_module_name,
72+
level=1,
73+
)
74+
)
75+
self._add_import(
76+
generate_import_from(
77+
names=self.arguments_generator.get_used_enums(),
78+
from_=self.enums_module_name,
79+
level=1,
80+
)
81+
)
82+
for custom_scalar_name in self.arguments_generator.get_used_custom_scalars():
83+
scalar_data = self.custom_scalars[custom_scalar_name]
84+
for import_ in generate_scalar_imports(scalar_data):
85+
self._add_import(import_)
86+
5087
gql_func = self._generate_gql_func()
51-
gql_func.lineno = len(self.imports) + 1
88+
gql_func.lineno = len(self._imports) + 1
5289
if self.plugin_manager:
5390
gql_func = self.plugin_manager.generate_gql_function(gql_func)
5491

55-
self.class_def.lineno = len(self.imports) + 3
56-
if not self.class_def.body:
57-
self.class_def.body.append(ast.Pass())
92+
self._class_def.lineno = len(self._imports) + 3
93+
if not self._class_def.body:
94+
self._class_def.body.append(ast.Pass())
5895
if self.plugin_manager:
59-
self.class_def = self.plugin_manager.generate_client_class(self.class_def)
96+
self._class_def = self.plugin_manager.generate_client_class(self._class_def)
6097

6198
module = generate_module(
62-
body=self.imports + [gql_func, self.class_def],
99+
body=self._imports + [gql_func, self._class_def],
63100
)
64101
if self.plugin_manager:
65102
module = self.plugin_manager.generate_client_module(module)
66103
return module
67104

68-
def add_import(self, names: List[str], from_: str, level: int = 0) -> None:
69-
"""Add import to be included in module file."""
70-
if not names:
71-
return
72-
import_ = generate_import_from(names=names, from_=from_, level=level)
73-
if self.plugin_manager:
74-
import_ = self.plugin_manager.generate_client_import(import_)
75-
self.imports.append(import_)
76-
77-
def add_imports(self, imports: List[ast.ImportFrom]):
78-
for import_ in imports:
79-
if self.plugin_manager:
80-
import_ = self.plugin_manager.generate_client_import(import_)
81-
self.imports.append(import_)
82-
83105
def add_method(
84106
self,
107+
definition: OperationDefinitionNode,
85108
name: str,
86109
return_type: str,
87-
arguments: ast.arguments,
88-
arguments_dict: ast.Dict,
110+
return_type_module: str,
89111
operation_str: str,
90112
async_: bool = True,
91113
):
92114
"""Add method to client."""
115+
arguments, arguments_dict = self.arguments_generator.generate(
116+
definition.variable_definitions
117+
)
93118
method_def = (
94119
self._generate_async_method(
95120
name=name,
@@ -107,12 +132,23 @@ def add_method(
107132
operation_str=operation_str,
108133
)
109134
)
110-
method_def.lineno = len(self.class_def.body) + 1
135+
method_def.lineno = len(self._class_def.body) + 1
111136
if self.plugin_manager:
112137
method_def = self.plugin_manager.generate_client_method(
113138
cast(Union[ast.FunctionDef, ast.AsyncFunctionDef], method_def)
114139
)
115-
self.class_def.body.append(method_def)
140+
self._class_def.body.append(method_def)
141+
self._add_import(
142+
generate_import_from(names=[return_type], from_=return_type_module, level=1)
143+
)
144+
145+
def _add_import(self, import_: Optional[ast.ImportFrom] = None):
146+
if not import_:
147+
return
148+
if self.plugin_manager:
149+
import_ = self.plugin_manager.generate_client_import(import_)
150+
if import_.names and import_.module:
151+
self._imports.append(import_)
116152

117153
def _generate_async_method(
118154
self,

ariadne_codegen/client_generators/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@
5252

5353
SCALARS_PARSE_DICT_NAME = "SCALARS_PARSE_FUNCTIONS"
5454
SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS"
55+
56+
UNSET_NAME = "UNSET"
57+
UNSET_TYPE_NAME = "UnsetType"

ariadne_codegen/client_generators/dependencies/async_base_client.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import httpx
44
from pydantic import BaseModel
55

6+
from .base_model import UNSET
67
from .exceptions import (
78
GraphQLClientGraphQLMultiError,
89
GraphQLClientHttpError,
@@ -71,12 +72,16 @@ def get_data(self, response: httpx.Response) -> dict[str, Any]:
7172

7273
def _convert_value(self, value: Any) -> Any:
7374
if isinstance(value, BaseModel):
74-
return value.dict(by_alias=True)
75+
return value.dict(by_alias=True, exclude_unset=True)
7576
if isinstance(value, list):
7677
return [self._convert_value(item) for item in value]
7778
return value
7879

7980
def _convert_dict_to_json_serializable(
8081
self, dict_: Dict[str, Any]
8182
) -> Dict[str, Any]:
82-
return {key: self._convert_value(value) for key, value in dict_.items()}
83+
return {
84+
key: self._convert_value(value)
85+
for key, value in dict_.items()
86+
if value is not UNSET
87+
}

ariadne_codegen/client_generators/dependencies/base_client.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import httpx
44
from pydantic import BaseModel
55

6+
from .base_model import UNSET
67
from .exceptions import (
78
GraphQLClientGraphQLMultiError,
89
GraphQLClientHttpError,
@@ -69,12 +70,16 @@ def get_data(self, response: httpx.Response) -> dict[str, Any]:
6970

7071
def _convert_value(self, value: Any) -> Any:
7172
if isinstance(value, BaseModel):
72-
return value.dict(by_alias=True)
73+
return value.dict(by_alias=True, exclude_unset=True)
7374
if isinstance(value, list):
7475
return [self._convert_value(item) for item in value]
7576
return value
7677

7778
def _convert_dict_to_json_serializable(
7879
self, dict_: Dict[str, Any]
7980
) -> Dict[str, Any]:
80-
return {key: self._convert_value(value) for key, value in dict_.items()}
81+
return {
82+
key: self._convert_value(value)
83+
for key, value in dict_.items()
84+
if value is not UNSET
85+
}

ariadne_codegen/client_generators/dependencies/base_model.py

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
from .scalars import SCALARS_PARSE_FUNCTIONS, SCALARS_SERIALIZE_FUNCTIONS
88

99

10+
class UnsetType:
11+
def __bool__(self) -> bool:
12+
return False
13+
14+
15+
UNSET = UnsetType()
16+
17+
1018
class BaseModel(PydanticBaseModel):
1119
class Config:
1220
allow_population_by_field_name = True

0 commit comments

Comments
 (0)