12
12
from dataclasses import dataclass
13
13
from functools import partial
14
14
from pathlib import Path , PurePath
15
- from typing import Callable , TypedDict
15
+ from typing import Callable , Literal , TypedDict , cast
16
16
17
17
from pants .base .build_environment import get_buildroot
18
18
from pants .base .exiter import PANTS_SUCCEEDED_EXIT_CODE , ExitCode
@@ -225,8 +225,6 @@ def _perform_replacements_on_file(self, file: Path, replacements: list[Replaceme
225
225
This function is not idempotent, so it should be run only once per file (per migration plan).
226
226
227
227
Replacement imports are bulk added below the existing "pants.engine.rules" import.
228
-
229
-
230
228
"""
231
229
232
230
imports_added = False
@@ -241,7 +239,7 @@ def _perform_replacements_on_file(self, file: Path, replacements: list[Replaceme
241
239
modified = False
242
240
for replacement in replacements :
243
241
if line_number == replacement .line_range [0 ]:
244
- line_end = ",\n " if replacement .is_argument else "\n "
242
+ line_end = ",\n " if replacement .add_trailing_comma else "\n "
245
243
# On the first line of the range, emit the new source code where the old "Get" started
246
244
print (line [: replacement .col_range [0 ]], end = "" )
247
245
print (ast .unparse (replacement .new_source ), end = line_end )
@@ -321,7 +319,8 @@ class Replacement:
321
319
current_source : ast .Call
322
320
new_source : ast .Call
323
321
additional_imports : list [ast .ImportFrom ]
324
- is_argument : bool = False
322
+ # TODO: Use libcst or another CST, rather than an ast
323
+ add_trailing_comma : bool = False
325
324
326
325
def sanitized_imports (self ) -> list [ast .ImportFrom ]:
327
326
"""Remove any circular or self-imports."""
@@ -361,7 +360,7 @@ def __str__(self) -> str:
361
360
current_source={ ast .dump (self .current_source , indent = 2 )} ,
362
361
new_source={ ast .dump (self .new_source , indent = 2 )} ,
363
362
additional_imports={ [ast .dump (i , indent = 2 ) for i in self .additional_imports ]} ,
364
- is_argument ={ self .is_argument }
363
+ add_trailing_comma ={ self .add_trailing_comma }
365
364
)
366
365
"""
367
366
@@ -613,19 +612,33 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
613
612
if not self ._should_visit_node (node .decorator_list ):
614
613
return
615
614
616
- for child in node .body :
617
- if call := self ._maybe_replaceable_call (child ):
618
- if replacement := self .syntax_mapper .map_get_to_new_syntax (
619
- call , self .filename , calling_func = node .name
620
- ):
621
- self .replacements .append (replacement )
615
+ self ._recurse_body_statements (node .name , node )
616
+
617
+ def _recurse_body_statements (self , root : str , node : ast .stmt ):
618
+ """Recursively walk the body of a node, including properties of compound statements looking
619
+ for Get() calls to replace.
620
+
621
+ https://docs.python.org/3/reference/compound_stmts.html
622
+ """
623
+ for prop in ["body" , "handlers" , "orelse" , "finalbody" ]:
624
+ for child in getattr (node , prop , []):
625
+ self ._recurse_body_statements (root , cast (ast .stmt , child ))
622
626
623
- for call in self ._maybe_replaceable_multiget (child ):
624
- if replacement := self .syntax_mapper .map_get_to_new_syntax (
625
- call , self .filename , calling_func = node .name
626
- ):
627
- replacement .is_argument = True
628
- self .replacements .append (replacement )
627
+ self ._maybe_add_replacements (root , node )
628
+
629
+ def _maybe_add_replacements (self , calling_func : str , statement : ast .stmt ):
630
+ if call := self ._maybe_replaceable_call (statement ):
631
+ if replacement := self .syntax_mapper .map_get_to_new_syntax (
632
+ call , self .filename , calling_func = calling_func
633
+ ):
634
+ self .replacements .append (replacement )
635
+
636
+ for call , statement_type in self ._maybe_replaceable_multiget (statement ):
637
+ if replacement := self .syntax_mapper .map_get_to_new_syntax (
638
+ call , self .filename , calling_func = calling_func
639
+ ):
640
+ replacement .add_trailing_comma = statement_type == "call"
641
+ self .replacements .append (replacement )
629
642
630
643
def _should_visit_node (self , decorator_list : list [ast .expr ]) -> bool :
631
644
"""Only interested in async functions with the @rule(...) or @goal_rule(...) decorator."""
@@ -667,23 +680,37 @@ def _maybe_replaceable_call(self, statement: ast.stmt) -> ast.Call | None:
667
680
return call_node
668
681
return None
669
682
670
- def _maybe_replaceable_multiget (self , statement : ast .stmt ) -> list [ast .Call ]:
671
- """Looks for the following form of MultiGet that we want to replace:
683
+ def _maybe_replaceable_multiget (
684
+ self , statement : ast .stmt
685
+ ) -> list [tuple [ast .Call , Literal ["call" , "generator" ]]]:
686
+ """Looks for the following forms of MultiGet that we want to replace:
672
687
673
688
- multigot = await MultiGet(Get(...), Get(...), ...)
689
+ - multigot = await MultiGet(Get(...) for x in y)
674
690
"""
675
- if (
691
+ if not (
676
692
isinstance (statement , ast .Assign )
677
693
and isinstance ((await_node := statement .value ), ast .Await )
678
694
and isinstance ((call_node := await_node .value ), ast .Call )
679
695
and isinstance (call_node .func , ast .Name )
680
696
and call_node .func .id == "MultiGet"
681
697
):
682
- return [
683
- arg
684
- for arg in call_node .args
685
- if isinstance (arg , ast .Call )
698
+ return []
699
+
700
+ args : list [tuple [ast .Call , Literal ["call" , "generator" ]]] = []
701
+ for arg in call_node .args :
702
+ if (
703
+ isinstance (arg , ast .Call )
686
704
and isinstance (arg .func , ast .Name )
687
705
and arg .func .id == "Get"
688
- ]
689
- return []
706
+ ):
707
+ args .append ((arg , "call" ))
708
+
709
+ if (
710
+ isinstance (arg , ast .GeneratorExp )
711
+ and isinstance ((call_arg := arg .elt ), ast .Call )
712
+ and isinstance (call_arg .func , ast .Name )
713
+ and call_arg .func .id == "Get"
714
+ ):
715
+ args .append ((call_arg , "generator" ))
716
+ return args
0 commit comments