Skip to content

Commit 358f0e9

Browse files
Remove guards of aten.full (#614)
Co-authored-by: Artem Yerofieiev <[email protected]>
1 parent 6cbda40 commit 358f0e9

File tree

2 files changed

+7
-47
lines changed

2 files changed

+7
-47
lines changed

tests/lowering/creation/test_full.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch_ttnn
33
import pytest
4-
import ttnn
54

65

76
class FullModule(torch.nn.Module):
@@ -14,7 +13,11 @@ def forward(self, size, fill_value):
1413

1514
@pytest.mark.parametrize(
1615
"input_shapes",
17-
[[(64, 128)]],
16+
[
17+
[(64, 128)],
18+
[(19, 19)],
19+
[(59, 59)],
20+
],
1821
)
1922
def test_full(device, input_shapes):
2023
m = FullModule()
@@ -29,6 +32,7 @@ def test_full(device, input_shapes):
2932

3033
# Check the graph has be rewritten and contain ttnn ops
3134
nodes = list(option._out_fx_graphs[0].nodes)
32-
assert [node.target for node in nodes].count(ttnn.full) == 1
35+
# Check the graph has be rewritten and aten ops are replaced
36+
assert not any(node.target == torch.ops.aten.full.default for node in nodes)
3337
# Check inference result
3438
assert torch.allclose(result_before, result_after)

torch_ttnn/passes/lowering/to_tt_guard_autogen.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -62,46 +62,6 @@
6262
["Tensor<[1, 16, 1, 60]> self = ?", "Tensor<[]> other = ?"],
6363
]
6464
aten__log_softmax_default_blocklist = [["Tensor<[19, 256008]> self = ?", "int dim = 1", "bool half_to_float = False"]]
65-
aten_full_default_blocklist = [
66-
[
67-
"List[int] size = [19, 19]",
68-
"number fill_value = -3.4028234663852886e+38",
69-
"Optional[Device] device = cpu",
70-
"Optional[bool] pin_memory = False",
71-
],
72-
[
73-
"List[int] size = [7, 7]",
74-
"number fill_value = -3.3895313892515355e+38",
75-
"Optional[Device] device = cpu",
76-
"Optional[bool] pin_memory = False",
77-
],
78-
[
79-
"List[int] size = [45, 45]",
80-
"number fill_value = -3.3895313892515355e+38",
81-
"Optional[Device] device = cpu",
82-
"Optional[bool] pin_memory = False",
83-
],
84-
[
85-
"List[int] size = [59, 59]",
86-
"number fill_value = -3.3895313892515355e+38",
87-
"Optional[Device] device = cpu",
88-
"Optional[bool] pin_memory = False",
89-
],
90-
[
91-
"List[int] size = [19, 19]",
92-
"number fill_value = -3.3895313892515355e+38",
93-
"Optional[Device] device = cpu",
94-
"Optional[bool] pin_memory = False",
95-
],
96-
]
97-
# TODO(#615): Dynamic shape is not supported yet
98-
aten_full_like_default_blocklist = [
99-
[
100-
"Tensor<[s0 + 1, s0 + 1]> self = ?",
101-
"number fill_value = 31",
102-
"Optional[bool] pin_memory = False",
103-
],
104-
]
10565
aten__scaled_dot_product_flash_attention_default_blocklist = [
10666
["Tensor<[1, 16, 197, 64]> query = ?", "Tensor<[1, 16, 197, 64]> key = ?", "Tensor<[1, 16, 197, 64]> value = ?"],
10767
["Tensor<[1, 12, 197, 64]> query = ?", "Tensor<[1, 12, 197, 64]> key = ?", "Tensor<[1, 12, 197, 64]> value = ?"],
@@ -1402,8 +1362,6 @@ def guard_aten(blocklist, node):
14021362
torch.ops.aten.clamp.default: partial(guard_aten, aten_clamp_default_blocklist),
14031363
torch.ops.aten.maximum.default: partial(guard_aten, aten_maximum_default_blocklist),
14041364
torch.ops.aten._log_softmax.default: partial(guard_aten, aten__log_softmax_default_blocklist),
1405-
torch.ops.aten.full.default: partial(guard_aten, aten_full_default_blocklist),
1406-
torch.ops.aten.full_like.default: partial(guard_aten, aten_full_like_default_blocklist),
14071365
torch.ops.aten._scaled_dot_product_flash_attention.default: partial(
14081366
guard_aten, aten__scaled_dot_product_flash_attention_default_blocklist
14091367
),
@@ -1430,8 +1388,6 @@ def guard_aten(blocklist, node):
14301388
"torch.ops.aten.clamp.default",
14311389
"torch.ops.aten.maximum.default",
14321390
"torch.ops.aten._log_softmax.default",
1433-
"torch.ops.aten.full.default",
1434-
"torch.ops.aten.full_like.default",
14351391
"torch.ops.aten.rsub.Scalar",
14361392
"torch.ops.aten._scaled_dot_product_flash_attention.default",
14371393
"torch.ops.aten.transpose.int",

0 commit comments

Comments
 (0)