@@ -554,29 +554,27 @@ def rewrite_node(node):
554
554
return None
555
555
556
556
if node .target == torch .ops .aten .slice .Tensor :
557
+ tensor , dim , start , end , * step = args
558
+ (step ,) = step or [1 ]
559
+ input_size = tensor .meta ["val" ].size ()
557
560
558
- def callback ( tensor , dim , start , end , step = 1 ):
559
- input_size = tensor . meta [ "val" ]. size ()
561
+ if step != 1 or dim >= len ( input_size ):
562
+ return None
560
563
561
- if step != 1 or dim >= len ( input_size ) :
562
- return None
564
+ if start == 0 and end >= input_size [ dim ] :
565
+ return tensor
563
566
564
- if start == 0 and end >= input_size [ dim ] :
565
- return tensor
567
+ if len ( input_size ) != 4 :
568
+ return None
566
569
567
- if len (input_size ) != 4 :
568
- return None
570
+ slice_start = np . zeros ( len (input_size ), dtype = int )
571
+ slice_end = np . array ( input_size )
569
572
570
- slice_start = np .zeros (len (input_size ), dtype = int )
571
- slice_end = np .array (input_size )
573
+ slice_start [dim ] = start
574
+ slice_end [dim ] = min (end , input_size [dim ])
575
+ slice_end -= 1
572
576
573
- slice_start [dim ] = start
574
- slice_end [dim ] = min (end , input_size [dim ])
575
- slice_end -= 1
576
-
577
- return g .call_function (ttnn .slice , (tensor , [* slice_start ], [* slice_end ]))
578
-
579
- return callback (* args )
577
+ return g .call_function (ttnn .slice , (tensor , [* slice_start ], [* slice_end ]))
580
578
581
579
if node .target == torch .ops .aten .unsqueeze .default :
582
580
output_size = node .meta ["val" ].size ()
0 commit comments