1
1
import torch
2
2
import ttnn
3
3
import math
4
+ import numpy as np
4
5
from torch ._subclasses .fake_tensor import unset_fake_temporarily
5
6
from torch_ttnn .utils import (
6
7
GraphCleanup ,
@@ -446,13 +447,19 @@ def __init__(self, node):
446
447
self .g = node .graph
447
448
self .node = node
448
449
449
- def call_function (self , target , args = (), kwargs = {}):
450
+ def call_function (self , target , args = (), kwargs = {}, new_shape = None , new_dtype = None ):
450
451
new_node = self .g .call_function (target , args , kwargs )
451
452
new_node .meta = self .node .meta .copy ()
452
453
if hasattr (self .node .target , "_schema" ):
453
454
new_node .meta ["original_input_variations" ] = metrics .collect_input_variation_from_node (self .node )
454
455
if target == ttnn .layer_norm :
455
456
new_node .meta ["val" ] = new_node .meta ["val" ][0 ]
457
+ if new_shape is not None or new_dtype is not None :
458
+ shape = new_shape if new_shape is not None else new_node .meta ["val" ].size ()
459
+ dtype = new_dtype if new_dtype is not None else new_node .meta ["val" ].dtype
460
+ fake_mode = FakeTensorMode ()
461
+ fake_tensor = fake_mode .from_tensor (torch .zeros (shape , dtype = dtype ))
462
+ new_node .meta ["val" ] = fake_tensor
456
463
return new_node
457
464
458
465
def inserting_before (self , node ):
@@ -1144,14 +1151,7 @@ def rewrite_node(node):
1144
1151
1145
1152
if node .target == torch .ops .aten .index .Tensor :
1146
1153
1147
- def edit_meta_val (node , shape , dtype ):
1148
- fake_mode = FakeTensorMode ()
1149
- fake_tensor = fake_mode .from_tensor (torch .zeros (shape , dtype = dtype ))
1150
- node .meta ["val" ] = fake_tensor
1151
-
1152
1154
def broadcast_indices (indices ):
1153
- import numpy as np
1154
-
1155
1155
indices_shapes = [get_shape (indices [i ]) for i in range (len (indices ))]
1156
1156
broadcasted_shape = torch .Size (np .broadcast_shapes (* indices_shapes ))
1157
1157
broadcasted_indices = []
@@ -1160,9 +1160,10 @@ def broadcast_indices(indices):
1160
1160
broadcasted_indices .append (indices [i ])
1161
1161
else :
1162
1162
broadcasted_indices .append (
1163
- g .call_function (torch .ops .aten .expand .default , (indices [i ], broadcasted_shape ))
1163
+ g .call_function (torch .ops .aten .expand .default , (indices [i ], broadcasted_shape )),
1164
+ new_shape = broadcasted_shape ,
1165
+ new_dtype = indices [i ].meta ["val" ].dtype ,
1164
1166
)
1165
- edit_meta_val (broadcasted_indices [- 1 ], broadcasted_shape , indices [i ].meta ["val" ].dtype )
1166
1167
return broadcasted_shape , broadcasted_indices
1167
1168
1168
1169
# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
@@ -1181,28 +1182,55 @@ def broadcast_indices(indices):
1181
1182
input_dtype = input_tensor .meta ["val" ].dtype
1182
1183
flatten_shape = torch .Size ([index_size ])
1183
1184
indices_flatten = [
1184
- g .call_function (torch .ops .aten .reshape .default , args = (idx , flatten_shape )) for idx in indices
1185
+ g .call_function (
1186
+ torch .ops .aten .reshape .default ,
1187
+ args = (idx , flatten_shape ),
1188
+ new_shape = flatten_shape ,
1189
+ new_dtype = idx .meta ["val" ].dtype ,
1190
+ )
1191
+ for idx in indices
1185
1192
]
1186
- for i in range (len (indices_flatten )):
1187
- edit_meta_val (indices_flatten [i ], flatten_shape , indices [i ].meta ["val" ].dtype )
1188
1193
output = []
1189
1194
for i in range (index_size ):
1190
- indexing = [g .call_function (getitem , args = (indices_flatten [n ], i )) for n in range (num_index )]
1191
- for n in range (num_index ):
1192
- edit_meta_val (indexing [n ], torch .Size ([]), indices_flatten [n ].meta ["val" ].dtype )
1193
- output .append (g .call_function (getitem , args = (input_tensor , indexing )))
1194
- edit_meta_val (output [- 1 ], remained_shape , input_dtype )
1195
+ indexing = [
1196
+ g .call_function (
1197
+ getitem ,
1198
+ args = (indices_flatten [n ], i ),
1199
+ new_shape = torch .Size ([]),
1200
+ new_dtype = indices_flatten [n ].meta ["val" ].dtype ,
1201
+ )
1202
+ for n in range (num_index )
1203
+ ]
1204
+ output .append (
1205
+ g .call_function (getitem , args = (input_tensor , indexing )),
1206
+ new_shape = remained_shape ,
1207
+ new_dtype = input_dtype ,
1208
+ )
1195
1209
# aten.cat cannot concat zero dim tensor
1196
1210
if len (remained_shape ) == 0 :
1197
1211
remained_shape = torch .Size ([1 ])
1198
- output = [g .call_function (torch .ops .aten .reshape .default , args = (o , remained_shape )) for o in output ]
1199
- for o in output :
1200
- edit_meta_val (o , remained_shape , input_dtype )
1201
- output_cat = g .call_function (torch .ops .aten .cat .default , args = (output ,))
1212
+ output = [
1213
+ g .call_function (
1214
+ torch .ops .aten .reshape .default ,
1215
+ args = (o , remained_shape ),
1216
+ new_shape = remained_shape ,
1217
+ new_dtype = input_dtype ,
1218
+ )
1219
+ for o in output
1220
+ ]
1202
1221
output_cat_shape = torch .Size ([len (output )] + list (remained_shape ))
1203
- edit_meta_val (output_cat , output_cat_shape , input_dtype )
1204
- output_reshape = g .call_function (torch .ops .aten .reshape .default , args = (output_cat , reshape_shape ))
1205
- edit_meta_val (output_reshape , reshape_shape , input_dtype )
1222
+ output_cat = g .call_function (
1223
+ torch .ops .aten .cat .default ,
1224
+ args = (output ,),
1225
+ new_shape = output_cat_shape ,
1226
+ new_dtype = input_dtype ,
1227
+ )
1228
+ output_reshape = g .call_function (
1229
+ torch .ops .aten .reshape .default ,
1230
+ args = (output_cat , reshape_shape ),
1231
+ new_shape = reshape_shape ,
1232
+ new_dtype = input_dtype ,
1233
+ )
1206
1234
return output_reshape
1207
1235
1208
1236
with g .inserting_before (node ):
0 commit comments