Skip to content

Commit c2be067

Browse files
committed
Support aten.index's indices has None
1 parent b2caf35 commit c2be067

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,9 @@ def _adjust_index_tensor(self, input_vals):
503503
new_indices = []
504504
for i in range(len(indices)):
505505
indice = indices[i]
506+
if indice is None:
507+
new_indices.append(None)
508+
continue
506509
new_indice = []
507510
for j in range(len(indice)):
508511
new_indice.append(torch.randint(0, self_shape[i], [1]))

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TtnnL1MemoryConfig,
1212
TtnnRowMajorLayout,
1313
TtnnTileLayout,
14+
SliceNone,
1415
get_shape,
1516
)
1617
import numpy as np
@@ -692,6 +693,9 @@ def lower_binary_eltwise(fn, args):
692693

693694
return None
694695

696+
# if node.target == torch.ops.aten.reshape.default:
697+
# return g.call_function(ttnn.reshape, args, kwargs)
698+
695699
if node.target == torch.ops.aten.squeeze.dim or node.target == torch.ops.aten.squeeze.default:
696700
if use_less_ttnn_op_types or node.target == torch.ops.aten.squeeze.default:
697701
# ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default)
@@ -1201,14 +1205,16 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
12011205
new_kwargs["dtype"] = node.meta["val"].dtype
12021206
return g.call_function(torch.ops.aten.zeros.default, args=(target_shape, *args[2:]), kwargs=new_kwargs)
12031207

1204-
if node.target == torch.ops.aten.index.Tensor:
1208+
if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]:
12051209

12061210
def broadcast_indices(indices):
12071211
indices_shapes = [get_shape(gm, indices[i]) for i in range(len(indices))]
12081212
broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes))
12091213
broadcasted_indices = []
12101214
for i in range(len(indices)):
1211-
if indices_shapes[i] == broadcasted_shape:
1215+
if indices_shapes[i] is None:
1216+
broadcasted_indices.append(None)
1217+
elif indices_shapes[i] == broadcasted_shape:
12121218
broadcasted_indices.append(indices[i])
12131219
else:
12141220
broadcasted_indices.append(
@@ -1226,36 +1232,56 @@ def broadcast_indices(indices):
12261232
input_tensor, indices = args
12271233
if get_shape(gm, input_tensor) is None:
12281234
return None
1229-
if None in [get_shape(gm, indices[i]) for i in range(len(indices))]:
1230-
return None
1235+
for index in indices:
1236+
if index is not None and get_shape(gm, index) is None:
1237+
return None
12311238
index_shape, indices = broadcast_indices(indices)
1239+
if index_shape.numel() > 256:
1240+
# cannot create too much op, or will cause
1241+
# runtime args targeting kernel reader_concat_stick_layout_interleaved_start_id on
1242+
# (x=0,y=0) are too large. Max allowable is 256
1243+
return None
12321244
input_shape = get_shape(gm, input_tensor)
12331245
num_index = len(indices)
12341246
index_size = index_shape.numel()
1235-
remained_shape = input_shape[num_index:]
1247+
remained_shape = []
1248+
for i in range(len(indices)):
1249+
if indices[i] is None:
1250+
remained_shape.append(input_shape[i])
1251+
remained_shape += input_shape[num_index:]
1252+
remained_shape = torch.Size(remained_shape)
12361253
reshape_shape = index_shape + remained_shape
12371254
input_dtype = input_tensor.meta["val"].dtype
12381255
flatten_shape = torch.Size([index_size])
1239-
indices_flatten = [
1240-
g.call_function(
1241-
torch.ops.aten.reshape.default,
1242-
args=(idx, flatten_shape),
1243-
new_shape=flatten_shape,
1244-
new_dtype=idx.meta["val"].dtype,
1245-
)
1246-
for idx in indices
1247-
]
1256+
1257+
indices_flatten = []
1258+
for idx in indices:
1259+
if idx is None:
1260+
indices_flatten.append(None)
1261+
else:
1262+
indices_flatten.append(
1263+
g.call_function(
1264+
torch.ops.aten.reshape.default,
1265+
args=(idx, flatten_shape),
1266+
new_shape=flatten_shape,
1267+
new_dtype=idx.meta["val"].dtype,
1268+
)
1269+
)
12481270
output = []
12491271
for i in range(index_size):
1250-
indexing = [
1251-
g.call_function(
1252-
getitem,
1253-
args=(indices_flatten[n], i),
1254-
new_shape=torch.Size([]),
1255-
new_dtype=indices_flatten[n].meta["val"].dtype,
1256-
)
1257-
for n in range(num_index)
1258-
]
1272+
indexing = []
1273+
for n in range(num_index):
1274+
if indices_flatten[n] is None:
1275+
indexing.append(slice(None))
1276+
else:
1277+
indexing.append(
1278+
g.call_function(
1279+
getitem,
1280+
args=(indices_flatten[n], i),
1281+
new_shape=torch.Size([]),
1282+
new_dtype=indices_flatten[n].meta["val"].dtype,
1283+
)
1284+
)
12591285
output.append(
12601286
g.call_function(getitem, args=(input_tensor, indexing), new_shape=remained_shape, new_dtype=input_dtype)
12611287
)

0 commit comments

Comments
 (0)