@@ -505,7 +505,28 @@ def create_build_operation_ast_method(self):
505
505
return_type = generate_name ("DocumentNode" ),
506
506
)
507
507
508
- def create_execute_custom_operation_method (self ):
508
+ def create_execute_custom_operation_method (self , async_client : bool ):
509
+ execute_call = generate_call (
510
+ func = generate_attribute (value = generate_name ("self" ), attr = "execute" ),
511
+ args = [
512
+ generate_call (
513
+ func = generate_name ("print_ast" ),
514
+ args = [generate_name ("operation_ast" )],
515
+ )
516
+ ],
517
+ keywords = [
518
+ generate_keyword (
519
+ arg = "variables" , value = generate_name ('combined_variables["values"]' )
520
+ ),
521
+ generate_keyword (
522
+ arg = "operation_name" , value = generate_name ("operation_name" )
523
+ ),
524
+ ],
525
+ )
526
+ response_value = (
527
+ generate_await (value = execute_call ) if async_client else execute_call
528
+ )
529
+
509
530
method_body = [
510
531
generate_assign (
511
532
targets = ["selections" ],
@@ -549,54 +570,31 @@ def create_execute_custom_operation_method(self):
549
570
],
550
571
),
551
572
),
552
- generate_assign (
553
- targets = ["response" ],
554
- value = generate_await (
555
- value = generate_call (
556
- func = generate_attribute (
557
- value = generate_name ("self" ),
558
- attr = "execute" ,
559
- ),
560
- args = [
561
- generate_call (
562
- func = generate_name ("print_ast" ),
563
- args = [generate_name ("operation_ast" )],
564
- )
565
- ],
566
- keywords = [
567
- generate_keyword (
568
- arg = "variables" ,
569
- value = generate_name ('combined_variables["values"]' ),
570
- ),
571
- generate_keyword (
572
- arg = "operation_name" ,
573
- value = generate_name ("operation_name" ),
574
- ),
575
- ],
576
- )
577
- ),
578
- ),
573
+ generate_assign (targets = ["response" ], value = response_value ),
579
574
generate_return (
580
575
value = generate_call (
581
576
func = generate_attribute (
582
- value = generate_name ("self" ),
583
- attr = "get_data" ,
577
+ value = generate_name ("self" ), attr = "get_data"
584
578
),
585
579
args = [generate_name ("response" )],
586
580
)
587
581
),
588
582
]
589
- return generate_async_method_definition (
583
+
584
+ method_definition = (
585
+ generate_async_method_definition
586
+ if async_client
587
+ else generate_method_definition
588
+ )
589
+
590
+ return method_definition (
590
591
name = "execute_custom_operation" ,
591
592
arguments = generate_arguments (
592
593
args = [
593
594
generate_arg ("self" ),
594
595
generate_arg ("*fields" , annotation = generate_name ("GraphQLField" )),
595
596
generate_arg (
596
- "operation_type" ,
597
- annotation = generate_name (
598
- "OperationType" ,
599
- ),
597
+ "operation_type" , annotation = generate_name ("OperationType" )
600
598
),
601
599
generate_arg ("operation_name" , annotation = generate_name ("str" )),
602
600
]
@@ -655,7 +653,7 @@ def create_build_selection_set(self):
655
653
),
656
654
)
657
655
658
- def add_execute_custom_operation_method (self ):
656
+ def add_execute_custom_operation_method (self , async_client : bool ):
659
657
self ._add_import (
660
658
generate_import_from (
661
659
[
@@ -679,13 +677,20 @@ def add_execute_custom_operation_method(self):
679
677
)
680
678
self ._add_import (generate_import_from ([DICT , TUPLE , LIST , ANY ], "typing" ))
681
679
682
- self ._class_def .body .append (self .create_execute_custom_operation_method ())
680
+ self ._class_def .body .append (
681
+ self .create_execute_custom_operation_method (async_client )
682
+ )
683
683
self ._class_def .body .append (self .create_combine_variables_method ())
684
684
self ._class_def .body .append (self .create_build_variable_definitions_method ())
685
685
self ._class_def .body .append (self .create_build_operation_ast_method ())
686
686
self ._class_def .body .append (self .create_build_selection_set ())
687
687
688
- def create_custom_operation_method (self , name , operation_type ):
688
+ def create_custom_operation_method (
689
+ self ,
690
+ name : str ,
691
+ operation_type : str ,
692
+ async_client : bool ,
693
+ ):
689
694
self ._add_import (
690
695
generate_import_from (
691
696
[
@@ -694,6 +699,55 @@ def create_custom_operation_method(self, name, operation_type):
694
699
GRAPHQL_MODULE ,
695
700
)
696
701
)
702
+ if async_client :
703
+ def_query = self ._create_async_operation_method (name , operation_type )
704
+ else :
705
+ def_query = self ._create_sync_operation_method (name , operation_type )
706
+ self ._class_def .body .append (def_query )
707
+
708
+ def _create_sync_operation_method (self , name : str , operation_type : str ):
709
+ body_return = generate_return (
710
+ value = generate_call (
711
+ func = generate_attribute (
712
+ value = generate_name ("self" ),
713
+ attr = "execute_custom_operation" ,
714
+ ),
715
+ args = [
716
+ generate_name ("*fields" ),
717
+ ],
718
+ keywords = [
719
+ generate_keyword (
720
+ arg = "operation_type" ,
721
+ value = generate_attribute (
722
+ value = generate_name ("OperationType" ),
723
+ attr = operation_type ,
724
+ ),
725
+ ),
726
+ generate_keyword (
727
+ arg = "operation_name" , value = generate_name ("operation_name" )
728
+ ),
729
+ ],
730
+ )
731
+ )
732
+
733
+ def_query = generate_method_definition (
734
+ name = name ,
735
+ arguments = generate_arguments (
736
+ args = [
737
+ generate_arg ("self" ),
738
+ generate_arg ("*fields" , annotation = generate_name ("GraphQLField" )),
739
+ generate_arg ("operation_name" , annotation = generate_name ("str" )),
740
+ ],
741
+ ),
742
+ body = [body_return ],
743
+ return_type = generate_subscript (
744
+ generate_name (DICT ),
745
+ generate_tuple ([generate_name ("str" ), generate_name ("Any" )]),
746
+ ),
747
+ )
748
+ return def_query
749
+
750
+ def _create_async_operation_method (self , name : str , operation_type : str ):
697
751
body_return = generate_return (
698
752
value = generate_await (
699
753
value = generate_call (
@@ -734,7 +788,7 @@ def create_custom_operation_method(self, name, operation_type):
734
788
generate_tuple ([generate_name ("str" ), generate_name ("Any" )]),
735
789
),
736
790
)
737
- self . _class_def . body . append ( async_def_query )
791
+ return async_def_query
738
792
739
793
def get_variable_names (self , arguments : ast .arguments ) -> Dict [str , str ]:
740
794
mapped_variable_names = [
0 commit comments