Skip to content

Commit 5ecdc42

Browse files
ydwu4pytorchmergebot
authored andcommitted
[while_loop][inductor] support sym expression as cond_fn output (pytorch#146222)
As titled. Previously, we only support tensor output of cond_fn, this PR changes to also allow a shape expr to be returned in cond_fn. aoti generated output code looks like: ``` V0203 11:28:05.750000 2611693 torch/_inductor/compile_fx.py:1091] [1/0] [__output_code] bool buf7_cond_result; .... (while_loop_cond_graph_0_arg2_1_handle); V0203 11:27:59.336000 2611693 torch/_inductor/compile_fx.py:1091] [1/0] [__output_code] buf7_cond_result = u0 + u1 < 10L; V0203 11:27:59.336000 2611693 torch/_inductor/compile_fx.py:1091] [1/0] [__output_code] if (!buf7_cond_result) break; ``` Pull Request resolved: pytorch#146222 Approved by: https://github.com/desertfire ghstack dependencies: pytorch#146194, pytorch#146195
1 parent 1b879fd commit 5ecdc42

File tree

5 files changed

+79
-13
lines changed

5 files changed

+79
-13
lines changed

test/inductor/test_aot_inductor.py

+20
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,26 @@ def test_while_loop_with_unbacked_symint_closure(self, dynamic):
15111511
dynamic_shapes=dynamic_shapes,
15121512
)
15131513

1514+
@common_utils.parametrize("dynamic", [False, True])
1515+
def test_while_loop_with_sym_expr_cond(self, dynamic):
1516+
inputs = (
1517+
torch.randn(10, 20, device=self.device),
1518+
torch.randn(10, 20, device=self.device),
1519+
)
1520+
dim0_ab = Dim("s0", min=2, max=1024)
1521+
dynamic_shapes = None
1522+
if dynamic:
1523+
dynamic_shapes = {
1524+
"c": {},
1525+
"a": {0: dim0_ab, 1: None},
1526+
"b": {0: dim0_ab, 1: None},
1527+
}
1528+
self.check_model_with_multiple_inputs(
1529+
WhileLoopModels.SymExprCond(),
1530+
prepend_counters(inputs),
1531+
dynamic_shapes=dynamic_shapes,
1532+
)
1533+
15141534
@config.patch({"is_predispatch": True})
15151535
def test_constant(self):
15161536
class M(torch.nn.Module):

test/inductor/test_control_flow.py

+34
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,23 @@ def body_fn(c, a, b):
876876
[c, a, b],
877877
)
878878

879+
class SymExprCond(torch.nn.Module):
880+
def forward(self, c, a, b):
881+
d = a.sum().to(torch.int64).item()
882+
e = torch.nonzero(b).size(0)
883+
884+
def cond_fn(c, a, b):
885+
return d + e + a.shape[0] - b.shape[0] < 10
886+
887+
def body_fn(c, a, b):
888+
return c + 1, a + e, b + d
889+
890+
return torch._higher_order_ops.while_loop(
891+
cond_fn,
892+
body_fn,
893+
[c, a, b],
894+
)
895+
879896

880897
class WhileLoopTests(TestCase):
881898
def _run_test(
@@ -1139,6 +1156,23 @@ def test_while_loop_with_unbacked_symint_closure(self, device, dynamic):
11391156
dynamic=dynamic,
11401157
)
11411158

1159+
@requires_gpu
1160+
@parametrize("device", ["cpu", GPU_TYPE])
1161+
@parametrize("dynamic", [True, False])
1162+
@torch._dynamo.config.patch(
1163+
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
1164+
)
1165+
def test_while_loop_with_sym_expr_cond(self, device, dynamic):
1166+
self._run_test(
1167+
model=WhileLoopModels.SymExprCond(),
1168+
inputs=(
1169+
torch.randn(10, 20),
1170+
torch.randn(10, 20),
1171+
),
1172+
device=device,
1173+
dynamic=dynamic,
1174+
)
1175+
11421176

11431177
class AssociativeScanTests(TestCase):
11441178
@requires_gpu

torch/_inductor/codegen/cpp_wrapper_cpu.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -1594,13 +1594,14 @@ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
15941594
subgraph.graph.graph_outputs, outer_outputs
15951595
):
15961596
src = inner_output.codegen_reference()
1597-
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
1598-
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
1599-
# constructor is deleted.
1600-
src = f"std::move({src})"
1601-
# in case the outer_output carried a value
1602-
# before (e.g., in the while_loop codegen)
1603-
self.writeline(f"{outer_output}.reset();")
1597+
if not isinstance(inner_output, ir.ShapeAsConstantBuffer):
1598+
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
1599+
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
1600+
# constructor is deleted.
1601+
src = f"std::move({src})"
1602+
# in case the outer_output carried a value
1603+
# before (e.g., in the while_loop codegen)
1604+
self.writeline(f"{outer_output}.reset();")
16041605
self.writeline(f"{outer_output} = {src};")
16051606

16061607
def codegen_invoke_subgraph(self, invoke_subgraph):
@@ -1662,6 +1663,9 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
16621663
self.pop_codegened_graph()
16631664

16641665
def codegen_while_loop(self, while_loop):
1666+
is_bool_pred = isinstance(
1667+
while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer
1668+
)
16651669
name = while_loop.get_name()
16661670
outer_carried_inputs = [
16671671
buf.codegen_reference() for buf in while_loop.carried_inputs
@@ -1670,7 +1674,10 @@ def codegen_while_loop(self, while_loop):
16701674
buf.codegen_reference() for buf in while_loop.additional_inputs
16711675
]
16721676
cond_result_name = f"{name}_cond_result"
1673-
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
1677+
if is_bool_pred:
1678+
self.writeline(f"bool {cond_result_name};")
1679+
else:
1680+
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
16741681

16751682
cond_outer_inputs = []
16761683
for inp, out in zip(outer_carried_inputs, while_loop.outputs):
@@ -1700,8 +1707,11 @@ def codegen_while_loop(self, while_loop):
17001707
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
17011708
)
17021709

1703-
cond_result = f"{cond_result_name}_scalar"
1704-
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
1710+
if is_bool_pred:
1711+
cond_result = f"{cond_result_name}"
1712+
else:
1713+
cond_result = f"{cond_result_name}_scalar"
1714+
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
17051715
self.writeline(f"if (!{cond_result}) break;")
17061716

17071717
self.writeline(ExitSubgraphLine(self))

torch/_inductor/codegen/wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2551,7 +2551,7 @@ def codegen_while_loop(self, while_loop):
25512551
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
25522552
)
25532553
self.writeline(
2554-
f"if not {cond_outer_outputs[0]}.item(): break"
2554+
f"if not {cond_outer_outputs[0]}: break"
25552555
) # condition doesn't hold
25562556
self.writeline(ExitSubgraphLine(self))
25572557
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))

torch/_inductor/ir.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7404,8 +7404,10 @@ def create( # type: ignore[no-untyped-def]
74047404

74057405
# make sure cond_fn returns a boolean scalar Tensor
74067406
assert len(cond_outputs) == 1, cond_outputs
7407-
assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs
7408-
assert len(cond_outputs[0].get_size()) == 0, cond_outputs
7407+
p = cond_outputs[0]
7408+
if not isinstance(p, ShapeAsConstantBuffer):
7409+
assert p.get_dtype() == torch.bool, p
7410+
assert len(p.get_size()) == 0, p
74097411

74107412
assert (
74117413
len(all_inputs) > 0

0 commit comments

Comments
 (0)