@@ -1594,13 +1594,14 @@ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
1594
1594
subgraph .graph .graph_outputs , outer_outputs
1595
1595
):
1596
1596
src = inner_output .codegen_reference ()
1597
- # in ABI-compatible mode, we need to std::move subgraph output (inner_output)
1598
- # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
1599
- # constructor is deleted.
1600
- src = f"std::move({ src } )"
1601
- # in case the outer_output carried a value
1602
- # before (e.g., in the while_loop codegen)
1603
- self .writeline (f"{ outer_output } .reset();" )
1597
+ if not isinstance (inner_output , ir .ShapeAsConstantBuffer ):
1598
+ # in ABI-compatible mode, we need to std::move subgraph output (inner_output)
1599
+ # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
1600
+ # constructor is deleted.
1601
+ src = f"std::move({ src } )"
1602
+ # in case the outer_output carried a value
1603
+ # before (e.g., in the while_loop codegen)
1604
+ self .writeline (f"{ outer_output } .reset();" )
1604
1605
self .writeline (f"{ outer_output } = { src } ;" )
1605
1606
1606
1607
def codegen_invoke_subgraph (self , invoke_subgraph ):
@@ -1662,6 +1663,9 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
1662
1663
self .pop_codegened_graph ()
1663
1664
1664
1665
def codegen_while_loop (self , while_loop ):
1666
+ is_bool_pred = isinstance (
1667
+ while_loop .cond_subgraph .graph .graph_outputs [0 ], ir .ShapeAsConstantBuffer
1668
+ )
1665
1669
name = while_loop .get_name ()
1666
1670
outer_carried_inputs = [
1667
1671
buf .codegen_reference () for buf in while_loop .carried_inputs
@@ -1670,7 +1674,10 @@ def codegen_while_loop(self, while_loop):
1670
1674
buf .codegen_reference () for buf in while_loop .additional_inputs
1671
1675
]
1672
1676
cond_result_name = f"{ name } _cond_result"
1673
- self .writeline (f"RAIIAtenTensorHandle { cond_result_name } ;" )
1677
+ if is_bool_pred :
1678
+ self .writeline (f"bool { cond_result_name } ;" )
1679
+ else :
1680
+ self .writeline (f"RAIIAtenTensorHandle { cond_result_name } ;" )
1674
1681
1675
1682
cond_outer_inputs = []
1676
1683
for inp , out in zip (outer_carried_inputs , while_loop .outputs ):
@@ -1700,8 +1707,11 @@ def codegen_while_loop(self, while_loop):
1700
1707
while_loop .cond_subgraph , cond_outer_inputs , cond_outer_outputs
1701
1708
)
1702
1709
1703
- cond_result = f"{ cond_result_name } _scalar"
1704
- self .codegen_tensor_item (torch .bool , cond_result_name , cond_result )
1710
+ if is_bool_pred :
1711
+ cond_result = f"{ cond_result_name } "
1712
+ else :
1713
+ cond_result = f"{ cond_result_name } _scalar"
1714
+ self .codegen_tensor_item (torch .bool , cond_result_name , cond_result )
1705
1715
self .writeline (f"if (!{ cond_result } ) break;" )
1706
1716
1707
1717
self .writeline (ExitSubgraphLine (self ))
0 commit comments