@@ -692,6 +692,9 @@ def lower_binary_eltwise(fn, args):
692
692
693
693
return None
694
694
695
+ # if node.target == torch.ops.aten.reshape.default:
696
+ # return g.call_function(ttnn.reshape, args, kwargs)
697
+
695
698
if node .target == torch .ops .aten .squeeze .dim or node .target == torch .ops .aten .squeeze .default :
696
699
if use_less_ttnn_op_types or node .target == torch .ops .aten .squeeze .default :
697
700
# ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default)
@@ -1201,14 +1204,16 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
1201
1204
new_kwargs ["dtype" ] = node .meta ["val" ].dtype
1202
1205
return g .call_function (torch .ops .aten .zeros .default , args = (target_shape , * args [2 :]), kwargs = new_kwargs )
1203
1206
1204
- if node .target == torch .ops .aten .index .Tensor :
1207
+ if node .target in [ torch .ops .aten .index .Tensor , torch . ops . aten . _unsafe_index . Tensor ] :
1205
1208
1206
1209
def broadcast_indices (indices ):
1207
1210
indices_shapes = [get_shape (gm , indices [i ]) for i in range (len (indices ))]
1208
1211
broadcasted_shape = torch .Size (np .broadcast_shapes (* indices_shapes ))
1209
1212
broadcasted_indices = []
1210
1213
for i in range (len (indices )):
1211
- if indices_shapes [i ] == broadcasted_shape :
1214
+ if indices_shapes [i ] is None :
1215
+ broadcasted_indices .append (None )
1216
+ elif indices_shapes [i ] == broadcasted_shape :
1212
1217
broadcasted_indices .append (indices [i ])
1213
1218
else :
1214
1219
broadcasted_indices .append (
@@ -1226,36 +1231,60 @@ def broadcast_indices(indices):
1226
1231
input_tensor , indices = args
1227
1232
if get_shape (gm , input_tensor ) is None :
1228
1233
return None
1229
- if None in [get_shape (gm , indices [i ]) for i in range (len (indices ))]:
1230
- return None
1234
+ for index in indices :
1235
+ if index is None :
1236
+ # slice(None) unhasable!
1237
+ return None
1238
+ if index is not None and get_shape (gm , index ) is None :
1239
+ return None
1231
1240
index_shape , indices = broadcast_indices (indices )
1241
+ if index_shape .numel () > 256 :
1242
+ # cannot create too much op, or will cause
1243
+ # runtime args targeting kernel reader_concat_stick_layout_interleaved_start_id on
1244
+ # (x=0,y=0) are too large. Max allowable is 256
1245
+ return None
1232
1246
input_shape = get_shape (gm , input_tensor )
1233
1247
num_index = len (indices )
1234
1248
index_size = index_shape .numel ()
1235
- remained_shape = input_shape [num_index :]
1249
+ remained_shape = []
1250
+ for i in range (len (indices )):
1251
+ if indices [i ] is None :
1252
+ remained_shape .append (input_shape [i ])
1253
+ remained_shape += input_shape [num_index :]
1254
+ remained_shape = torch .Size (remained_shape )
1236
1255
reshape_shape = index_shape + remained_shape
1237
1256
input_dtype = input_tensor .meta ["val" ].dtype
1238
1257
flatten_shape = torch .Size ([index_size ])
1239
- indices_flatten = [
1240
- g .call_function (
1241
- torch .ops .aten .reshape .default ,
1242
- args = (idx , flatten_shape ),
1243
- new_shape = flatten_shape ,
1244
- new_dtype = idx .meta ["val" ].dtype ,
1245
- )
1246
- for idx in indices
1247
- ]
1258
+
1259
+ indices_flatten = []
1260
+ for idx in indices :
1261
+ if idx is None :
1262
+ indices_flatten .append (None )
1263
+ else :
1264
+ indices_flatten .append (
1265
+ g .call_function (
1266
+ torch .ops .aten .reshape .default ,
1267
+ args = (idx , flatten_shape ),
1268
+ new_shape = flatten_shape ,
1269
+ new_dtype = idx .meta ["val" ].dtype ,
1270
+ )
1271
+ )
1248
1272
output = []
1249
1273
for i in range (index_size ):
1250
- indexing = [
1251
- g .call_function (
1252
- getitem ,
1253
- args = (indices_flatten [n ], i ),
1254
- new_shape = torch .Size ([]),
1255
- new_dtype = indices_flatten [n ].meta ["val" ].dtype ,
1256
- )
1257
- for n in range (num_index )
1258
- ]
1274
+ indexing = []
1275
+ for n in range (num_index ):
1276
+ if indices_flatten [n ] is None :
1277
+ # TODO: unhasable!
1278
+ indexing .append (slice (None ))
1279
+ else :
1280
+ indexing .append (
1281
+ g .call_function (
1282
+ getitem ,
1283
+ args = (indices_flatten [n ], i ),
1284
+ new_shape = torch .Size ([]),
1285
+ new_dtype = indices_flatten [n ].meta ["val" ].dtype ,
1286
+ )
1287
+ )
1259
1288
output .append (
1260
1289
g .call_function (getitem , args = (input_tensor , indexing ), new_shape = remained_shape , new_dtype = input_dtype )
1261
1290
)
0 commit comments