Skip to content

Commit d1f7ac6

Browse files
3l1facebook-github-bot
authored andcommitted
Handle avg_pool2d with padding == 0 as no padding (pytorch#10697)
Summary: Pull Request resolved: pytorch#10697 Differential Revision: D74117402
1 parent 3ffb704 commit d1f7ac6

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

backends/arm/operator_support/pool_2d_support.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5454
kernel = cast(tuple[int, int], node.args[1])
5555
stride = cast(tuple[int, int], node.args[2])
5656
if len(node.args) > 3:
57+
padding = cast(tuple[int, int], node.args[3])
5758
# Padding case
58-
if not all(1 <= k <= 8 for k in kernel):
59+
if not all(1 <= k <= 8 for k in kernel) and not all(v == 0 for v in padding):
5960
self.reporter.report_reject(
6061
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
6162
)

backends/arm/test/ops/test_avg_pool2d.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
from typing import Tuple
1111

12+
import pytest
13+
1214
import torch
1315

14-
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test import common, conftest
1517

1618
from executorch.backends.arm.test.tester.test_pipeline import (
1719
EthosU55PipelineBI,
@@ -64,15 +66,24 @@ def forward(self, x):
6466

6567

6668
@common.parametrize("test_module", test_modules)
69+
@pytest.mark.tosa_ref_model
6770
def test_avgpool2d_tosa_MI(test_module):
6871
model, input_tensor = test_module
6972

70-
pipeline = TosaPipelineMI[input_t](model, input_tensor, aten_op, exir_op)
71-
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1)
72-
pipeline.run()
73+
pipeline = TosaPipelineMI[input_t](
74+
model,
75+
input_tensor,
76+
aten_op,
77+
exir_op,
78+
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
79+
)
80+
if conftest.get_option("tosa_version") == "1.0":
81+
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1)
82+
pipeline.run()
7383

7484

7585
@common.parametrize("test_module", test_modules)
86+
@pytest.mark.tosa_ref_model
7687
def test_avgpool2d_tosa_BI(test_module):
7788
model, input_tensor = test_module
7889

@@ -82,9 +93,11 @@ def test_avgpool2d_tosa_BI(test_module):
8293
aten_op,
8394
exir_op,
8495
symmetric_io_quantization=True,
96+
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
8597
)
86-
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1)
87-
pipeline.run()
98+
if conftest.get_option("tosa_version") == "0.80":
99+
pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1)
100+
pipeline.run()
88101

89102

90103
@common.parametrize("test_module", test_modules)

backends/arm/test/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_arm_tests():
1313

1414
# Operators
1515
test_files += [
16+
"ops/test_avg_pool2d.py",
1617
"ops/test_linear.py",
1718
"ops/test_slice.py",
1819
"ops/test_sigmoid.py",

0 commit comments

Comments
 (0)