1
1
import torch
2
2
import ttnn
3
3
import math
4
+ import numpy as np
4
5
from torch ._subclasses .fake_tensor import unset_fake_temporarily
5
6
from torch_ttnn .utils import (
6
7
GraphCleanup ,
@@ -446,13 +447,19 @@ def __init__(self, node):
446
447
self .g = node .graph
447
448
self .node = node
448
449
449
- def call_function (self , target , args = (), kwargs = {}):
450
+ def call_function (self , target , args = (), kwargs = {}, new_shape = None , new_dtype = None ):
450
451
new_node = self .g .call_function (target , args , kwargs )
451
452
new_node .meta = self .node .meta .copy ()
452
453
if hasattr (self .node .target , "_schema" ):
453
454
new_node .meta ["original_input_variations" ] = metrics .collect_input_variation_from_node (self .node )
454
455
if target == ttnn .layer_norm :
455
456
new_node .meta ["val" ] = new_node .meta ["val" ][0 ]
457
+ if new_shape is not None or new_dtype is not None :
458
+ shape = new_shape if new_shape is not None else new_node .meta ["val" ].size ()
459
+ dtype = new_dtype if new_dtype is not None else new_node .meta ["val" ].dtype
460
+ fake_mode = FakeTensorMode ()
461
+ fake_tensor = fake_mode .from_tensor (torch .zeros (shape , dtype = dtype ))
462
+ new_node .meta ["val" ] = fake_tensor
456
463
return new_node
457
464
458
465
def inserting_before (self , node ):
@@ -1144,14 +1151,7 @@ def rewrite_node(node):
1144
1151
1145
1152
if node .target == torch .ops .aten .index .Tensor :
1146
1153
1147
- def edit_meta_val (node , shape , dtype ):
1148
- fake_mode = FakeTensorMode ()
1149
- fake_tensor = fake_mode .from_tensor (torch .zeros (shape , dtype = dtype ))
1150
- node .meta ["val" ] = fake_tensor
1151
-
1152
1154
def broadcast_indices (indices ):
1153
- import numpy as np
1154
-
1155
1155
indices_shapes = [get_shape (indices [i ]) for i in range (len (indices ))]
1156
1156
broadcasted_shape = torch .Size (np .broadcast_shapes (* indices_shapes ))
1157
1157
broadcasted_indices = []
@@ -1160,9 +1160,13 @@ def broadcast_indices(indices):
1160
1160
broadcasted_indices .append (indices [i ])
1161
1161
else :
1162
1162
broadcasted_indices .append (
1163
- g .call_function (torch .ops .aten .expand .default , (indices [i ], broadcasted_shape ))
1163
+ g .call_function (
1164
+ torch .ops .aten .expand .default ,
1165
+ (indices [i ], broadcasted_shape ),
1166
+ new_shape = broadcasted_shape ,
1167
+ new_dtype = indices [i ].meta ["val" ].dtype ,
1168
+ )
1164
1169
)
1165
- edit_meta_val (broadcasted_indices [- 1 ], broadcasted_shape , indices [i ].meta ["val" ].dtype )
1166
1170
return broadcasted_shape , broadcasted_indices
1167
1171
1168
1172
# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
@@ -1181,28 +1185,55 @@ def broadcast_indices(indices):
1181
1185
input_dtype = input_tensor .meta ["val" ].dtype
1182
1186
flatten_shape = torch .Size ([index_size ])
1183
1187
indices_flatten = [
1184
- g .call_function (torch .ops .aten .reshape .default , args = (idx , flatten_shape )) for idx in indices
1188
+ g .call_function (
1189
+ torch .ops .aten .reshape .default ,
1190
+ args = (idx , flatten_shape ),
1191
+ new_shape = flatten_shape ,
1192
+ new_dtype = idx .meta ["val" ].dtype ,
1193
+ )
1194
+ for idx in indices
1185
1195
]
1186
- for i in range (len (indices_flatten )):
1187
- edit_meta_val (indices_flatten [i ], flatten_shape , indices [i ].meta ["val" ].dtype )
1188
1196
output = []
1189
1197
for i in range (index_size ):
1190
- indexing = [g .call_function (getitem , args = (indices_flatten [n ], i )) for n in range (num_index )]
1191
- for n in range (num_index ):
1192
- edit_meta_val (indexing [n ], torch .Size ([]), indices_flatten [n ].meta ["val" ].dtype )
1193
- output .append (g .call_function (getitem , args = (input_tensor , indexing )))
1194
- edit_meta_val (output [- 1 ], remained_shape , input_dtype )
1198
+ indexing = [
1199
+ g .call_function (
1200
+ getitem ,
1201
+ args = (indices_flatten [n ], i ),
1202
+ new_shape = torch .Size ([]),
1203
+ new_dtype = indices_flatten [n ].meta ["val" ].dtype ,
1204
+ )
1205
+ for n in range (num_index )
1206
+ ]
1207
+ output .append (
1208
+ g .call_function (
1209
+ getitem , args = (input_tensor , indexing ), new_shape = remained_shape , new_dtype = input_dtype
1210
+ )
1211
+ )
1195
1212
# aten.cat cannot concat zero dim tensor
1196
1213
if len (remained_shape ) == 0 :
1197
1214
remained_shape = torch .Size ([1 ])
1198
- output = [g .call_function (torch .ops .aten .reshape .default , args = (o , remained_shape )) for o in output ]
1199
- for o in output :
1200
- edit_meta_val (o , remained_shape , input_dtype )
1201
- output_cat = g .call_function (torch .ops .aten .cat .default , args = (output ,))
1215
+ output = [
1216
+ g .call_function (
1217
+ torch .ops .aten .reshape .default ,
1218
+ args = (o , remained_shape ),
1219
+ new_shape = remained_shape ,
1220
+ new_dtype = input_dtype ,
1221
+ )
1222
+ for o in output
1223
+ ]
1202
1224
output_cat_shape = torch .Size ([len (output )] + list (remained_shape ))
1203
- edit_meta_val (output_cat , output_cat_shape , input_dtype )
1204
- output_reshape = g .call_function (torch .ops .aten .reshape .default , args = (output_cat , reshape_shape ))
1205
- edit_meta_val (output_reshape , reshape_shape , input_dtype )
1225
+ output_cat = g .call_function (
1226
+ torch .ops .aten .cat .default ,
1227
+ args = (output ,),
1228
+ new_shape = output_cat_shape ,
1229
+ new_dtype = input_dtype ,
1230
+ )
1231
+ output_reshape = g .call_function (
1232
+ torch .ops .aten .reshape .default ,
1233
+ args = (output_cat , reshape_shape ),
1234
+ new_shape = reshape_shape ,
1235
+ new_dtype = input_dtype ,
1236
+ )
1206
1237
return output_reshape
1207
1238
1208
1239
with g .inserting_before (node ):
0 commit comments