diff --git a/tests/lowering/reduction/test_amax.py b/tests/lowering/reduction/test_amax.py new file mode 100644 index 0000000000..2adb12a9ce --- /dev/null +++ b/tests/lowering/reduction/test_amax.py @@ -0,0 +1,69 @@ +import torch +import torch_ttnn +import pytest +import ttnn + + +class AmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, dim, keepdim): + return torch.amax(input, dim=dim, keepdim=keepdim) + + +@pytest.mark.parametrize("sign", [1, -1]) +@pytest.mark.parametrize( + "input_shape, dim, keepdim", + [ + ((32, 32), [], True), + ((16, 32, 32), [], True), + ((16, 32, 32), [1], True), + ((16, 32, 32), 1, True), + ((16, 32, 32), [2], True), + ((16, 32, 32), [1, 2], True), + # TODO(#240): keepdim = false is not supported + pytest.param((32, 32), [1], False, marks=pytest.mark.xfail(reason="keepdim = false is not supported (#240)")), + # TODO(#240): Not support reduction on < rank - 2 dims + pytest.param( + (16, 32, 32), [0], True, marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)") + ), + pytest.param( + (32, 32, 32), + [0, 1, 2], + True, + marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)"), + ), + # TODO(#240): Unexpected output shape (1, 1) instead of (1) + pytest.param((32,), [], True, marks=pytest.mark.xfail(reason="Need -inf padding value (#240)")), + # TODO(#240): Need -inf padding value + pytest.param((1, 32), [], True, marks=pytest.mark.xfail(reason="Need -inf padding value (#240)")), + pytest.param((32, 1), [], True, marks=pytest.mark.xfail(reason="Need -inf padding value (#240)")), + # TODO(#240): Output reshape inside generic reduction can't handle non-tile-aligned size + pytest.param( + (1, 32), + [1], + True, + marks=pytest.mark.xfail( + reason="Output reshape inside generic reduction can't handle non-tile-aligned size (#240)" + ), + ), + ], +) +def test_amax(device, sign, input_shape, dim, keepdim): + m = AmaxModule() + input = torch.rand(input_shape, dtype=torch.bfloat16) * sign + result_before = m.forward(input, dim, keepdim) + + option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True) + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, dim, keepdim) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten + nodes = list(option._out_fx_graphs[0].nodes) + assert [node.target for node in nodes].count(ttnn.max) == 1 + # Check inference result + assert result_before.shape == result_after.shape + assert torch.allclose(result_before, result_after) diff --git a/tests/lowering/reduction/test_amin.py b/tests/lowering/reduction/test_amin.py new file mode 100644 index 0000000000..c5c6f2f265 --- /dev/null +++ b/tests/lowering/reduction/test_amin.py @@ -0,0 +1,71 @@ +import torch +import torch_ttnn +import pytest +import ttnn + + +class AminModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, dim, keepdim): + return torch.amin(input, dim=dim, keepdim=keepdim) + + +@pytest.mark.parametrize("sign", [1, -1]) +@pytest.mark.parametrize( + "input_shape, dim, keepdim", + [ + ((32, 32), [], True), + ((16, 32, 32), [], True), + ((16, 32, 32), 1, True), + ((16, 32, 32), [1], True), + ((16, 32, 32), [2], True), + ((16, 32, 32), [1, 2], True), + # TODO(#240): keepdim = false is not supported + pytest.param((32, 32), [1], False, marks=pytest.mark.xfail(reason="keepdim = false is not supported (#240)")), + # TODO(#240): Not support reduction on < rank - 2 dims + pytest.param( + (16, 32, 32), [0], True, marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)") + ), + pytest.param( + (32, 32, 32), + [0, 1, 2], + True, + marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)"), + ), + # TODO(#240): Unexpected output shape (1, 1) instead of (1) + pytest.param( + (32,), [], True, marks=pytest.mark.xfail(reason="Unexpected output shape (1, 1) instead of (1) (#240)") + ), + # TODO(#240): Need inf padding value + pytest.param((1, 32), [], True, marks=pytest.mark.xfail(reason="Need inf padding value (#240)")), + pytest.param((32, 1), [], True, marks=pytest.mark.xfail(reason="Need inf padding value (#240)")), + # TODO(#240): Output reshape inside generic reduction can't handle non-tile-aligned size + pytest.param( + (1, 32), + [1], + True, + marks=pytest.mark.xfail( + reason="Output reshape inside generic reduction can't handle non-tile-aligned size (#240)" + ), + ), + ], +) +def test_amin(device, sign, input_shape, dim, keepdim): + m = AminModule() + input = torch.rand(input_shape, dtype=torch.bfloat16) * sign + result_before = m.forward(input, dim, keepdim) + + option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True) + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, dim, keepdim) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten + nodes = list(option._out_fx_graphs[0].nodes) + assert [node.target for node in nodes].count(ttnn.min) == 1 + # Check inference result + assert result_before.shape == result_after.shape + assert torch.allclose(result_before, result_after) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 7add48c366..40c6343d6e 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -50,7 +50,6 @@ def is_function_call(node) -> bool: ttnn.log1p, ttnn.log2, ttnn.logical_not, - ttnn.min, ttnn.neg, ttnn.reciprocal, ttnn.relu, @@ -109,6 +108,12 @@ def is_function_call(node) -> bool: ttnn.where, ] +TTNN_REDUCTION_OPS = [ + ttnn.max, + ttnn.mean, + ttnn.min, +] + TTNN_MATRIX_MULPIPLICATION_OPS = [ ttnn.matmul, ttnn.linear, @@ -148,6 +153,7 @@ def is_tt_compute(node) -> bool: + TTNN_POINTWISE_BINARY_OPS + TTNN_POINTWISE_TRINARY_OPS + TTNN_MATRIX_MULPIPLICATION_OPS + + TTNN_REDUCTION_OPS + TTNN_TARGET_WRAPPERS + TTNN_DATAMOVE_OPS + TTNN_NORM_OPS @@ -157,7 +163,6 @@ def is_tt_compute(node) -> bool: ttnn.tril, ttnn.arange, ttnn.zeros_like, - ttnn.mean, ttnn.global_avg_pool2d, ttnn.clip, ttnn.squeeze, diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index b0bb2c8f2e..4b1d59eac6 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -623,6 +623,29 @@ def rewrite_node(node): input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout())) return g.call_function(ttnn.pad, args=(input, full_pad, value)) + if node.target in [torch.ops.aten.amin.default, torch.ops.aten.amax.default]: + input_shape = args[0].meta["val"].size() + # TODO(#240): Not support keepdim = false (default value) + if len(args) < 3 or args[2] == False: + return None + # TODO(#240): Not support rank < 2 or non-tile-size-aligned tensor + if len(input_shape) < 2 or any(size % ttnn.TILE_SIZE != 0 for size in input_shape[-2:]): + return None + new_args = list(args) + # Convert dim int/list to tuple + if len(args) >= 2: + dim = args[1] + dim = (dim,) if isinstance(dim, int) else tuple(dim) + # TODO(#240): Not support reduction on < rank - 2 dims + if any(idx < len(input_shape) - 2 for idx in dim): + return None + new_args[1] = dim if len(dim) > 0 else None + return g.call_function( + ttnn.min if node.target == torch.ops.aten.amin.default else ttnn.max, + tuple(new_args), + kwargs, + ) + with g.inserting_before(node): new_node = rewrite_node(node) if new_node is not None: