1+ // Copyright (C) 2025 Intel Corporation
2+ // SPDX-License-Identifier: Apache-2.0
3+ //
4+
5+ #include < gtest/gtest.h>
6+
7+ #include < memory>
8+ #include < vector>
9+
10+ #include " openvino/core/type/element_type.hpp"
11+ #include " openvino/op/add.hpp"
12+ #include " openvino/op/constant.hpp"
13+ #include " openvino/op/interpolate.hpp"
14+ #include " openvino/op/multiply.hpp"
15+ #include " openvino/op/parameter.hpp"
16+ #include " openvino/op/result.hpp"
17+ #include " shared_test_classes/base/ov_subgraph.hpp"
18+ #include " utils/cpu_test_utils.hpp"
19+
20+ namespace ov {
21+ namespace test {
22+
23+ class InterpolateWithPostOps : public SubgraphBaseStaticTest , public ::testing::WithParamInterface<bool > {
24+ public:
25+ static std::string getTestCaseName (const ::testing::TestParamInfo<bool >& info) {
26+ return info.param ? " Interpolate_NoFuse_NCHWAsNHWC" : " Interpolate_Fuse_DefaultAxes" ;
27+ }
28+
29+ protected:
30+ bool NCHWAsNHWC_NoFuse = false ;
31+ void SetUp () override {
32+ NCHWAsNHWC_NoFuse = GetParam ();
33+ ov::element::Type netPrecision = ov::element::f32 ;
34+ targetDevice = ov::test::utils::DEVICE_CPU;
35+
36+ std::shared_ptr<ov::Model> raw_function;
37+ if (NCHWAsNHWC_NoFuse) {
38+ auto input_shape = ov::Shape{1 , 3 , 128 , 128 };
39+ auto mul_const_shape = ov::Shape{1 , 1 , 1 , 128 };
40+ auto add_const_shape = ov::Shape{1 , 1 , 1 , 128 };
41+ auto input = std::make_shared<ov::op::v0::Parameter>(netPrecision, input_shape);
42+ auto sizes = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {256 , 256 });
43+ auto axes = ov::op::v0::Constant::create (ov::element::i64 , {2 }, {1 , 2 });
44+ auto interpolate = std::make_shared<ov::op::v11::Interpolate>(
45+ input,
46+ sizes,
47+ axes,
48+ ov::op::v11::Interpolate::InterpolateAttrs{
49+ ov::op::v11::Interpolate::InterpolateMode::BILINEAR_PILLOW,
50+ ov::op::v11::Interpolate::ShapeCalcMode::SIZES,
51+ {0 , 0 , 0 , 0 },
52+ {0 , 0 , 0 , 0 },
53+ ov::op::v11::Interpolate::CoordinateTransformMode::HALF_PIXEL,
54+ ov::op::v11::Interpolate::NearestMode::FLOOR,
55+ false ,
56+ -0 .75f });
57+ auto mul_const = ov::op::v0::Constant::create (netPrecision, mul_const_shape, {1 .0f });
58+ auto add_const = ov::op::v0::Constant::create (netPrecision, add_const_shape, {2 .0f });
59+ auto mul = std::make_shared<ov::op::v1::Multiply>(interpolate, mul_const);
60+ auto add = std::make_shared<ov::op::v1::Add>(mul, add_const);
61+ auto result = std::make_shared<ov::op::v0::Result>(add);
62+ raw_function = std::make_shared<ov::Model>(result,
63+ ov::ParameterVector{input},
64+ " Interpolate_with_post_ops_NoFuse_NCHWAsNHWC" );
65+ } else {
66+ auto input_shape = ov::Shape{1 , 3 , 128 , 128 };
67+ auto mul_const_shape = ov::Shape{1 , 3 , 1 , 1 };
68+ auto add_const_shape = ov::Shape{1 , 3 , 1 , 1 };
69+ auto input = std::make_shared<ov::op::v0::Parameter>(netPrecision, input_shape);
70+ auto sizes = ov::op::v0::Constant::create (ov::element::i64 , {4 }, {1 , 3 , 256 , 256 });
71+ auto interpolate = std::make_shared<ov::op::v11::Interpolate>(
72+ input,
73+ sizes,
74+ ov::op::v11::Interpolate::InterpolateAttrs{
75+ ov::op::v11::Interpolate::InterpolateMode::LINEAR_ONNX,
76+ ov::op::v11::Interpolate::ShapeCalcMode::SIZES,
77+ {0 , 0 , 0 , 0 },
78+ {0 , 0 , 0 , 0 },
79+ ov::op::v11::Interpolate::CoordinateTransformMode::HALF_PIXEL,
80+ ov::op::v11::Interpolate::NearestMode::FLOOR,
81+ false ,
82+ -0 .75f });
83+ auto mul_const = ov::op::v0::Constant::create (netPrecision, mul_const_shape, {1 .0f });
84+ auto add_const = ov::op::v0::Constant::create (netPrecision, add_const_shape, {2 .0f });
85+ auto mul = std::make_shared<ov::op::v1::Multiply>(interpolate, mul_const);
86+ auto add = std::make_shared<ov::op::v1::Add>(mul, add_const);
87+ auto result = std::make_shared<ov::op::v0::Result>(add);
88+ raw_function = std::make_shared<ov::Model>(result,
89+ ov::ParameterVector{input},
90+ " Interpolate_with_post_ops_Fuse_DefaultAxes" );
91+ }
92+ auto ppp_model = ov::preprocess::PrePostProcessor (raw_function);
93+ ppp_model.input ().tensor ().set_layout (" NHWC" );
94+ function = ppp_model.build ();
95+ }
96+ };
97+
98+ TEST_P (InterpolateWithPostOps, CheckInterpolateWithPostOps) {
99+ run ();
100+ if (NCHWAsNHWC_NoFuse) {
101+ CPUTestUtils::CheckNumberOfNodesWithTypes (compiledModel, {" Subgraph" , " Eltwise" }, 1 );
102+ } else {
103+ CPUTestUtils::CheckNumberOfNodesWithTypes (compiledModel, {" Subgraph" , " Eltwise" }, 0 );
104+ }
105+ }
106+
107+ INSTANTIATE_TEST_SUITE_P (InterpolateWithPostOpsFusion,
108+ InterpolateWithPostOps,
109+ ::testing::Values (true , false ),
110+ InterpolateWithPostOps::getTestCaseName);
111+
112+ } // namespace test
113+ } // namespace ov
0 commit comments