11
11
TtnnL1MemoryConfig ,
12
12
TtnnRowMajorLayout ,
13
13
TtnnTileLayout ,
14
+ SliceNone ,
14
15
get_shape ,
15
16
)
16
17
import numpy as np
@@ -692,6 +693,9 @@ def lower_binary_eltwise(fn, args):
692
693
693
694
return None
694
695
696
+ # if node.target == torch.ops.aten.reshape.default:
697
+ # return g.call_function(ttnn.reshape, args, kwargs)
698
+
695
699
if node .target == torch .ops .aten .squeeze .dim or node .target == torch .ops .aten .squeeze .default :
696
700
if use_less_ttnn_op_types or node .target == torch .ops .aten .squeeze .default :
697
701
# ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default)
@@ -1201,14 +1205,16 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
1201
1205
new_kwargs ["dtype" ] = node .meta ["val" ].dtype
1202
1206
return g .call_function (torch .ops .aten .zeros .default , args = (target_shape , * args [2 :]), kwargs = new_kwargs )
1203
1207
1204
- if node .target == torch .ops .aten .index .Tensor :
1208
+ if node .target in [ torch .ops .aten .index .Tensor , torch . ops . aten . _unsafe_index . Tensor ] :
1205
1209
1206
1210
def broadcast_indices (indices ):
1207
1211
indices_shapes = [get_shape (gm , indices [i ]) for i in range (len (indices ))]
1208
1212
broadcasted_shape = torch .Size (np .broadcast_shapes (* indices_shapes ))
1209
1213
broadcasted_indices = []
1210
1214
for i in range (len (indices )):
1211
- if indices_shapes [i ] == broadcasted_shape :
1215
+ if indices_shapes [i ] is None :
1216
+ broadcasted_indices .append (None )
1217
+ elif indices_shapes [i ] == broadcasted_shape :
1212
1218
broadcasted_indices .append (indices [i ])
1213
1219
else :
1214
1220
broadcasted_indices .append (
@@ -1226,36 +1232,56 @@ def broadcast_indices(indices):
1226
1232
input_tensor , indices = args
1227
1233
if get_shape (gm , input_tensor ) is None :
1228
1234
return None
1229
- if None in [get_shape (gm , indices [i ]) for i in range (len (indices ))]:
1230
- return None
1235
+ for index in indices :
1236
+ if index is not None and get_shape (gm , index ) is None :
1237
+ return None
1231
1238
index_shape , indices = broadcast_indices (indices )
1239
+ if index_shape .numel () > 256 :
1240
+ # cannot create too much op, or will cause
1241
+ # runtime args targeting kernel reader_concat_stick_layout_interleaved_start_id on
1242
+ # (x=0,y=0) are too large. Max allowable is 256
1243
+ return None
1232
1244
input_shape = get_shape (gm , input_tensor )
1233
1245
num_index = len (indices )
1234
1246
index_size = index_shape .numel ()
1235
- remained_shape = input_shape [num_index :]
1247
+ remained_shape = []
1248
+ for i in range (len (indices )):
1249
+ if indices [i ] is None :
1250
+ remained_shape .append (input_shape [i ])
1251
+ remained_shape += input_shape [num_index :]
1252
+ remained_shape = torch .Size (remained_shape )
1236
1253
reshape_shape = index_shape + remained_shape
1237
1254
input_dtype = input_tensor .meta ["val" ].dtype
1238
1255
flatten_shape = torch .Size ([index_size ])
1239
- indices_flatten = [
1240
- g .call_function (
1241
- torch .ops .aten .reshape .default ,
1242
- args = (idx , flatten_shape ),
1243
- new_shape = flatten_shape ,
1244
- new_dtype = idx .meta ["val" ].dtype ,
1245
- )
1246
- for idx in indices
1247
- ]
1256
+
1257
+ indices_flatten = []
1258
+ for idx in indices :
1259
+ if idx is None :
1260
+ indices_flatten .append (None )
1261
+ else :
1262
+ indices_flatten .append (
1263
+ g .call_function (
1264
+ torch .ops .aten .reshape .default ,
1265
+ args = (idx , flatten_shape ),
1266
+ new_shape = flatten_shape ,
1267
+ new_dtype = idx .meta ["val" ].dtype ,
1268
+ )
1269
+ )
1248
1270
output = []
1249
1271
for i in range (index_size ):
1250
- indexing = [
1251
- g .call_function (
1252
- getitem ,
1253
- args = (indices_flatten [n ], i ),
1254
- new_shape = torch .Size ([]),
1255
- new_dtype = indices_flatten [n ].meta ["val" ].dtype ,
1256
- )
1257
- for n in range (num_index )
1258
- ]
1272
+ indexing = []
1273
+ for n in range (num_index ):
1274
+ if indices_flatten [n ] is None :
1275
+ indexing .append (slice (None ))
1276
+ else :
1277
+ indexing .append (
1278
+ g .call_function (
1279
+ getitem ,
1280
+ args = (indices_flatten [n ], i ),
1281
+ new_shape = torch .Size ([]),
1282
+ new_dtype = indices_flatten [n ].meta ["val" ].dtype ,
1283
+ )
1284
+ )
1259
1285
output .append (
1260
1286
g .call_function (getitem , args = (input_tensor , indexing ), new_shape = remained_shape , new_dtype = input_dtype )
1261
1287
)
0 commit comments