Skip to content

Commit fd2ccb1

Browse files
committed
move edit_meta_val to call_function
1 parent ddd7fe3 commit fd2ccb1

File tree

2 files changed

+54
-26
lines changed

2 files changed

+54
-26
lines changed

tests/lowering/misc/test_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def forward(self, input, indices):
2424
((3, 4, 5), [[[0, 1, 1], [1, 1, 0]], [[2, 1, 2]]]), # broadcast
2525
],
2626
)
27-
def test_select(device, input_shapes, indices):
27+
def test_index(device, input_shapes, indices):
2828
m = IndexModule()
2929
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
3030
indices = [torch.tensor(index) for index in indices]

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import ttnn
33
import math
4+
import numpy as np
45
from torch._subclasses.fake_tensor import unset_fake_temporarily
56
from torch_ttnn.utils import (
67
GraphCleanup,
@@ -446,13 +447,19 @@ def __init__(self, node):
446447
self.g = node.graph
447448
self.node = node
448449

449-
def call_function(self, target, args=(), kwargs={}):
450+
def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None):
450451
new_node = self.g.call_function(target, args, kwargs)
451452
new_node.meta = self.node.meta.copy()
452453
if hasattr(self.node.target, "_schema"):
453454
new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node)
454455
if target == ttnn.layer_norm:
455456
new_node.meta["val"] = new_node.meta["val"][0]
457+
if new_shape is not None or new_dtype is not None:
458+
shape = new_shape if new_shape is not None else new_node.meta["val"].size()
459+
dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype
460+
fake_mode = FakeTensorMode()
461+
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype))
462+
new_node.meta["val"] = fake_tensor
456463
return new_node
457464

458465
def inserting_before(self, node):
@@ -1144,14 +1151,7 @@ def rewrite_node(node):
11441151

11451152
if node.target == torch.ops.aten.index.Tensor:
11461153

1147-
def edit_meta_val(node, shape, dtype):
1148-
fake_mode = FakeTensorMode()
1149-
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype))
1150-
node.meta["val"] = fake_tensor
1151-
11521154
def broadcast_indices(indices):
1153-
import numpy as np
1154-
11551155
indices_shapes = [get_shape(indices[i]) for i in range(len(indices))]
11561156
broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes))
11571157
broadcasted_indices = []
@@ -1160,9 +1160,10 @@ def broadcast_indices(indices):
11601160
broadcasted_indices.append(indices[i])
11611161
else:
11621162
broadcasted_indices.append(
1163-
g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape))
1163+
g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape)),
1164+
new_shape=broadcasted_shape,
1165+
new_dtype=indices[i].meta["val"].dtype,
11641166
)
1165-
edit_meta_val(broadcasted_indices[-1], broadcasted_shape, indices[i].meta["val"].dtype)
11661167
return broadcasted_shape, broadcasted_indices
11671168

11681169
# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
@@ -1181,28 +1182,55 @@ def broadcast_indices(indices):
11811182
input_dtype = input_tensor.meta["val"].dtype
11821183
flatten_shape = torch.Size([index_size])
11831184
indices_flatten = [
1184-
g.call_function(torch.ops.aten.reshape.default, args=(idx, flatten_shape)) for idx in indices
1185+
g.call_function(
1186+
torch.ops.aten.reshape.default,
1187+
args=(idx, flatten_shape),
1188+
new_shape=flatten_shape,
1189+
new_dtype=idx.meta["val"].dtype,
1190+
)
1191+
for idx in indices
11851192
]
1186-
for i in range(len(indices_flatten)):
1187-
edit_meta_val(indices_flatten[i], flatten_shape, indices[i].meta["val"].dtype)
11881193
output = []
11891194
for i in range(index_size):
1190-
indexing = [g.call_function(getitem, args=(indices_flatten[n], i)) for n in range(num_index)]
1191-
for n in range(num_index):
1192-
edit_meta_val(indexing[n], torch.Size([]), indices_flatten[n].meta["val"].dtype)
1193-
output.append(g.call_function(getitem, args=(input_tensor, indexing)))
1194-
edit_meta_val(output[-1], remained_shape, input_dtype)
1195+
indexing = [
1196+
g.call_function(
1197+
getitem,
1198+
args=(indices_flatten[n], i),
1199+
new_shape=torch.Size([]),
1200+
new_dtype=indices_flatten[n].meta["val"].dtype,
1201+
)
1202+
for n in range(num_index)
1203+
]
1204+
output.append(
1205+
g.call_function(getitem, args=(input_tensor, indexing)),
1206+
new_shape=remained_shape,
1207+
new_dtype=input_dtype,
1208+
)
11951209
# aten.cat cannot concat zero dim tensor
11961210
if len(remained_shape) == 0:
11971211
remained_shape = torch.Size([1])
1198-
output = [g.call_function(torch.ops.aten.reshape.default, args=(o, remained_shape)) for o in output]
1199-
for o in output:
1200-
edit_meta_val(o, remained_shape, input_dtype)
1201-
output_cat = g.call_function(torch.ops.aten.cat.default, args=(output,))
1212+
output = [
1213+
g.call_function(
1214+
torch.ops.aten.reshape.default,
1215+
args=(o, remained_shape),
1216+
new_shape=remained_shape,
1217+
new_dtype=input_dtype,
1218+
)
1219+
for o in output
1220+
]
12021221
output_cat_shape = torch.Size([len(output)] + list(remained_shape))
1203-
edit_meta_val(output_cat, output_cat_shape, input_dtype)
1204-
output_reshape = g.call_function(torch.ops.aten.reshape.default, args=(output_cat, reshape_shape))
1205-
edit_meta_val(output_reshape, reshape_shape, input_dtype)
1222+
output_cat = g.call_function(
1223+
torch.ops.aten.cat.default,
1224+
args=(output,),
1225+
new_shape=output_cat_shape,
1226+
new_dtype=input_dtype,
1227+
)
1228+
output_reshape = g.call_function(
1229+
torch.ops.aten.reshape.default,
1230+
args=(output_cat, reshape_shape),
1231+
new_shape=reshape_shape,
1232+
new_dtype=input_dtype,
1233+
)
12061234
return output_reshape
12071235

12081236
with g.inserting_before(node):

0 commit comments

Comments
 (0)