Skip to content

Commit afac9cf

Browse files
philei-ttPavlo Hilei
and
Pavlo Hilei
authored
Remove fallback for aten.squeeze.default, aten.full, aten.fill.Scalar, aten.memory_format (#896)
* Always use native ttnn squeeze * Remove fallback for aten.full(_like).default * Enable ttnn.fill for aten.fill.Scalar * Enable ttnn.empty for aten.empty.memory_format * Fix precommit check * Remove changes from autogen test * Update ttnn version * Fix typing issue in Embeddings::validate --------- Co-authored-by: Pavlo Hilei <[email protected]>
1 parent 93b1645 commit afac9cf

File tree

6 files changed

+53
-43
lines changed

6 files changed

+53
-43
lines changed

tests/lowering/creation/test_full.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,24 @@ def forward(self, size, fill_value):
1515

1616

1717
@pytest.mark.parametrize(
18-
"input_shapes",
18+
"input_shape",
1919
[
20-
[(64, 128)],
21-
[(19, 19)],
22-
[(59, 59)],
20+
[64, 128],
21+
[19, 19],
22+
[59, 59],
23+
[33],
24+
[], # scalar
2325
],
2426
)
25-
def test_full(device, input_shapes):
27+
def test_full(device, input_shape):
2628
m = FullModule()
2729
fill_value = 1.23
28-
result_before = m.forward(input_shapes[0], fill_value).to(torch.bfloat16)
30+
result_before = m.forward(input_shape, fill_value).to(torch.bfloat16)
2931
option = torch_ttnn.TorchTtnnOption(device=device)
3032
option.gen_graphviz = True
3133
# The compilation is lazy, so we need to run forward once to trigger the compilation
3234
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
33-
result_after = m.forward(input_shapes[0], fill_value).to(torch.bfloat16)
35+
result_after = m.forward(input_shape, fill_value).to(torch.bfloat16)
3436
option._out_fx_graphs[0].print_tabular()
3537

3638
# Check the graph has be rewritten and contain ttnn ops

tests/lowering/creation/test_full_like.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def forward(self, tensor, fill_value):
2222
(1, 1),
2323
(2, 2),
2424
(17, 17),
25+
(33,),
26+
(),
2527
],
2628
)
2729
def test_full_like(device, input_shape):

tests/lowering/tensor_manipulation/test_squeeze.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def forward(self, input, dim):
2222
((1, 256, 1), -1),
2323
((33, 44, 1, 32, 16), 1),
2424
((33, 44, 1, 32, 16), 2),
25+
((1, 12), 0),
26+
((1), 0),
27+
((), 0),
2528
],
2629
)
2730
def test_squeeze_dim(device, input_shape, dim):
@@ -36,11 +39,7 @@ def test_squeeze_dim(device, input_shape, dim):
3639
option._out_fx_graphs[0].print_tabular()
3740
# Check the graph has be rewritten and contain ttnn ops
3841
nodes = list(option._out_fx_graphs[0].nodes)
39-
if option.use_less_ttnn_op_types:
40-
# squeeze is lowered to reshape
41-
assert [node.target for node in nodes].count(ttnn.reshape) == 1
42-
else:
43-
assert [node.target for node in nodes].count(ttnn.squeeze) == 1
42+
assert [node.target for node in nodes].count(ttnn.squeeze) == 1
4443
# Check inference result
4544
assert torch.allclose(result_before, result_after)
4645

@@ -60,6 +59,10 @@ def forward(self, input):
6059
((1, 1, 55, 23, 44, 32, 32)),
6160
((22, 1, 55, 23, 44, 32, 1)),
6261
((1, 1, 55, 1, 1, 1, 1)),
62+
((1, 12)),
63+
((1, 1)),
64+
((1)),
65+
(()),
6366
],
6467
)
6568
def test_squeeze_none_dim(device, input_shape):
@@ -74,6 +77,6 @@ def test_squeeze_none_dim(device, input_shape):
7477
option._out_fx_graphs[0].print_tabular()
7578
# Check the graph has be rewritten and contain ttnn ops (squeeze without provided dim is lowered to reshape)
7679
nodes = list(option._out_fx_graphs[0].nodes)
77-
assert [node.target for node in nodes].count(ttnn.reshape) == 1
80+
assert [node.target for node in nodes].count(ttnn.squeeze) == 1
7881
# Check inference result
7982
assert torch.allclose(result_before, result_after)

torch_ttnn/passes/lowering/add_data_move_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def is_tt_compute(node) -> bool:
216216
ttnn.sum,
217217
ttnn.typecast,
218218
ttnn.argmax,
219+
ttnn.fill,
220+
ttnn.empty,
219221
]
220222
)
221223

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
GraphCleanup,
1111
TtnnBfloat16,
1212
TtnnInt32,
13+
TtnnUint32,
1314
TtnnDevice,
1415
TtnnL1MemoryConfig,
1516
TtnnRowMajorLayout,
@@ -370,7 +371,7 @@ def __init__(self, target, args, kwargs):
370371
def torch_dtype_to_ttnn_dtype(dtype: torch.dtype):
371372
# Add newly supported dtypes here:
372373
dtype_map = {
373-
torch.float32: TtnnBfloat16(),
374+
torch.float32: TtnnBfloat16(), # Should this be changed to TtnnFloat32?
374375
torch.bfloat16: TtnnBfloat16(),
375376
}
376377
if dtype in dtype_map:
@@ -597,21 +598,12 @@ def reshape_1d(code, args=args, kwargs=kwargs):
597598
return None
598599

