Skip to content

Commit ac215a6

Browse files
authored
[Call-by-name] Fixed ignored compound statements and generator MultiGets (#21042)
This PR captures two missing cases from migration. - Generator-based `MultiGet`s (very important) - `Get` calls within compound statements (helps with edge cases)
1 parent 32d845b commit ac215a6

File tree

2 files changed

+92
-28
lines changed

2 files changed

+92
-28
lines changed

src/python/pants/goal/migrate_call_by_name.py

+54-27
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dataclasses import dataclass
1313
from functools import partial
1414
from pathlib import Path, PurePath
15-
from typing import Callable, TypedDict
15+
from typing import Callable, Literal, TypedDict, cast
1616

1717
from pants.base.build_environment import get_buildroot
1818
from pants.base.exiter import PANTS_SUCCEEDED_EXIT_CODE, ExitCode
@@ -225,8 +225,6 @@ def _perform_replacements_on_file(self, file: Path, replacements: list[Replaceme
225225
This function is not idempotent, so it should be run only once per file (per migration plan).
226226
227227
Replacement imports are bulk added below the existing "pants.engine.rules" import.
228-
229-
230228
"""
231229

232230
imports_added = False
@@ -241,7 +239,7 @@ def _perform_replacements_on_file(self, file: Path, replacements: list[Replaceme
241239
modified = False
242240
for replacement in replacements:
243241
if line_number == replacement.line_range[0]:
244-
line_end = ",\n" if replacement.is_argument else "\n"
242+
line_end = ",\n" if replacement.add_trailing_comma else "\n"
245243
# On the first line of the range, emit the new source code where the old "Get" started
246244
print(line[: replacement.col_range[0]], end="")
247245
print(ast.unparse(replacement.new_source), end=line_end)
@@ -321,7 +319,8 @@ class Replacement:
321319
current_source: ast.Call
322320
new_source: ast.Call
323321
additional_imports: list[ast.ImportFrom]
324-
is_argument: bool = False
322+
# TODO: Use libcst or another CST, rather than an ast
323+
add_trailing_comma: bool = False
325324

326325
def sanitized_imports(self) -> list[ast.ImportFrom]:
327326
"""Remove any circular or self-imports."""
@@ -361,7 +360,7 @@ def __str__(self) -> str:
361360
current_source={ast.dump(self.current_source, indent=2)},
362361
new_source={ast.dump(self.new_source, indent=2)},
363362
additional_imports={[ast.dump(i, indent=2) for i in self.additional_imports]},
364-
is_argument={self.is_argument}
363+
add_trailing_comma={self.add_trailing_comma}
365364
)
366365
"""
367366

@@ -613,19 +612,33 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
613612
if not self._should_visit_node(node.decorator_list):
614613
return
615614

616-
for child in node.body:
617-
if call := self._maybe_replaceable_call(child):
618-
if replacement := self.syntax_mapper.map_get_to_new_syntax(
619-
call, self.filename, calling_func=node.name
620-
):
621-
self.replacements.append(replacement)
615+
self._recurse_body_statements(node.name, node)
616+
617+
def _recurse_body_statements(self, root: str, node: ast.stmt):
618+
"""Recursively walk the body of a node, including properties of compound statements looking
619+
for Get() calls to replace.
620+
621+
https://docs.python.org/3/reference/compound_stmts.html
622+
"""
623+
for prop in ["body", "handlers", "orelse", "finalbody"]:
624+
for child in getattr(node, prop, []):
625+
self._recurse_body_statements(root, cast(ast.stmt, child))
622626

623-
for call in self._maybe_replaceable_multiget(child):
624-
if replacement := self.syntax_mapper.map_get_to_new_syntax(
625-
call, self.filename, calling_func=node.name
626-
):
627-
replacement.is_argument = True
628-
self.replacements.append(replacement)
627+
self._maybe_add_replacements(root, node)
628+
629+
def _maybe_add_replacements(self, calling_func: str, statement: ast.stmt):
630+
if call := self._maybe_replaceable_call(statement):
631+
if replacement := self.syntax_mapper.map_get_to_new_syntax(
632+
call, self.filename, calling_func=calling_func
633+
):
634+
self.replacements.append(replacement)
635+
636+
for call, statement_type in self._maybe_replaceable_multiget(statement):
637+
if replacement := self.syntax_mapper.map_get_to_new_syntax(
638+
call, self.filename, calling_func=calling_func
639+
):
640+
replacement.add_trailing_comma = statement_type == "call"
641+
self.replacements.append(replacement)
629642

630643
def _should_visit_node(self, decorator_list: list[ast.expr]) -> bool:
631644
"""Only interested in async functions with the @rule(...) or @goal_rule(...) decorator."""
@@ -667,23 +680,37 @@ def _maybe_replaceable_call(self, statement: ast.stmt) -> ast.Call | None:
667680
return call_node
668681
return None
669682

670-
def _maybe_replaceable_multiget(self, statement: ast.stmt) -> list[ast.Call]:
671-
"""Looks for the following form of MultiGet that we want to replace:
683+
def _maybe_replaceable_multiget(
684+
self, statement: ast.stmt
685+
) -> list[tuple[ast.Call, Literal["call", "generator"]]]:
686+
"""Looks for the following forms of MultiGet that we want to replace:
672687
673688
- multigot = await MultiGet(Get(...), Get(...), ...)
689+
- multigot = await MultiGet(Get(...) for x in y)
674690
"""
675-
if (
691+
if not (
676692
isinstance(statement, ast.Assign)
677693
and isinstance((await_node := statement.value), ast.Await)
678694
and isinstance((call_node := await_node.value), ast.Call)
679695
and isinstance(call_node.func, ast.Name)
680696
and call_node.func.id == "MultiGet"
681697
):
682-
return [
683-
arg
684-
for arg in call_node.args
685-
if isinstance(arg, ast.Call)
698+
return []
699+
700+
args: list[tuple[ast.Call, Literal["call", "generator"]]] = []
701+
for arg in call_node.args:
702+
if (
703+
isinstance(arg, ast.Call)
686704
and isinstance(arg.func, ast.Name)
687705
and arg.func.id == "Get"
688-
]
689-
return []
706+
):
707+
args.append((arg, "call"))
708+
709+
if (
710+
isinstance(arg, ast.GeneratorExp)
711+
and isinstance((call_arg := arg.elt), ast.Call)
712+
and isinstance(call_arg.func, ast.Name)
713+
and call_arg.func.id == "Get"
714+
):
715+
args.append((call_arg, "generator"))
716+
return args

src/python/pants/goal/migrate_call_by_name_integration_test.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,20 @@ async def variants(black: Black, local_env: ChosenLocalEnvironmentName) -> Foo:
9595
pex = await Get(VenvPex, PexRequest, black.to_pex_request())
9696
digest = await Get(Digest, CreateArchive(EMPTY_SNAPSHOT))
9797
paths = await Get(BinaryPaths, {{BinaryPathRequest(binary_name="time", search_path=("/usr/bin")): BinaryPathRequest, local_env.val: EnvironmentName}})
98+
try:
99+
try_all_targets_try = await Get(AllTargets)
100+
except:
101+
try_all_targets_except = await Get(AllTargets)
102+
else:
103+
try_all_targets_else = await Get(AllTargets)
104+
finally:
105+
try_all_targets_finally = await Get(AllTargets)
106+
if True:
107+
conditional_all_targets_if = await Get(AllTargets)
108+
elif False:
109+
conditional_all_targets_elif = await Get(AllTargets)
110+
else:
111+
conditional_all_targets_else = await Get(AllTargets)
98112
99113
class Bar:
100114
pass
@@ -136,6 +150,10 @@ async def multiget(black: Black) -> Thud:
136150
),
137151
digest_get
138152
)
153+
multigot_forloop = await MultiGet(
154+
Get(Digest, CreateArchive(EMPTY_SNAPSHOT))
155+
for i in [0, 1, 2]
156+
)
139157
140158
def rules():
141159
return collect_rules()
@@ -168,6 +186,20 @@ async def variants(black: Black, local_env: ChosenLocalEnvironmentName) -> Foo:
168186
pex = await create_venv_pex(**implicitly({black.to_pex_request(): PexRequest}))
169187
digest = await create_archive(CreateArchive(EMPTY_SNAPSHOT), **implicitly())
170188
paths = await find_binary(**implicitly({BinaryPathRequest(binary_name='time', search_path='/usr/bin'): BinaryPathRequest, local_env.val: EnvironmentName}))
189+
try:
190+
try_all_targets_try = await find_all_targets(**implicitly())
191+
except:
192+
try_all_targets_except = await find_all_targets(**implicitly())
193+
else:
194+
try_all_targets_else = await find_all_targets(**implicitly())
195+
finally:
196+
try_all_targets_finally = await find_all_targets(**implicitly())
197+
if True:
198+
conditional_all_targets_if = await find_all_targets(**implicitly())
199+
elif False:
200+
conditional_all_targets_elif = await find_all_targets(**implicitly())
201+
else:
202+
conditional_all_targets_else = await find_all_targets(**implicitly())
171203
172204
class Bar:
173205
pass
@@ -201,6 +233,10 @@ async def multiget(black: Black) -> Thud:
201233
create_venv_pex(**implicitly({black.to_pex_request(): PexRequest})),
202234
digest_get
203235
)
236+
multigot_forloop = await concurrently(
237+
create_archive(CreateArchive(EMPTY_SNAPSHOT), **implicitly())
238+
for i in [0, 1, 2]
239+
)
204240
205241
def rules():
206242
return collect_rules()
@@ -275,8 +311,9 @@ def test_migrate_call_by_name_syntax():
275311
# Ensure the JSON output contains the paths to the files we expect to migrate
276312
assert all(str(p) in result.stdout for p in [register_path, rules1_path, rules2_path])
277313
# Ensure the warning for embedded comments is logged
314+
# Note: This assertion is brittle - adding extra content to rules1.py will probably mean updating the range
278315
assert (
279-
f"Comments found in {tmpdir}/src/migrateme/rules1.py within replacement range: (37, 42)"
316+
f"Comments found in {tmpdir}/src/migrateme/rules1.py within replacement range: (51, 56)"
280317
in result.stderr
281318
)
282319

0 commit comments

Comments
 (0)