Skip to content

Commit 7b4af97

Browse files
committed
Support aten.index's indices has None
1 parent b2caf35 commit 7b4af97

File tree

2 files changed

+52
-23
lines changed

2 files changed

+52
-23
lines changed

tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,9 @@ def _adjust_index_tensor(self, input_vals):
503503
new_indices = []
504504
for i in range(len(indices)):
505505
indice = indices[i]
506+
if indice is None:
507+
new_indices.append(None)
508+
continue
506509
new_indice = []
507510
for j in range(len(indice)):
508511
new_indice.append(torch.randint(0, self_shape[i], [1]))

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ def lower_binary_eltwise(fn, args):
692692

693693
return None
694694

695+
# if node.target == torch.ops.aten.reshape.default:
696+
# return g.call_function(ttnn.reshape, args, kwargs)
697+
695698
if node.target == torch.ops.aten.squeeze.dim or node.target == torch.ops.aten.squeeze.default:
696699
if use_less_ttnn_op_types or node.target == torch.ops.aten.squeeze.default:
697700
# ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default)
@@ -1201,14 +1204,16 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
12011204
new_kwargs["dtype"] = node.meta["val"].dtype
12021205
return g.call_function(torch.ops.aten.zeros.default, args=(target_shape, *args[2:]), kwargs=new_kwargs)
12031206

1204-
if node.target == torch.ops.aten.index.Tensor:
1207+
if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]:
12051208

12061209
def broadcast_indices(indices):
12071210
indices_shapes = [get_shape(gm, indices[i]) for i in range(len(indices))]
12081211
broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes))
12091212
broadcasted_indices = []
12101213
for i in range(len(indices)):
1211-
if indices_shapes[i] == broadcasted_shape:
1214+
if indices_shapes[i] is None:
1215+
broadcasted_indices.append(None)
1216+
elif indices_shapes[i] == broadcasted_shape:
12121217
broadcasted_indices.append(indices[i])
12131218
else:
12141219
broadcasted_indices.append(
@@ -1226,36 +1231,57 @@ def broadcast_indices(indices):
12261231
input_tensor, indices = args
12271232
if get_shape(gm, input_tensor) is None:
12281233
return None
1229-
if None in [get_shape(gm, indices[i]) for i in range(len(indices))]:
1230-
return None
1234+
for index in indices:
1235+
if index is not None and get_shape(gm, index) is None:
1236+
return None
12311237
index_shape, indices = broadcast_indices(indices)
1238+
if index_shape.numel() > 256:
1239+
# cannot create too much op, or will cause
1240+
# runtime args targeting kernel reader_concat_stick_layout_interleaved_start_id on
1241+
# (x=0,y=0) are too large. Max allowable is 256
1242+
return None
12321243
input_shape = get_shape(gm, input_tensor)
12331244
num_index = len(indices)
12341245
index_size = index_shape.numel()
1235-
remained_shape = input_shape[num_index:]
1246+
remained_shape = []
1247+
for i in range(len(indices)):
1248+
if indices[i] is None:
1249+
remained_shape.append(input_shape[i])
1250+
remained_shape += input_shape[num_index:]
1251+
remained_shape = torch.Size(remained_shape)
12361252
reshape_shape = index_shape + remained_shape
12371253
input_dtype = input_tensor.meta["val"].dtype
12381254
flatten_shape = torch.Size([index_size])
1239-
indices_flatten = [
1240-
g.call_function(
1241-
torch.ops.aten.reshape.default,
1242-
args=(idx, flatten_shape),
1243-
new_shape=flatten_shape,
1244-
new_dtype=idx.meta["val"].dtype,
1245-
)
1246-
for idx in indices
1247-
]
1255+
1256+
indices_flatten = []
1257+
for idx in indices:
1258+
if idx is None:
1259+
indices_flatten.append(None)
1260+
else:
1261+
indices_flatten.append(
1262+
g.call_function(
1263+
torch.ops.aten.reshape.default,
1264+
args=(idx, flatten_shape),
1265+
new_shape=flatten_shape,
1266+
new_dtype=idx.meta["val"].dtype,
1267+
)
1268+
)
12481269
output = []
12491270
for i in range(index_size):
1250-
indexing = [
1251-
g.call_function(
1252-
getitem,
1253-
args=(indices_flatten[n], i),
1254-
new_shape=torch.Size([]),
1255-
new_dtype=indices_flatten[n].meta["val"].dtype,
1256-
)
1257-
for n in range(num_index)
1258-
]
1271+
indexing = []
1272+
for n in range(num_index):
1273+
if indices_flatten[n] is None:
1274+
# TODO: unhasable!
1275+
indexing.append(slice(None))
1276+
else:
1277+
indexing.append(
1278+
g.call_function(
1279+
getitem,
1280+
args=(indices_flatten[n], i),
1281+
new_shape=torch.Size([]),
1282+
new_dtype=indices_flatten[n].meta["val"].dtype,
1283+
)
1284+
)
12591285
output.append(
12601286
g.call_function(getitem, args=(input_tensor, indexing), new_shape=remained_shape, new_dtype=input_dtype)
12611287
)

0 commit comments

Comments
 (0)