Skip to content

Commit 0089868

Browse files
committed
[passes] Introduce LegalizeAvgpool2d
This commit is to accept torch.nn.AvgPool2d(count_include_pad = False). count_include_pad must be true for Circle Avgpool2D. Therefore, generate mask to make equivalent operation. TICO-DCO-1.0-Signed-off-by: Dayoung Lee <dayoung.lee@samsung.com>
1 parent 962669a commit 0089868

5 files changed

Lines changed: 176 additions & 3 deletions

File tree

test/modules/op/avg_pool2d.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,18 @@ def forward(self, tensor):
9696

9797
def get_example_inputs(self):
9898
return (torch.randn(2, 4, 8, 16),)
99+
100+
101+
class AvgPoolWithNoCountIncludePad(torch.nn.Module):
102+
def __init__(self):
103+
super().__init__()
104+
self.avgpool = torch.nn.AvgPool2d(
105+
kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), count_include_pad=False
106+
)
107+
108+
def forward(self, tensor):
109+
result = self.avgpool(tensor)
110+
return result
111+
112+
def get_example_inputs(self):
113+
return (torch.randn(1, 3, 56, 56),)

tico/passes/legalize_avgpool2d.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch.fx
19+
import torch
20+
from torch.export import ExportedProgram
21+
22+
from tico.utils import logging
23+
from tico.utils.passes import PassBase, PassResult
24+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
25+
from tico.utils.validate_args_kwargs import AvgPool2dArgs
26+
27+
28+
@trace_graph_diff_on_pass
29+
class LegalizeAvgpool2D(PassBase):
30+
"""
31+
Let's legalize avg_pool2d with various options.
32+
33+
Now it supports avg_pool2d (count_include_pad=False)
34+
35+
36+
[BEFORE]
37+
38+
input
39+
|
40+
avgpool2d (padding = padding, count_include_pad=False)
41+
|
42+
out
43+
44+
[AFTER]
45+
46+
input full_like (input, 1)
47+
| |
48+
padding (padding = padding) padding (padding = padding)
49+
| |
50+
avgpool2d (count_include_pad = True) avgpool2d
51+
| |
52+
------------------------- mul --------------------(=mask)
53+
|
54+
out
55+
56+
"""
57+
58+
def __init__(self):
59+
super().__init__()
60+
61+
def call(self, exported_program: ExportedProgram) -> PassResult:
62+
logger = logging.getLogger(__name__)
63+
64+
gm = exported_program.graph_module
65+
graph: torch.fx.Graph = gm.graph
66+
modified = False
67+
68+
for node in graph.nodes:
69+
if node.op != "call_function":
70+
continue
71+
72+
if node.target in [
73+
torch.ops.aten.avg_pool2d.default,
74+
]:
75+
args = AvgPool2dArgs(*node.args, **node.kwargs)
76+
input = args.input
77+
kernel_size = args.kernel_size
78+
stride = args.stride
79+
padding = args.padding
80+
count_include_pad = args.count_include_pad
81+
82+
if args.count_include_pad == True:
83+
continue
84+
85+
assert args.count_include_pad == False
86+
87+
with graph.inserting_before(node):
88+
# 1. Pad the input tensor
89+
x_padded = graph.call_function(
90+
torch.ops.aten.constant_pad_nd.default,
91+
(input, [padding[0], padding[0], padding[1], padding[1]], 0),
92+
)
93+
94+
# 2. Perform average pooling (with padding included)
95+
pooled = graph.call_function(
96+
torch.ops.aten.avg_pool2d.default,
97+
(x_padded, kernel_size, stride, [0, 0], count_include_pad),
98+
)
99+
100+
# 3. Calculate mask with valid pixel count ratio
101+
#
102+
# ones_padded -> mask
103+
# 0 0 0 0 . . . .
104+
# 0 1 1 1 -> . 4/9 6/9 6/9
105+
# 0 1 1 1 . 6/9 1 1
106+
ones = graph.call_function(
107+
torch.ops.aten.full_like.default, (pooled, 1.0)
108+
)
109+
ones_padded = graph.call_function(
110+
torch.ops.aten.constant_pad_nd.default,
111+
(ones, [padding[0], padding[0], padding[1], padding[1]], 0),
112+
)
113+
mask = graph.call_function(
114+
torch.ops.aten.avg_pool2d.default,
115+
(
116+
ones_padded,
117+
kernel_size,
118+
stride,
119+
[0, 0],
120+
),
121+
) # Already padded
122+
123+
result = graph.call_function(
124+
torch.ops.aten.div.Tensor, (pooled, mask)
125+
)
126+
127+
node.replace_all_uses_with(result, propagate_meta=True)
128+
129+
modified = True
130+
131+
gm.graph.eliminate_dead_code()
132+
gm.graph.lint()
133+
gm.recompile()
134+
135+
return PassResult(modified)

tico/passes/legalize_predefined_layout_operators.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
if TYPE_CHECKING:
1919
import torch.fx
20+
import math
21+
2022
import torch
2123
from torch.export import ExportedProgram
2224

@@ -317,11 +319,26 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool:
317319
if ceil_mode:
318320
raise NotYetSupportedError("Only support non-ceil model.")
319321
count_include_pad = args.count_include_pad
322+
320323
if not count_include_pad:
321-
# NOTE count_include_pad = False can be partially supported with SAME padding in circle.
322-
raise NotYetSupportedError(
323-
"For the case that the count_include_pad is False is not yet supported."
324+
# count_include_pad is False means that padding type is SAME
325+
input_shape = extract_shape(input_)
326+
327+
output_height = (
328+
math.ceil((input_shape[2] - kernel_size[0] + 1) / stride[0])
329+
+ padding[0] * 2
324330
)
331+
output_width = (
332+
math.ceil((input_shape[3] - kernel_size[1] + 1) / stride[1])
333+
+ padding[1] * 2
334+
)
335+
336+
# Check if its padding type is SAME
337+
if not (input_shape[2] == output_height and input_shape[3] == output_width):
338+
raise NotYetSupportedError(
339+
"Only support count_include_pad=False with SAME padding case"
340+
)
341+
325342
divisor_override = args.divisor_override
326343
if divisor_override is not None:
327344
raise NotYetSupportedError(

tico/serialize/operators/op_avg_pool2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def define_node(
4646
stride = args.stride
4747
padding = args.padding
4848

49+
if args.count_include_pad == False:
50+
# count_include_pad must be legalized by LegalizeAvgPool2D pass
51+
raise ValueError("count_include_pad must be True")
52+
4953
avgpool_input: torch.fx.Node | circle.Tensor.TensorT = input
5054

5155
def define_padding_node():

tico/utils/convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from tico.passes.extract_dtype_kwargs import ExtractDtypeKwargsPass
5353
from tico.passes.fill_meta_val import FillMetaVal
5454
from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean
55+
from tico.passes.legalize_avgpool2d import LegalizeAvgpool2D
5556
from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
5657
from tico.passes.legalize_predefined_layout_operators import (
5758
LegalizePreDefinedLayoutOperators,
@@ -208,6 +209,7 @@ def convert_exported_module_to_circle(
208209
DecomposeGroupedConv2d(),
209210
CastATenWhereArgType(),
210211
ConvertRepeatToExpandCopy(),
212+
LegalizeAvgpool2D(),
211213
*RemoveRedundantPermutePasses(),
212214
RemoveRedundantAssertionNodes(),
213215
RemoveRedundantExpand(),

0 commit comments

Comments
 (0)