1
1
import ast
2
- from typing import List , Optional , Union , cast
2
+ from typing import Dict , List , Optional , Union , cast
3
+
4
+ from graphql import OperationDefinitionNode
3
5
4
6
from ..codegen import (
5
7
generate_ann_assign ,
22
24
generate_tuple ,
23
25
)
24
26
from ..plugins .manager import PluginManager
25
- from .constants import ANY , LIST , OPTIONAL , TYPING_MODULE
27
+ from .arguments import ArgumentsGenerator
28
+ from .constants import ANY , LIST , OPTIONAL , TYPING_MODULE , UNION
29
+ from .scalars import ScalarData , generate_scalar_imports
26
30
27
31
28
32
class ClientGenerator :
29
33
def __init__ (
30
34
self ,
31
35
name : str ,
32
36
base_client : str ,
37
+ enums_module_name : str ,
38
+ input_types_module_name : str ,
39
+ arguments_generator : ArgumentsGenerator ,
40
+ base_client_import : Optional [ast .ImportFrom ] = None ,
41
+ unset_import : Optional [ast .ImportFrom ] = None ,
42
+ custom_scalars : Optional [Dict [str , ScalarData ]] = None ,
33
43
plugin_manager : Optional [PluginManager ] = None ,
34
44
) -> None :
35
45
self .name = name
46
+ self .enums_module_name = enums_module_name
47
+ self .input_types_module_name = input_types_module_name
36
48
self .plugin_manager = plugin_manager
37
- self .class_def = generate_class_def (name = name , base_names = [base_client ])
38
- self .imports : list = [
39
- generate_import_from ([OPTIONAL , LIST , ANY ], TYPING_MODULE )
40
- ]
49
+ self .custom_scalars = custom_scalars if custom_scalars else {}
50
+ self .arguments_generator = arguments_generator
51
+
52
+ self ._imports : List [ast .ImportFrom ] = []
53
+ self ._add_import (
54
+ generate_import_from ([OPTIONAL , LIST , ANY , UNION ], TYPING_MODULE )
55
+ )
56
+ self ._add_import (base_client_import )
57
+ self ._add_import (unset_import )
41
58
59
+ self ._class_def = generate_class_def (name = name , base_names = [base_client ])
42
60
self ._gql_func_name = "gql"
43
61
self ._operation_str_variable = "query"
44
62
self ._variables_dict_variable = "variables"
@@ -47,49 +65,56 @@ def __init__(
47
65
48
66
def generate (self ) -> ast .Module :
49
67
"""Generate module with class definition of graphql client."""
68
+ self ._add_import (
69
+ generate_import_from (
70
+ names = self .arguments_generator .get_used_inputs (),
71
+ from_ = self .input_types_module_name ,
72
+ level = 1 ,
73
+ )
74
+ )
75
+ self ._add_import (
76
+ generate_import_from (
77
+ names = self .arguments_generator .get_used_enums (),
78
+ from_ = self .enums_module_name ,
79
+ level = 1 ,
80
+ )
81
+ )
82
+ for custom_scalar_name in self .arguments_generator .get_used_custom_scalars ():
83
+ scalar_data = self .custom_scalars [custom_scalar_name ]
84
+ for import_ in generate_scalar_imports (scalar_data ):
85
+ self ._add_import (import_ )
86
+
50
87
gql_func = self ._generate_gql_func ()
51
- gql_func .lineno = len (self .imports ) + 1
88
+ gql_func .lineno = len (self ._imports ) + 1
52
89
if self .plugin_manager :
53
90
gql_func = self .plugin_manager .generate_gql_function (gql_func )
54
91
55
- self .class_def .lineno = len (self .imports ) + 3
56
- if not self .class_def .body :
57
- self .class_def .body .append (ast .Pass ())
92
+ self ._class_def .lineno = len (self ._imports ) + 3
93
+ if not self ._class_def .body :
94
+ self ._class_def .body .append (ast .Pass ())
58
95
if self .plugin_manager :
59
- self .class_def = self .plugin_manager .generate_client_class (self .class_def )
96
+ self ._class_def = self .plugin_manager .generate_client_class (self ._class_def )
60
97
61
98
module = generate_module (
62
- body = self .imports + [gql_func , self .class_def ],
99
+ body = self ._imports + [gql_func , self ._class_def ],
63
100
)
64
101
if self .plugin_manager :
65
102
module = self .plugin_manager .generate_client_module (module )
66
103
return module
67
104
68
- def add_import (self , names : List [str ], from_ : str , level : int = 0 ) -> None :
69
- """Add import to be included in module file."""
70
- if not names :
71
- return
72
- import_ = generate_import_from (names = names , from_ = from_ , level = level )
73
- if self .plugin_manager :
74
- import_ = self .plugin_manager .generate_client_import (import_ )
75
- self .imports .append (import_ )
76
-
77
- def add_imports (self , imports : List [ast .ImportFrom ]):
78
- for import_ in imports :
79
- if self .plugin_manager :
80
- import_ = self .plugin_manager .generate_client_import (import_ )
81
- self .imports .append (import_ )
82
-
83
105
def add_method (
84
106
self ,
107
+ definition : OperationDefinitionNode ,
85
108
name : str ,
86
109
return_type : str ,
87
- arguments : ast .arguments ,
88
- arguments_dict : ast .Dict ,
110
+ return_type_module : str ,
89
111
operation_str : str ,
90
112
async_ : bool = True ,
91
113
):
92
114
"""Add method to client."""
115
+ arguments , arguments_dict = self .arguments_generator .generate (
116
+ definition .variable_definitions
117
+ )
93
118
method_def = (
94
119
self ._generate_async_method (
95
120
name = name ,
@@ -107,12 +132,23 @@ def add_method(
107
132
operation_str = operation_str ,
108
133
)
109
134
)
110
- method_def .lineno = len (self .class_def .body ) + 1
135
+ method_def .lineno = len (self ._class_def .body ) + 1
111
136
if self .plugin_manager :
112
137
method_def = self .plugin_manager .generate_client_method (
113
138
cast (Union [ast .FunctionDef , ast .AsyncFunctionDef ], method_def )
114
139
)
115
- self .class_def .body .append (method_def )
140
+ self ._class_def .body .append (method_def )
141
+ self ._add_import (
142
+ generate_import_from (names = [return_type ], from_ = return_type_module , level = 1 )
143
+ )
144
+
145
+ def _add_import (self , import_ : Optional [ast .ImportFrom ] = None ):
146
+ if not import_ :
147
+ return
148
+ if self .plugin_manager :
149
+ import_ = self .plugin_manager .generate_client_import (import_ )
150
+ if import_ .names and import_ .module :
151
+ self ._imports .append (import_ )
116
152
117
153
def _generate_async_method (
118
154
self ,
0 commit comments