Skip to content

Commit a6822d3

Browse files
committed
Broadcast minimum input
1 parent cce8552 commit a6822d3

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

tests/lowering/eltwise/binary/test_minimum.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,10 @@ def forward(self, x, y):
1616
"input_shapes",
1717
(
1818
((32, 32), (32, 32)),
19-
pytest.param(
20-
((64,), (32, 64)),
21-
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
22-
),
23-
pytest.param(
24-
((64, 32), (64, 1)),
25-
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
26-
),
27-
pytest.param(
28-
((64, 1), (1, 64)),
29-
marks=pytest.mark.xfail(reason="broadcasting issues (#64)"),
30-
),
19+
((64,), (32, 64)),
20+
((64, 32), (64, 1)),
21+
((64, 1), (1, 64)),
22+
((1, 16, 59, 59), ()),
3123
),
3224
)
3325
def test_minimum(device, input_shapes):

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,7 @@ def rewrite_node(node):
11701170
kwargs = node.kwargs
11711171

11721172
# workaround for issue #64
1173-
if node.target == torch.ops.aten.maximum.default:
1173+
if node.target in [torch.ops.aten.maximum.default, torch.ops.aten.minimum.default]:
11741174
self_tensor = args[0]
11751175
if len(args) > 1:
11761176
other_tensor = args[1]
@@ -1179,7 +1179,7 @@ def rewrite_node(node):
11791179
if get_shape(self_tensor) is None or get_shape(other_tensor) is None:
11801180
return None
11811181
broadcasted_shape, broadcasted_tensors = broadcast_tensors(g, [self_tensor, other_tensor])
1182-
return g.call_function(torch.ops.aten.maximum.default, tuple(broadcasted_tensors))
1182+
return g.call_function(node.target, tuple(broadcasted_tensors))
11831183

11841184
with g.inserting_before(node):
11851185
new_node = rewrite_node(node)

0 commit comments

Comments
 (0)