Skip to content

Commit 6007c0b

Browse files
committed
Add test cases for Interpolate with post ops:
- NCHWAsNHWC without fused post ops - DefaultAxes with fused post ops
1 parent 96f0a01 commit 6007c0b

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <gtest/gtest.h>
6+
7+
#include <memory>
8+
#include <vector>
9+
10+
#include "openvino/core/type/element_type.hpp"
11+
#include "openvino/op/add.hpp"
12+
#include "openvino/op/constant.hpp"
13+
#include "openvino/op/interpolate.hpp"
14+
#include "openvino/op/multiply.hpp"
15+
#include "openvino/op/parameter.hpp"
16+
#include "openvino/op/result.hpp"
17+
#include "shared_test_classes/base/ov_subgraph.hpp"
18+
#include "utils/cpu_test_utils.hpp"
19+
20+
namespace ov {
21+
namespace test {
22+
23+
class InterpolateWithPostOps : public SubgraphBaseStaticTest, public ::testing::WithParamInterface<bool> {
24+
public:
25+
static std::string getTestCaseName(const ::testing::TestParamInfo<bool>& info) {
26+
return info.param ? "Interpolate_NoFuse_NCHWAsNHWC" : "Interpolate_Fuse_DefaultAxes";
27+
}
28+
29+
protected:
30+
bool NCHWAsNHWC_NoFuse = false;
31+
void SetUp() override {
32+
NCHWAsNHWC_NoFuse = GetParam();
33+
ov::element::Type netPrecision = ov::element::f32;
34+
targetDevice = ov::test::utils::DEVICE_CPU;
35+
36+
std::shared_ptr<ov::Model> raw_function;
37+
if (NCHWAsNHWC_NoFuse) {
38+
auto input_shape = ov::Shape{1, 3, 128, 128};
39+
auto mul_const_shape = ov::Shape{1, 1, 1, 128};
40+
auto add_const_shape = ov::Shape{1, 1, 1, 128};
41+
auto input = std::make_shared<ov::op::v0::Parameter>(netPrecision, input_shape);
42+
auto sizes = ov::op::v0::Constant::create(ov::element::i64, {2}, {256, 256});
43+
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2});
44+
auto interpolate = std::make_shared<ov::op::v11::Interpolate>(
45+
input,
46+
sizes,
47+
axes,
48+
ov::op::v11::Interpolate::InterpolateAttrs{
49+
ov::op::v11::Interpolate::InterpolateMode::BILINEAR_PILLOW,
50+
ov::op::v11::Interpolate::ShapeCalcMode::SIZES,
51+
{0, 0, 0, 0},
52+
{0, 0, 0, 0},
53+
ov::op::v11::Interpolate::CoordinateTransformMode::HALF_PIXEL,
54+
ov::op::v11::Interpolate::NearestMode::FLOOR,
55+
false,
56+
-0.75f});
57+
auto mul_const = ov::op::v0::Constant::create(netPrecision, mul_const_shape, {1.0f});
58+
auto add_const = ov::op::v0::Constant::create(netPrecision, add_const_shape, {2.0f});
59+
auto mul = std::make_shared<ov::op::v1::Multiply>(interpolate, mul_const);
60+
auto add = std::make_shared<ov::op::v1::Add>(mul, add_const);
61+
auto result = std::make_shared<ov::op::v0::Result>(add);
62+
raw_function = std::make_shared<ov::Model>(result,
63+
ov::ParameterVector{input},
64+
"Interpolate_with_post_ops_NoFuse_NCHWAsNHWC");
65+
} else {
66+
auto input_shape = ov::Shape{1, 3, 128, 128};
67+
auto mul_const_shape = ov::Shape{1, 3, 1, 1};
68+
auto add_const_shape = ov::Shape{1, 3, 1, 1};
69+
auto input = std::make_shared<ov::op::v0::Parameter>(netPrecision, input_shape);
70+
auto sizes = ov::op::v0::Constant::create(ov::element::i64, {4}, {1, 3, 256, 128});
71+
auto interpolate = std::make_shared<ov::op::v11::Interpolate>(
72+
input,
73+
sizes,
74+
ov::op::v11::Interpolate::InterpolateAttrs{
75+
ov::op::v11::Interpolate::InterpolateMode::LINEAR_ONNX,
76+
ov::op::v11::Interpolate::ShapeCalcMode::SIZES,
77+
{0, 0, 0, 0},
78+
{0, 0, 0, 0},
79+
ov::op::v11::Interpolate::CoordinateTransformMode::HALF_PIXEL,
80+
ov::op::v11::Interpolate::NearestMode::FLOOR,
81+
false,
82+
-0.75f});
83+
auto mul_const = ov::op::v0::Constant::create(netPrecision, mul_const_shape, {1.0f});
84+
auto add_const = ov::op::v0::Constant::create(netPrecision, add_const_shape, {2.0f});
85+
auto mul = std::make_shared<ov::op::v1::Multiply>(interpolate, mul_const);
86+
auto add = std::make_shared<ov::op::v1::Add>(mul, add_const);
87+
auto result = std::make_shared<ov::op::v0::Result>(add);
88+
raw_function = std::make_shared<ov::Model>(result,
89+
ov::ParameterVector{input},
90+
"Interpolate_with_post_ops_Fuse_DefaultAxes");
91+
}
92+
auto ppp_model = ov::preprocess::PrePostProcessor(raw_function);
93+
ppp_model.input().tensor().set_layout("NHWC");
94+
function = ppp_model.build();
95+
}
96+
};
97+
98+
TEST_P(InterpolateWithPostOps, CheckInterpolateWithPostOps) {
99+
run();
100+
if (NCHWAsNHWC_NoFuse) {
101+
CPUTestUtils::CheckNumberOfNodesWithTypes(compiledModel, {"Subgraph", "Eltwise"}, 1);
102+
} else {
103+
CPUTestUtils::CheckNumberOfNodesWithTypes(compiledModel, {"Subgraph", "Eltwise"}, 0);
104+
}
105+
}
106+
107+
INSTANTIATE_TEST_SUITE_P(InterpolateWithPostOpsFusion,
108+
InterpolateWithPostOps,
109+
::testing::Values(true, false),
110+
InterpolateWithPostOps::getTestCaseName);
111+
112+
} // namespace test
113+
} // namespace ov

0 commit comments

Comments
 (0)