Skip to content

Commit 61b2c39

Browse files
committed
Try to support aten.index's indices has None but failed of slice(None) unhashable
1 parent b2caf35 commit 61b2c39

File tree

2 files changed

+56
-24
lines changed

2 files changed

+56
-24
lines changed

tests/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def __init__(self, op_name: str, input_strings: List[str]):
423423
"aten.index.Tensor": self._adjust_index_tensor,
424424
"aten.index_put.default": self._adjust_index_tensor,
425425
"aten._native_batch_norm_legit_no_training.default": self._adjust__native_batch_norm_legit_no_training_default,
426-
# "aten._unsafe_index.Tensor": self._adjust_index_tensor,
426+
"aten._unsafe_index.Tensor": self._adjust_index_tensor,
427427
}
428428

429429
def _adjust_bitwise_not_default(self, input_vals):
@@ -503,6 +503,9 @@ def _adjust_index_tensor(self, input_vals):
503503
new_indices = []
504504
for i in range(len(indices)):
505505
indice = indices[i]
506+
if indice is None:
507+
new_indices.append(None)
508+
continue
506509
new_indice = []
507510
for j in range(len(indice)):
508511
new_indice.append(torch.randint(0, self_shape[i], [1]))

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ def lower_binary_eltwise(fn, args):
692692

693693
return None
694694

695+
# if node.target == torch.ops.aten.reshape.default:
696+
# return g.call_function(ttnn.reshape, args, kwargs)
697+
695698
if node.target == torch.ops.aten.squeeze.dim or node.target == torch.ops.aten.squeeze.default:
696699
if use_less_ttnn_op_types or node.target == torch.ops.aten.squeeze.default:
697700
# ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default)
@@ -1201,14 +1204,16 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
12011204
new_kwargs["dtype"] = node.meta["val"].dtype
12021205
return g.call_function(torch.ops.aten.zeros.default, args=(target_shape, *args[2:]), kwargs=new_kwargs)
12031206

1204-
if node.target == torch.ops.aten.index.Tensor:
1207+
if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]:
12051208

12061209
def broadcast_indices(indices):
12071210
indices_shapes = [get_shape(gm, indices[i]) for i in range(len(indices))]
12081211
broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes))
12091212
broadcasted_indices = []
12101213
for i in range(len(indices)):
1211-
if indices_shapes[i] == broadcasted_shape:
1214+
if indices_shapes[i] is None:
1215+
broadcasted_indices.append(None)
1216+
elif indices_shapes[i] == broadcasted_shape:
12121217
broadcasted_indices.append(indices[i])
12131218
else:
12141219
broadcasted_indices.append(
@@ -1226,36 +1231,60 @@ def broadcast_indices(indices):
12261231
input_tensor, indices = args
12271232
if get_shape(gm, input_tensor) is None:
12281233
return None
1229-
if None in [get_shape(gm, indices[i]) for i in range(len(indices))]:
1230-
return None
1234+
for index in indices:
1235+
if index is None:
1236+
# slice(None) unhasable!
1237+
return None
1238+
if index is not None and get_shape(gm, index) is None:
1239+
return None
12311240
index_shape, indices = broadcast_indices(indices)
1241+
if index_shape.numel() > 256:
1242+
# cannot create too much op, or will cause
1243+
# runtime args targeting kernel reader_concat_stick_layout_interleaved_start_id on
1244+
# (x=0,y=0) are too large. Max allowable is 256
1245+
return None
12321246
input_shape = get_shape(gm, input_tensor)
12331247
num_index = len(indices)
12341248
index_size = index_shape.numel()
1235-
remained_shape = input_shape[num_index:]
1249+
remained_shape = []
1250+
for i in range(len(indices)):
1251+
if indices[i] is None:
1252+
remained_shape.append(input_shape[i])
1253+
remained_shape += input_shape[num_index:]
1254+
remained_shape = torch.Size(remained_shape)
12361255
reshape_shape = index_shape + remained_shape
12371256
input_dtype = input_tensor.meta["val"].dtype
12381257
flatten_shape = torch.Size([index_size])
1239-
indices_flatten = [
1240-
g.call_function(
1241-
torch.ops.aten.reshape.default,
1242-
args=(idx, flatten_shape),
1243-
new_shape=flatten_shape,
1244-
new_dtype=idx.meta["val"].dtype,
1245-
)
1246-
for idx in indices
1247-
]
1258+
1259+
indices_flatten = []
1260+
for idx in indices:
1261+
if idx is None:
1262+
indices_flatten.append(None)
1263+
else:
1264+
indices_flatten.append(
1265+
g.call_function(
1266+
torch.ops.aten.reshape.default,
1267+
args=(idx, flatten_shape),
1268+
new_shape=flatten_shape,
1269+
new_dtype=idx.meta["val"].dtype,
1270+
)
1271+
)
12481272
output = []
12491273
for i in range(index_size):
1250-
indexing = [
1251-
g.call_function(
1252-
getitem,
1253-
args=(indices_flatten[n], i),
1254-
new_shape=torch.Size([]),
1255-
new_dtype=indices_flatten[n].meta["val"].dtype,
1256-
)
1257-
for n in range(num_index)
1258-
]
1274+
indexing = []
1275+
for n in range(num_index):
1276+
if indices_flatten[n] is None:
1277+
# TODO: unhasable!
1278+
indexing.append(slice(None))
1279+
else:
1280+
indexing.append(
1281+
g.call_function(
1282+
getitem,
1283+
args=(indices_flatten[n], i),
1284+
new_shape=torch.Size([]),
1285+
new_dtype=indices_flatten[n].meta["val"].dtype,
1286+
)
1287+
)
12591288
output.append(
12601289
g.call_function(getitem, args=(input_tensor, indexing), new_shape=remained_shape, new_dtype=input_dtype)
12611290
)

0 commit comments

Comments
 (0)