There is the following issue on this page: https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
`from functorch.experimental.control_flow import cond
@torch.compile(fullgraph=True)
def bar_fixed(a, b):
x = a / (torch.abs(a) + 1)
def true_branch(y):
return y * -1
def false_branch(y):
# NOTE: torch.cond doesn't allow aliased outputs
return y.clone()
**x = cond(b.sum() < 0, true_branch, false_branch, (b,))**
return x * b
bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)`
** line is error? should be b = cond(b.sum() < 0, true_branch, false_branch, (b,))