Skip to content

Commit 684aae4

Browse files
committed
Reduce one level of nesting
1 parent 29411f8 commit 684aae4

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -554,29 +554,27 @@ def rewrite_node(node):
554554
return None
555555

556556
if node.target == torch.ops.aten.slice.Tensor:
557+
tensor, dim, start, end, *step = args
558+
(step,) = step or [1]
559+
input_size = tensor.meta["val"].size()
557560

558-
def callback(tensor, dim, start, end, step=1):
559-
input_size = tensor.meta["val"].size()
561+
if step != 1 or dim >= len(input_size):
562+
return None
560563

561-
if step != 1 or dim >= len(input_size):
562-
return None
564+
if start == 0 and end >= input_size[dim]:
565+
return tensor
563566

564-
if start == 0 and end >= input_size[dim]:
565-
return tensor
567+
if len(input_size) != 4:
568+
return None
566569

567-
if len(input_size) != 4:
568-
return None
570+
slice_start = np.zeros(len(input_size), dtype=int)
571+
slice_end = np.array(input_size)
569572

570-
slice_start = np.zeros(len(input_size), dtype=int)
571-
slice_end = np.array(input_size)
573+
slice_start[dim] = start
574+
slice_end[dim] = min(end, input_size[dim])
575+
slice_end -= 1
572576

573-
slice_start[dim] = start
574-
slice_end[dim] = min(end, input_size[dim])
575-
slice_end -= 1
576-
577-
return g.call_function(ttnn.slice, (tensor, [*slice_start], [*slice_end]))
578-
579-
return callback(*args)
577+
return g.call_function(ttnn.slice, (tensor, [*slice_start], [*slice_end]))
580578

581579
if node.target == torch.ops.aten.unsqueeze.default:
582580
output_size = node.meta["val"].size()

0 commit comments

Comments
 (0)