Skip to content

[CORE] Squeeze v15 reverse infer #27526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
579a5e2
add test with empty axes
barnasm1 Oct 31, 2024
2738301
remove attribute allow_axis_skip for empty axis Ctor
barnasm1 Nov 12, 2024
4ad95df
refactor squeezeV0 shape_infer common validate input
barnasm1 Nov 12, 2024
7255e6f
refactor squeezeV15 shape_infer common validate input
barnasm1 Nov 12, 2024
132a7f1
format
barnasm1 Nov 13, 2024
de3a1c8
format2
barnasm1 Nov 13, 2024
82d8ef1
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 13, 2024
c82e8fd
update squeeze15 ctor args
barnasm1 Nov 13, 2024
a3011b8
unify enable_if usage
barnasm1 Nov 13, 2024
084550b
reverse ctor update
barnasm1 Nov 13, 2024
09ba124
fix namespace selection for windows build
barnasm1 Nov 13, 2024
66e328a
code foramt
barnasm1 Nov 13, 2024
2eafc25
fix test
barnasm1 Nov 13, 2024
37b939f
code foramt
barnasm1 Nov 14, 2024
bd92721
avoid code duplication, use function: set_output_shape
barnasm1 Nov 14, 2024
9dff25c
code foramt
barnasm1 Nov 14, 2024
6ce7ccb
use ov::optional
barnasm1 Nov 14, 2024
d827fd4
separate test cases
barnasm1 Nov 14, 2024
cc1346e
use common in OV for this doxy style
barnasm1 Nov 14, 2024
b08f8bf
direct return output_shape
barnasm1 Nov 14, 2024
b2c7ac8
do not pass lambda functor into validator
barnasm1 Nov 14, 2024
fc0b5ae
unused variable
barnasm1 Nov 14, 2024
cc3da2f
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 14, 2024
727331d
undo doc format for op
barnasm1 Nov 15, 2024
cfd532a
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 19, 2024
a0e5d67
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 25, 2024
9996e6a
gather reverse infer + squeeze v15 rt_info
barnasm1 Nov 25, 2024
fbfb6d5
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 25, 2024
139c0b1
const rt_info name
barnasm1 Nov 25, 2024
771c3af
revert gitignore
barnasm1 Nov 28, 2024
7f4a1a1
direct use ov::PartialShape::dynamic()
barnasm1 Nov 28, 2024
e717976
revert gitignore empty line
barnasm1 Nov 28, 2024
6a66abe
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 28, 2024
2b69d81
update gather reverse infer
barnasm1 Nov 28, 2024
25c5f39
fix gather transformation
barnasm1 Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ std::shared_ptr<ov::Model> create_v15_model(const IndicesMode indices_mode,
ov::ParameterVector params = {data};
std::shared_ptr<op::v15::Squeeze> squeeze;
if (indices_mode == IndicesMode::NONE) {
squeeze = std::make_shared<ov::opset15::Squeeze>(data, allow_axis_skip);
squeeze = std::make_shared<ov::opset15::Squeeze>(data);
} else if (indices_mode == IndicesMode::PARAM) {
const auto& indices =
std::make_shared<ov::opset15::Parameter>(ov::element::i32, PartialShape({data_shape.rank()}));
Expand Down
2 changes: 1 addition & 1 deletion src/core/include/openvino/op/squeeze.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class OPENVINO_API Squeeze : public util::SqueezeBase {
///
/// \param data Input tensor with data
/// \param allow_axis_skip Shape inference result dynamic rank if selected axis has 1 in range of its dynamic
Squeeze(const Output<Node>& data, const bool allow_axis_skip = false);
Squeeze(const Output<Node>& data);
/// \brief Constructs a squeeze v15 operation.
///
/// \param data Input tensor with data
Expand Down
115 changes: 67 additions & 48 deletions src/core/shape_inference/include/squeeze_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,27 @@

namespace ov {
namespace op {
namespace v0 {
template <class T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const Squeeze* op,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
namespace util {
template <class T,
class TRShape = result_shape_t<T>,
class Lambda,
class Squeeze,
typename = typename std::enable_if<std::is_same<Squeeze, ov::op::v0::Squeeze>::value ||
std::is_same<Squeeze, ov::op::v15::Squeeze>::value,
bool>::type>
bool validate_input_and_try_set_output_shape(const Squeeze* op,
std::unique_ptr<std::set<int64_t>>& unique_axes,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta,
TRShape& output_shape,
Lambda output_shape_for_squeezable_dim) {
using DimType = typename T::value_type;

const auto number_of_inputs = input_shapes.size();
OPENVINO_ASSERT(!input_shapes.empty());
const auto number_of_inputs = input_shapes.size();

const auto& arg_shape = input_shapes[0];
const auto& arg_rank = arg_shape.rank();
auto output_shapes = std::vector<TRShape>(1);
auto& output_shape = output_shapes[0];

std::unique_ptr<std::set<int64_t>> unique_axes;

if (number_of_inputs == 1) {
unique_axes.reset(new std::set<int64_t>());
Expand All @@ -36,7 +41,7 @@ std::vector<TRShape> shape_infer(const Squeeze* op,
"Second input (axes) should not be of rank higher than 1. Got: ",
axes_shape.rank().get_length());

std::vector<int64_t> axes;
std::vector<int64_t> axes{};
if (arg_rank.is_static() && axes_shape.is_static()) {
if (auto axes = get_input_const_data_as<TRShape, int64_t>(op, 1, ta)) {
// The values of `axes` input are known
Expand All @@ -49,18 +54,56 @@ std::vector<TRShape> shape_infer(const Squeeze* op,
return dim.compatible(1);
});
if (has_squeezable_dim) {
output_shape = PartialShape::dynamic(arg_rank.get_length() - 1);
output_shape = output_shape_for_squeezable_dim();
} else {
output_shape = arg_shape;
}
return output_shapes;
return true;
}
}
} else {
// Invalid number of inputs, empty error message for backward compatibility.
NODE_VALIDATION_CHECK(op, false);
}

return false;
}
} // namespace util
namespace v0 {
/**
* \brief Do Squeeze shape inference.
*
* \tparam T Type of input/output shapes.
*
* \param op Squeeze operator pointer.
* \param input_shapes Squeeze input shapes.
* \param ta Tensor accessor to constant data.
*/
template <class T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const Squeeze* op,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
using DimType = typename T::value_type;

const auto& arg_shape = input_shapes[0];
const auto& arg_rank = arg_shape.rank();
auto output_shapes = std::vector<TRShape>(1);
auto& output_shape = output_shapes[0];
std::unique_ptr<std::set<int64_t>> unique_axes{};

auto output_shape_for_squeezable_dim = [&]() {
return PartialShape::dynamic(arg_rank.get_length() - 1);
};

if (util::validate_input_and_try_set_output_shape(op,
unique_axes,
input_shapes,
ta,
output_shape,
output_shape_for_squeezable_dim)) {
return output_shapes;
}

if (arg_rank.is_static() && (unique_axes != nullptr)) {
output_shape.resize(0);
if (unique_axes->empty()) {
Expand Down Expand Up @@ -138,48 +181,24 @@ std::vector<TRShape> shape_infer(const Squeeze* op,
const ITensorAccessor& ta = make_tensor_accessor()) {
using DimType = typename T::value_type;

const auto number_of_inputs = input_shapes.size();
OPENVINO_ASSERT(!input_shapes.empty());

const auto& arg_shape = input_shapes[0];
const auto& arg_rank = arg_shape.rank();
auto output_shapes = std::vector<TRShape>(1);
auto& output_shape = output_shapes[0];

std::unique_ptr<std::set<int64_t>> unique_axes;
std::unique_ptr<std::set<int64_t>> unique_axes{};

if (number_of_inputs == 1) {
unique_axes.reset(new std::set<int64_t>());
} else if (number_of_inputs == 2) {
const auto& axes_shape = input_shapes[1];
NODE_VALIDATION_CHECK(op,
axes_shape.is_dynamic() || ov::util::is_rank_compatible_any_of(axes_shape.rank(), {0, 1}),
"Second input (axes) should not be of rank higher than 1. Got: ",
axes_shape.rank().get_length());
auto output_shape_for_squeezable_dim = [&]() {
return PartialShape::dynamic();
};

std::vector<int64_t> axes;
if (arg_rank.is_static() && axes_shape.is_static()) {
if (auto axes = get_input_const_data_as<TRShape, int64_t>(op, 1, ta)) {
// The values of `axes` input are known
ov::util::try_normalize_axes(*axes, arg_rank, *op);
unique_axes.reset(new std::set<int64_t>(axes->cbegin(), axes->cend()));
} else if (arg_rank.get_length() > 0 && shape_size(axes_shape.to_shape()) == 1) {
// The `axes` input is a single element tensor which is unique by definition, deducing output rank
const auto has_squeezable_dim =
std::any_of(arg_shape.cbegin(), arg_shape.cend(), [](const DimType& dim) {
return dim.compatible(1);
});
if (has_squeezable_dim) {
output_shape = PartialShape::dynamic();
} else {
output_shape = arg_shape;
}
return output_shapes;
}
}
} else {
// Invalid number of inputs, empty error message for backward compatibility.
NODE_VALIDATION_CHECK(op, false);
if (util::validate_input_and_try_set_output_shape(op,
unique_axes,
input_shapes,
ta,
output_shape,
output_shape_for_squeezable_dim)) {
return output_shapes;
}

if (!arg_rank.is_static() || (unique_axes == nullptr) || apply_allow_axis_skip(op, unique_axes, arg_shape)) {
Expand Down
6 changes: 2 additions & 4 deletions src/core/src/op/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ bool Squeeze::evaluate(TensorVector& outputs, const TensorVector& inputs) const
namespace v15 {
Squeeze::Squeeze() : util::SqueezeBase() {}

Squeeze::Squeeze(const Output<Node>& data, const bool allow_axis_skip)
: util::SqueezeBase(data),
m_allow_axis_skip{allow_axis_skip} {
Squeeze::Squeeze(const Output<Node>& data) : util::SqueezeBase(data) {
constructor_validate_and_infer_types();
}

Expand All @@ -80,7 +78,7 @@ std::shared_ptr<Node> Squeeze::clone_with_new_inputs(const OutputVector& new_arg

switch (new_args.size()) {
case 1:
return std::make_shared<Squeeze>(new_args[0], m_allow_axis_skip);
return std::make_shared<Squeeze>(new_args[0]);
case 2:
return std::make_shared<Squeeze>(new_args[0], new_args[1], m_allow_axis_skip);
default:
Expand Down
38 changes: 38 additions & 0 deletions src/core/tests/type_prop/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,4 +639,42 @@ TEST(SqueezeDynamicAxis, squeeze_dynamic_non_const_axes) {
EXPECT_EQ(squeeze->get_allow_axis_skip(), allow_axis_skip);
}

TEST(SqueezeDynamicAxis, squeeze_dynamic_empty_axes) {
auto p_shape = PartialShape{1, 2, -1, 4};
auto axes_node = make_shared<ov::op::v0::Constant>(element::u64, Shape{});
auto exp_shape = PartialShape{2, -1, 4};
auto param = make_shared<ov::op::v0::Parameter>(element::f32, p_shape);

const auto squeeze0 = std::make_shared<op::v0::Squeeze>(param, axes_node);
const auto squeeze1 = std::make_shared<op::v15::Squeeze>(param, axes_node, true);
const auto squeeze2 = std::make_shared<op::v15::Squeeze>(param, axes_node, false);

EXPECT_EQ(squeeze0->get_element_type(), element::f32);
EXPECT_EQ(squeeze0->get_output_partial_shape(0), exp_shape);

EXPECT_EQ(squeeze1->get_element_type(), element::f32);
EXPECT_EQ(squeeze1->get_output_partial_shape(0), exp_shape);
EXPECT_EQ(squeeze1->get_allow_axis_skip(), true);

EXPECT_EQ(squeeze2->get_element_type(), element::f32);
EXPECT_EQ(squeeze2->get_output_partial_shape(0), exp_shape);
EXPECT_EQ(squeeze2->get_allow_axis_skip(), false);
}
TEST(SqueezeDynamicAxis, squeeze_dynamic_no_axes) {
auto p_shape = PartialShape{1, 2, -1, 4};
auto axes_node = make_shared<ov::op::v0::Constant>(element::u64, Shape{});
auto exp_shape = PartialShape::dynamic();
auto param = make_shared<ov::op::v0::Parameter>(element::f32, p_shape);

const auto squeeze0 = std::make_shared<op::v0::Squeeze>(param);
const auto squeeze1 = std::make_shared<op::v15::Squeeze>(param);

EXPECT_EQ(squeeze0->get_element_type(), element::f32);
EXPECT_EQ(squeeze0->get_output_partial_shape(0), exp_shape);

EXPECT_EQ(squeeze1->get_element_type(), element::f32);
EXPECT_EQ(squeeze1->get_output_partial_shape(0), exp_shape);
EXPECT_EQ(squeeze1->get_allow_axis_skip(), false);
}

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class ReferenceSqueezeV15AttributeSetLayerTest : public ReferenceSqueezeLayerTes
std::make_shared<op::v0::Constant>(params.m_axes_type, params.m_axes_shape, params.m_axes_value.data());
squeeze = std::make_shared<op::v15::Squeeze>(in, axes_node, true);
} else {
squeeze = std::make_shared<op::v15::Squeeze>(in, true);
squeeze = std::make_shared<op::v15::Squeeze>(in);
}

return std::make_shared<ov::Model>(squeeze, ParameterVector{in});
Expand Down
Loading