Skip to content

Commit e0abc51

Browse files
Revert "Convert aten.view to ttnn.reshape (#221)" (#239)
This reverts commit 8135edd.
1 parent bcc17ce commit e0abc51

File tree

3 files changed

+10
-153
lines changed

3 files changed

+10
-153
lines changed

tests/lowering/tensor_manipulation/test_view.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

torch_ttnn/passes/lowering/add_data_move_pass.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,6 @@ def is_function_call(node) -> bool:
132132
]
133133

134134

135-
def can_be_tilized(node):
136-
size = node.meta["val"].size()
137-
return len(size) >= 2 and size[-1] % 32 == 0 and size[-2] % 32 == 0
138-
139-
140135
# For operations limitations
141136
# See https://github.com/tenstorrent-metal/tt-metal/blob/main/ttnn/README.md?plain=1#L19
142137
def is_tt_compute(node) -> bool:
@@ -269,7 +264,7 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
269264
with g.inserting_before(dst_node):
270265
kwargs = {}
271266
if (
272-
(dst_node.target == ttnn.reshape and not can_be_tilized(dst_node))
267+
dst_node.target == ttnn.reshape
273268
or dst_node.target == ttnn.embedding
274269
or dst_node.target == ttnn.zeros_like
275270
or dst_node.target == target_wrappers.repeat
@@ -282,7 +277,9 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
282277
kwargs["dtype"] = TtnnBfloat16()
283278

284279
# For reshape only put tensor on device if rank is 4
285-
if is_tt_compute(dst_node):
280+
if (is_tt_compute(dst_node) and dst_node.target != ttnn.reshape) or (
281+
dst_node.target == ttnn.reshape and len(dst_node.args[1]) == 4
282+
):
286283
kwargs["device"] = device
287284

288285
new_nodes.append(g.call_function(ttnn.from_torch, (src_node,), kwargs))
@@ -305,12 +302,7 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node) -> torch.fx.n
305302
return None
306303
if not is_function_call(dst_node):
307304
return None
308-
if (
309-
dst_node.target not in layout_change_ops
310-
or dst_idx != 0
311-
or not is_tt(src_node)
312-
or (dst_node.target == ttnn.reshape and can_be_tilized(dst_node))
313-
):
305+
if dst_node.target not in layout_change_ops or dst_idx != 0 or not is_tt(src_node):
314306
return None
315307

316308
g = dst_node.graph
@@ -326,12 +318,7 @@ def try_add_layout_change_after_node(src_node, dst_idx, dst_node) -> torch.fx.no
326318
# Consider src_node is ttnn.repeat, and dst_node should be any tt_compute node that uses ttnn.repeat
327319
if not is_function_call(src_node):
328320
return None
329-
if (
330-
src_node.target not in layout_change_ops
331-
or not is_tt_compute(dst_node)
332-
or dst_node.target == ttnn.embedding
333-
or (src_node.target == ttnn.reshape and can_be_tilized(src_node))
334-
):
321+
if src_node.target not in layout_change_ops or not is_tt_compute(dst_node) or dst_node.target == ttnn.embedding:
335322
return None
336323

337324
g = dst_node.graph

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ def call_function(self, target, args, kwargs):
295295
if target == torch.ops.aten.permute.default:
296296
return self.call_function_prop_meta(ttnn.permute, args, kwargs)
297297

298+
if target == torch.ops.aten.view.default:
299+
# aten.reshape is more stable if the input nodes have changed
300+
return self.call_function_prop_meta(torch.ops.aten.reshape.default, args, kwargs)
301+
298302
############################################################
299303
# Other ops
300304
############################################################
@@ -619,56 +623,6 @@ def rewrite_node(node):
619623
input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout()))
620624
return g.call_function(ttnn.pad, args=(input, full_pad, value))
621625

622-
if node.target == torch.ops.aten.view.default:
623-
source_shape = args[0].meta["val"].size()
624-
out_shape = args[1]
625-
source_rank = len(source_shape)
626-
out_rank = len(out_shape)
627-
628-
# Allow lowering by default
629-
can_reshape = True
630-
631-
# Unsupported:
632-
# (1) -> (1, 1) or (1, 1, 1), etc
633-
if (source_rank == 1) and (np.prod(source_shape) == 1):
634-
can_reshape = False
635-
elif not has_valid_page_size(source_shape):
636-
can_reshape = False
637-
# Same as ttnn.squeeze with dim = 0
638-
# Supported:
639-
# (1, 16, 256, 256) -> (16, 256, 256)
640-
# (1, 256, 256) - > (256, 256)
641-
elif (source_rank != 1) and (out_rank == (source_rank - 1)) and (source_shape[0] == 1):
642-
for i in range(0, out_rank):
643-
if source_shape[i + 1] != out_shape[i]:
644-
can_reshape = False
645-
break
646-
647-
# Same as ttnn.unsqueeze_to_4D
648-
# Supported:
649-
# (16, 256, 256) -> (1, 16, 256, 256)
650-
# (256, 256) -> (1, 1, 256, 256)
651-
elif (out_rank > 1) and (out_rank <= 4) and (source_rank > 0) and (source_rank <= 4):
652-
for i in range(0, out_rank):
653-
si = i + (source_rank - out_rank)
654-
if si < 0:
655-
if out_shape[i] != 1:
656-
can_reshape = False
657-
break
658-
else:
659-
if out_shape[i] != source_shape[si]:
660-
can_reshape = False
661-
break
662-
else:
663-
can_reshape = False
664-
665-
# Transform to ttnn.reshape if possible
666-
if can_reshape:
667-
return g.call_function(ttnn.reshape, (args[0], args[1]), {})
668-
else:
669-
# Fallback: aten.reshape is more stable if the input nodes have changed
670-
return g.call_function(torch.ops.aten.reshape.default, args, kwargs)
671-
672626
with g.inserting_before(node):
673627
new_node = rewrite_node(node)
674628
if new_node is not None:

0 commit comments

Comments
 (0)