Skip to content

Commit a82ecb8

Browse files
Add custom operation generation (#296)
Add custom operation generation
1 parent cb576f8 commit a82ecb8

40 files changed

+3948
-14
lines changed

ariadne_codegen/client_generators/client.py

+201-2
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
generate_await,
1515
generate_call,
1616
generate_class_def,
17+
generate_comp,
1718
generate_constant,
1819
generate_expr,
1920
generate_import_from,
2021
generate_keyword,
22+
generate_list,
23+
generate_list_comp,
2124
generate_method_definition,
2225
generate_module,
2326
generate_name,
@@ -32,11 +35,20 @@
3235
from .constants import (
3336
ANY,
3437
ASYNC_ITERATOR,
38+
BASE_GRAPHQL_FIELD_CLASS_NAME,
39+
BASE_OPERATION_FILE_PATH,
3540
DICT,
41+
DOCUMENT_NODE,
42+
GRAPHQL_MODULE,
3643
KWARGS_NAMES,
3744
LIST,
3845
MODEL_VALIDATE_METHOD,
46+
NAME_NODE,
47+
OPERATION_DEFINITION_NODE,
48+
OPERATION_TYPE,
3949
OPTIONAL,
50+
PRINT_AST,
51+
SELECTION_SET_NODE,
4052
TYPING_MODULE,
4153
UNION,
4254
UNSET_IMPORT,
@@ -66,10 +78,18 @@ def __init__(
6678
self.custom_scalars = custom_scalars if custom_scalars else {}
6779
self.arguments_generator = arguments_generator
6880

69-
self._imports: List[ast.ImportFrom] = []
81+
self._imports: List[Union[ast.ImportFrom, ast.Import]] = []
7082
self._add_import(
7183
generate_import_from(
72-
[OPTIONAL, LIST, DICT, ANY, UNION, ASYNC_ITERATOR], TYPING_MODULE
84+
[
85+
OPTIONAL,
86+
LIST,
87+
DICT,
88+
ANY,
89+
UNION,
90+
ASYNC_ITERATOR,
91+
],
92+
TYPING_MODULE,
7393
)
7494
)
7595
self._add_import(base_client_import)
@@ -187,6 +207,185 @@ def add_method(
187207
generate_import_from(names=[return_type], from_=return_type_module, level=1)
188208
)
189209

210+
def add_execute_custom_operation_method(self):
211+
self._add_import(
212+
generate_import_from(
213+
[
214+
DOCUMENT_NODE,
215+
OPERATION_DEFINITION_NODE,
216+
NAME_NODE,
217+
SELECTION_SET_NODE,
218+
PRINT_AST,
219+
],
220+
GRAPHQL_MODULE,
221+
)
222+
)
223+
self._add_import(
224+
generate_import_from(
225+
[BASE_GRAPHQL_FIELD_CLASS_NAME], BASE_OPERATION_FILE_PATH.stem, level=1
226+
)
227+
)
228+
execute_await = generate_await(
229+
value=generate_call(
230+
func=generate_attribute(value=generate_name("self"), attr="execute"),
231+
args=[
232+
generate_call(
233+
func=generate_name("print_ast"),
234+
args=[generate_name("operation_ast")],
235+
)
236+
],
237+
keywords=[
238+
generate_keyword(
239+
arg="operation_name", value=generate_name("operation_name")
240+
)
241+
],
242+
)
243+
)
244+
245+
operation_definition_node = generate_call(
246+
func=generate_name("OperationDefinitionNode"),
247+
keywords=[
248+
generate_keyword(
249+
arg="operation", value=generate_name("operation_type")
250+
),
251+
generate_keyword(
252+
arg="name",
253+
value=generate_call(
254+
func=generate_name("NameNode"),
255+
keywords=[
256+
generate_keyword(
257+
arg="value", value=generate_name("operation_name")
258+
)
259+
],
260+
),
261+
),
262+
generate_keyword(
263+
arg="selection_set",
264+
value=generate_call(
265+
func=generate_name("SelectionSetNode"),
266+
keywords=[
267+
generate_keyword(
268+
arg="selections",
269+
value=generate_list_comp(
270+
elt=generate_call(
271+
func=generate_attribute(
272+
value=generate_name("field"),
273+
attr="to_ast",
274+
),
275+
),
276+
generators=[
277+
generate_comp(
278+
target="field",
279+
iter_="fields",
280+
)
281+
],
282+
),
283+
)
284+
],
285+
),
286+
),
287+
],
288+
)
289+
operation_ast = generate_call(
290+
func=generate_name("DocumentNode"),
291+
keywords=[
292+
generate_keyword(
293+
arg="definitions",
294+
value=generate_list(elements=[operation_definition_node]),
295+
)
296+
],
297+
)
298+
body_return = generate_return(
299+
value=generate_call(
300+
func=generate_attribute(value=generate_name("self"), attr="get_data"),
301+
args=[generate_name("response")],
302+
)
303+
)
304+
async_def_node = generate_async_method_definition(
305+
name="execute_custom_operation",
306+
arguments=generate_arguments(
307+
args=[
308+
generate_arg("self"),
309+
generate_arg(
310+
"*fields",
311+
annotation=generate_name("GraphQLField"),
312+
),
313+
generate_arg(
314+
"operation_type",
315+
annotation=generate_name("OperationType"),
316+
),
317+
generate_arg("operation_name", annotation=generate_name("str")),
318+
],
319+
),
320+
body=[
321+
generate_assign(
322+
targets=["operation_ast"],
323+
value=operation_ast,
324+
),
325+
generate_assign(
326+
targets=["response"],
327+
value=execute_await,
328+
),
329+
body_return,
330+
],
331+
return_type=generate_subscript(
332+
generate_name(DICT),
333+
generate_tuple([generate_name("str"), generate_name("Any")]),
334+
),
335+
)
336+
self._class_def.body.append(async_def_node)
337+
338+
def create_custom_operation_method(self, name, operation_type):
339+
self._add_import(
340+
generate_import_from(
341+
[
342+
OPERATION_TYPE,
343+
],
344+
GRAPHQL_MODULE,
345+
)
346+
)
347+
body_return = generate_return(
348+
value=generate_await(
349+
value=generate_call(
350+
func=generate_attribute(
351+
value=generate_name("self"),
352+
attr="execute_custom_operation",
353+
),
354+
args=[
355+
generate_name("*fields"),
356+
],
357+
keywords=[
358+
generate_keyword(
359+
arg="operation_type",
360+
value=generate_attribute(
361+
value=generate_name("OperationType"),
362+
attr=operation_type,
363+
),
364+
),
365+
generate_keyword(
366+
arg="operation_name", value=generate_name("operation_name")
367+
),
368+
],
369+
)
370+
)
371+
)
372+
async_def_query = generate_async_method_definition(
373+
name=name,
374+
arguments=generate_arguments(
375+
args=[
376+
generate_arg("self"),
377+
generate_arg("*fields", annotation=generate_name("GraphQLField")),
378+
generate_arg("operation_name", annotation=generate_name("str")),
379+
],
380+
),
381+
body=[body_return],
382+
return_type=generate_subscript(
383+
generate_name(DICT),
384+
generate_tuple([generate_name("str"), generate_name("Any")]),
385+
),
386+
)
387+
self._class_def.body.append(async_def_query)
388+
190389
def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]:
191390
mapped_variable_names = [
192391
self._operation_str_variable,

ariadne_codegen/client_generators/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,34 @@
1616
LIST = "List"
1717
UNION = "Union"
1818
ANY = "Any"
19+
TYPE = "Type"
20+
TYPE_CHECKING = "TYPE_CHECKING"
1921
DICT = "Dict"
2022
CALLABLE = "Callable"
2123
ANNOTATED = "Annotated"
2224
LITERAL = "Literal"
2325
ASYNC_ITERATOR = "AsyncIterator"
26+
DOCUMENT_NODE = "DocumentNode"
27+
OPERATION_DEFINITION_NODE = "OperationDefinitionNode"
28+
NAME_NODE = "NameNode"
29+
SELECTION_SET_NODE = "SelectionSetNode"
30+
PRINT_AST = "print_ast"
31+
OPERATION_TYPE = "OperationType"
32+
33+
HTTPX = "httpx"
34+
HTTPX_RESPONSE = "httpx.Response"
2435

2536
TIMESTAMP_COMMENT = "# Generated by ariadne-codegen on {}"
2637
STABLE_COMMENT = "# Generated by ariadne-codegen"
2738
SOURCE_COMMENT = "# Source: {}"
2839
COMMENT_DATETIME_FORMAT = "%Y-%m-%d %H:%M"
2940

41+
BASE_OPERATION_FILE_PATH = Path(__file__).parent / "dependencies" / "base_operation.py"
42+
BASE_GRAPHQL_OPERATION_CLASS_NAME = "BaseGraphQLOperation"
43+
BASE_GRAPHQL_FIELD_CLASS_NAME = "GraphQLField"
44+
CUSTOM_FIELDS_FILE_PATH = Path(__file__).parent / "custom_fields.py"
45+
CUSTOM_FIELDS_TYPING_FILE_PATH = Path(__file__).parent / "custom_typing_fields.py"
46+
3047
BASE_MODEL_FILE_PATH = Path(__file__).parent / "dependencies" / "base_model.py"
3148
BASE_MODEL_CLASS_NAME = "BaseModel"
3249
BASE_MODEL_IMPORT = ast.ImportFrom(
@@ -49,6 +66,7 @@
4966
TYPENAME_ALIAS = "typename__"
5067

5168
TYPING_MODULE = "typing"
69+
GRAPHQL_MODULE = "graphql"
5270
PYDANTIC_MODULE = "pydantic"
5371
FIELD_CLASS = "Field"
5472
ALIAS_KEYWORD = "alias"
@@ -100,3 +118,5 @@
100118

101119
SCALARS_PARSE_DICT_NAME = "SCALARS_PARSE_FUNCTIONS"
102120
SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS"
121+
122+
OPERATION_TYPES = ("Query", "Mutation", "Subscription")

0 commit comments

Comments
 (0)