Skip to content

Commit 86f8693

Browse files
committed
Make sure meta val of aten.index is correct
1 parent 2c24c7d commit 86f8693

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1919
import torch.fx.traceback as fx_traceback
20+
from torch._subclasses.fake_tensor import FakeTensorMode
2021
from . import target_wrappers
2122
from .to_tt_guard import can_lowering_to_ttnn
2223
from operator import getitem
@@ -447,7 +448,7 @@ def __init__(self, node):
447448

448449
def call_function(self, target, args=(), kwargs={}):
449450
new_node = self.g.call_function(target, args, kwargs)
450-
new_node.meta = self.node.meta
451+
new_node.meta = self.node.meta.copy()
451452
if hasattr(self.node.target, "_schema"):
452453
new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node)
453454
if target == ttnn.layer_norm:
@@ -1133,16 +1134,21 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps):
11331134

11341135

11351136
def DigestAtenOps(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1136-
g = gm.graph
11371137
nodes = list(gm.graph.nodes)
11381138
for node in nodes:
1139+
g = GraphWrapper(node)
11391140

11401141
def rewrite_node(node):
11411142
args = node.args
11421143
kwargs = node.kwargs
11431144

11441145
if node.target == torch.ops.aten.index.Tensor:
11451146

1147+
def edit_meta_val(node, shape, dtype):
1148+
fake_mode = FakeTensorMode()
1149+
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype))
1150+
node.meta["val"] = fake_tensor
1151+
11461152
def broadcast_indices(indices):
11471153
import numpy as np
11481154

@@ -1156,27 +1162,47 @@ def broadcast_indices(indices):
11561162
broadcasted_indices.append(
11571163
g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape))
11581164
)
1165+
edit_meta_val(broadcasted_indices[-1], broadcasted_shape, indices[i].meta["val"].dtype)
11591166
return broadcasted_shape, broadcasted_indices
11601167

11611168
# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
11621169
# then output is [[input[0][2], input[1][1], input[1][2]]]
11631170
input_tensor, indices = args
1171+
if get_shape(input_tensor) is None:
1172+
return None
1173+
if None in [get_shape(indices[i]) for i in range(len(indices))]:
1174+
return None
11641175
index_shape, indices = broadcast_indices(indices)
11651176
input_shape = get_shape(input_tensor)
11661177
num_index = len(indices)
11671178
index_size = index_shape.numel()
11681179
remained_shape = input_shape[num_index:]
11691180
reshape_shape = index_shape + remained_shape
1170-
indices_flatten = [g.call_function(torch.ops.aten.flatten, args=(idx,)) for idx in indices]
1181+
input_dtype = input_tensor.meta["val"].dtype
1182+
flatten_shape = torch.Size([index_size])
1183+
indices_flatten = [
1184+
g.call_function(torch.ops.aten.reshape.default, args=(idx, flatten_shape)) for idx in indices
1185+
]
1186+
for i in range(len(indices_flatten)):
1187+
edit_meta_val(indices_flatten[i], flatten_shape, indices[i].meta["val"].dtype)
11711188
output = []
11721189
for i in range(index_size):
11731190
indexing = [g.call_function(getitem, args=(indices_flatten[n], i)) for n in range(num_index)]
1191+
for n in range(num_index):
1192+
edit_meta_val(indexing[n], torch.Size([]), indices_flatten[n].meta["val"].dtype)
11741193
output.append(g.call_function(getitem, args=(input_tensor, indexing)))
1194+
edit_meta_val(output[-1], remained_shape, input_dtype)
11751195
# aten.cat cannot concat zero dim tensor
11761196
if len(remained_shape) == 0:
1177-
output = [g.call_function(torch.ops.aten.reshape, args=(o, [1])) for o in output]
1178-
output_cat = g.call_function(torch.ops.aten.cat, args=(output,))
1179-
output_reshape = g.call_function(torch.ops.aten.reshape, args=(output_cat, reshape_shape))
1197+
remained_shape = torch.Size([1])
1198+
output = [g.call_function(torch.ops.aten.reshape.default, args=(o, remained_shape)) for o in output]
1199+
for o in output:
1200+
edit_meta_val(o, remained_shape, input_dtype)
1201+
output_cat = g.call_function(torch.ops.aten.cat.default, args=(output,))
1202+
output_cat_shape = torch.Size([len(output)] + list(remained_shape))
1203+
edit_meta_val(output_cat, output_cat_shape, input_dtype)
1204+
output_reshape = g.call_function(torch.ops.aten.reshape.default, args=(output_cat, reshape_shape))
1205+
edit_meta_val(output_reshape, reshape_shape, input_dtype)
11801206
return output_reshape
11811207

11821208
with g.inserting_before(node):

0 commit comments

Comments
 (0)