599600
if node.target == torch.ops.aten.full.default:
600-
# args[0] can be empty for aten.full which simply creates a scalar. Ignore conversion in this case.
601-
if args[0]:
602-
new_kwargs = {
603-
"fill_value": args[1],
604-
"device": TtnnDevice(),
605-
"layout": TtnnTileLayout(),
606-
}
607-
return g.call_function(ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs)
608-
# Replace op with scalar for eltwise ops
609-
# TODO: Generalize this to support all eltwise ops
610-
node_users = list(node.users.keys())
611-
for node_user in node_users:
612-
if node_user.target == torch.ops.aten.div.Tensor:
613-
node_user.update_arg(1, args[1])
614-
return None
601+
new_kwargs = {
602+
"fill_value": args[1],
603+
"device": TtnnDevice(),
604+
"layout": TtnnTileLayout(),
605+
}
606+
return g.call_function(ttnn.full, args=(args[0],), kwargs=new_kwargs)
615607

616608
if node.target == torch.ops.aten.baddbmm.default:
617609
# out = beta * input + alpha * (batch1 @ batch2)
@@ -738,16 +730,7 @@ def reshape_1d(code, args=args, kwargs=kwargs):
738730
return None
739731

740732
if node.target == torch.ops.aten.squeeze.dim or node.target == torch.ops.aten.squeeze.default:
741-
if get_shape(gm, args[0]) in [torch.Size([1]), torch.Size([])]:
742-
# see #442
743-
return None
744-
if use_less_ttnn_op_types or node.target == torch.ops.aten.squeeze.default:
745-
# ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default)
746-
# squeezing is the same as reshaping to shape of output tensor of squeeze
747-
output_size = list(node.meta["val"].size())
748-
return g.call_function(ttnn.reshape, args=(args[0], output_size))
749-
else:
750-
return g.call_function(ttnn.squeeze, args=(args[0], args[1]))
733+
return g.call_function(ttnn.squeeze, args=args, kwargs=kwargs)
751734

752735
if node.target == torch.ops.aten.unsqueeze.default:
753736
output_shape_num_element = node.meta["val"].numel()
@@ -906,6 +889,9 @@ def reshape_1d(code, args=args, kwargs=kwargs):
906889
# Essentially remove this op
907890
return node.args[0]
908891

892+
if node.target == torch.ops.aten.fill.Scalar:
893+
return g.call_function(ttnn.fill, args=args)
894+
909895
if node.target in [torch.ops.aten.masked_fill.Scalar, torch.ops.aten.masked_fill.Tensor]:
910896
# aten.masked_fill is equivalent to the following:
911897
# masked_fill = (tensor * (ones - mask)) + (mask * full)
@@ -1223,6 +1209,21 @@ def reshape_1d(code, args=args, kwargs=kwargs):
12231209
ttnn_all = g.call_function(target_wrappers.all, args=(args[0], input_shape.numel()))
12241210
return g.call_function(torch.ops.aten.squeeze.default, args=(ttnn_all,))
12251211

1212+
if node.target == torch.ops.aten.empty.memory_format:
1213+
# raise RuntimeError(f"{str(kwargs)}, {str(args)}, {str(type(args[0]))}")
1214+
dtype_mapping = {
1215+
torch.float32: TtnnBfloat16(),
1216+
torch.float16: TtnnBfloat16(),
1217+
torch.int32: TtnnInt32(),
1218+
}
1219+
dtype = dtype_mapping.get(kwargs["dtype"], TtnnUint32())
1220+
new_kwargs = {
1221+
"dtype": dtype,
1222+
"layout": TtnnTileLayout(),
1223+
"device": TtnnDevice(),
1224+
}
1225+
return g.call_function(ttnn.empty, args=(args[0],), kwargs=new_kwargs)
1226+
12261227
# PEP 8 suggests this explicit statement
12271228
return None
12281229

torch_ttnn/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ def __repr__(self):
108108
return f"ttnn_TILE_LAYOUT"
109109

110110

111-
class TtnnUint32:
111+
class TtnnInt32:
112112
def __repr__(self):
113-
return f"ttnn_uint32"
113+
return f"ttnn_int32"
114114

115115

116-
class TtnnInt32:
116+
class TtnnUint32:
117117
def __repr__(self):
118-
return f"ttnn_int32"
118+
return f"ttnn_uint32"
119119

120120

121121
class TtnnBfloat16:

0 commit comments

Comments
 (0)