17
17
18
18
from torch .fx .passes .infra .pass_base import PassBase , PassResult
19
19
import torch .fx .traceback as fx_traceback
20
+ from torch ._subclasses .fake_tensor import FakeTensorMode
20
21
from . import target_wrappers
21
22
from .to_tt_guard import can_lowering_to_ttnn
22
23
from operator import getitem
@@ -447,7 +448,7 @@ def __init__(self, node):
447
448
448
449
def call_function (self , target , args = (), kwargs = {}):
449
450
new_node = self .g .call_function (target , args , kwargs )
450
- new_node .meta = self .node .meta
451
+ new_node .meta = self .node .meta . copy ()
451
452
if hasattr (self .node .target , "_schema" ):
452
453
new_node .meta ["original_input_variations" ] = metrics .collect_input_variation_from_node (self .node )
453
454
if target == ttnn .layer_norm :
@@ -1133,16 +1134,21 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps):
1133
1134
1134
1135
1135
1136
def DigestAtenOps (gm : torch .fx .GraphModule ) -> torch .fx .GraphModule :
1136
- g = gm .graph
1137
1137
nodes = list (gm .graph .nodes )
1138
1138
for node in nodes :
1139
+ g = GraphWrapper (node )
1139
1140
1140
1141
def rewrite_node (node ):
1141
1142
args = node .args
1142
1143
kwargs = node .kwargs
1143
1144
1144
1145
if node .target == torch .ops .aten .index .Tensor :
1145
1146
1147
+ def edit_meta_val (node , shape , dtype ):
1148
+ fake_mode = FakeTensorMode ()
1149
+ fake_tensor = fake_mode .from_tensor (torch .zeros (shape , dtype = dtype ))
1150
+ node .meta ["val" ] = fake_tensor
1151
+
1146
1152
def broadcast_indices (indices ):
1147
1153
import numpy as np
1148
1154
@@ -1156,27 +1162,47 @@ def broadcast_indices(indices):
1156
1162
broadcasted_indices .append (
1157
1163
g .call_function (torch .ops .aten .expand .default , (indices [i ], broadcasted_shape ))
1158
1164
)
1165
+ edit_meta_val (broadcasted_indices [- 1 ], broadcasted_shape , indices [i ].meta ["val" ].dtype )
1159
1166
return broadcasted_shape , broadcasted_indices
1160
1167
1161
1168
# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
1162
1169
# then output is [[input[0][2], input[1][1], input[1][2]]]
1163
1170
input_tensor , indices = args
1171
+ if get_shape (input_tensor ) is None :
1172
+ return None
1173
+ if None in [get_shape (indices [i ]) for i in range (len (indices ))]:
1174
+ return None
1164
1175
index_shape , indices = broadcast_indices (indices )
1165
1176
input_shape = get_shape (input_tensor )
1166
1177
num_index = len (indices )
1167
1178
index_size = index_shape .numel ()
1168
1179
remained_shape = input_shape [num_index :]
1169
1180
reshape_shape = index_shape + remained_shape
1170
- indices_flatten = [g .call_function (torch .ops .aten .flatten , args = (idx ,)) for idx in indices ]
1181
+ input_dtype = input_tensor .meta ["val" ].dtype
1182
+ flatten_shape = torch .Size ([index_size ])
1183
+ indices_flatten = [
1184
+ g .call_function (torch .ops .aten .reshape .default , args = (idx , flatten_shape )) for idx in indices
1185
+ ]
1186
+ for i in range (len (indices_flatten )):
1187
+ edit_meta_val (indices_flatten [i ], flatten_shape , indices [i ].meta ["val" ].dtype )
1171
1188
output = []
1172
1189
for i in range (index_size ):
1173
1190
indexing = [g .call_function (getitem , args = (indices_flatten [n ], i )) for n in range (num_index )]
1191
+ for n in range (num_index ):
1192
+ edit_meta_val (indexing [n ], torch .Size ([]), indices_flatten [n ].meta ["val" ].dtype )
1174
1193
output .append (g .call_function (getitem , args = (input_tensor , indexing )))
1194
+ edit_meta_val (output [- 1 ], remained_shape , input_dtype )
1175
1195
# aten.cat cannot concat zero dim tensor
1176
1196
if len (remained_shape ) == 0 :
1177
- output = [g .call_function (torch .ops .aten .reshape , args = (o , [1 ])) for o in output ]
1178
- output_cat = g .call_function (torch .ops .aten .cat , args = (output ,))
1179
- output_reshape = g .call_function (torch .ops .aten .reshape , args = (output_cat , reshape_shape ))
1197
+ remained_shape = torch .Size ([1 ])
1198
+ output = [g .call_function (torch .ops .aten .reshape .default , args = (o , remained_shape )) for o in output ]
1199
+ for o in output :
1200
+ edit_meta_val (o , remained_shape , input_dtype )
1201
+ output_cat = g .call_function (torch .ops .aten .cat .default , args = (output ,))
1202
+ output_cat_shape = torch .Size ([len (output )] + list (remained_shape ))
1203
+ edit_meta_val (output_cat , output_cat_shape , input_dtype )
1204
+ output_reshape = g .call_function (torch .ops .aten .reshape .default , args = (output_cat , reshape_shape ))
1205
+ edit_meta_val (output_reshape , reshape_shape , input_dtype )
1180
1206
return output_reshape
1181
1207
1182
1208
with g .inserting_before (node ):
0 commit comments