Skip to content

Commit cb43ae0

Browse files
committed
Updated step validations, now different than 1. Since the compiler already fills with 1 when the array is 1, we can assume that will be always equals or bigger than 1
1 parent a88567e commit cb43ae0

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def reshape_1d(code, args=args, kwargs=kwargs):
695695
[step] = step or [1]
696696
rank = len(input_size)
697697

698-
if step != 1 or dim >= rank:
698+
if dim >= rank:
699699
return None
700700

701701
# Check if no-op, just return the input tensor
@@ -1146,9 +1146,7 @@ def reshape_1d(code, args=args, kwargs=kwargs):
11461146
# slice_scatter could be concat([pre_slice_tensor, src_tensor, post_slice_tensor])
11471147
rank = len(tensor_shape)
11481148
[step] = step or [1]
1149-
if step != 1:
1150-
return None
1151-
1149+
11521150
assert dim < rank, f"The slice dim {dim} should be less than rank {rank}"
11531151

11541152
dim = (dim + rank) % rank

0 commit comments

Comments
 (0)