Skip to content

[CORE] Squeeze v15 reverse infer #27526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
579a5e2
add test with empty axes
barnasm1 Oct 31, 2024
2738301
remove attribute allow_axis_skip for empty axis Ctor
barnasm1 Nov 12, 2024
4ad95df
refactor squeezeV0 shape_infer common validate input
barnasm1 Nov 12, 2024
7255e6f
refactor squeezeV15 shape_infer common validate input
barnasm1 Nov 12, 2024
132a7f1
format
barnasm1 Nov 13, 2024
de3a1c8
format2
barnasm1 Nov 13, 2024
82d8ef1
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 13, 2024
c82e8fd
update squeeze15 ctor args
barnasm1 Nov 13, 2024
a3011b8
unify enable_if usage
barnasm1 Nov 13, 2024
084550b
reverse ctor update
barnasm1 Nov 13, 2024
09ba124
fix namespace selection for windows build
barnasm1 Nov 13, 2024
66e328a
code foramt
barnasm1 Nov 13, 2024
2eafc25
fix test
barnasm1 Nov 13, 2024
37b939f
code foramt
barnasm1 Nov 14, 2024
bd92721
avoid code duplication, use function: set_output_shape
barnasm1 Nov 14, 2024
9dff25c
code foramt
barnasm1 Nov 14, 2024
6ce7ccb
use ov::optional
barnasm1 Nov 14, 2024
d827fd4
separate test cases
barnasm1 Nov 14, 2024
cc1346e
use common in OV for this doxy style
barnasm1 Nov 14, 2024
b08f8bf
direct return output_shape
barnasm1 Nov 14, 2024
b2c7ac8
do not pass lambda functor into validator
barnasm1 Nov 14, 2024
fc0b5ae
unused variable
barnasm1 Nov 14, 2024
cc3da2f
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 14, 2024
727331d
undo doc format for op
barnasm1 Nov 15, 2024
cfd532a
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 19, 2024
a0e5d67
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 25, 2024
9996e6a
gather reverse infer + squeeze v15 rt_info
barnasm1 Nov 25, 2024
fbfb6d5
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 25, 2024
139c0b1
const rt_info name
barnasm1 Nov 25, 2024
771c3af
revert gitignore
barnasm1 Nov 28, 2024
7f4a1a1
direct use ov::PartialShape::dynamic()
barnasm1 Nov 28, 2024
e717976
revert gitignore empty line
barnasm1 Nov 28, 2024
6a66abe
Merge branch 'master' into squeeze_v15_reverse_infer
barnasm1 Nov 28, 2024
2b69d81
update gather reverse infer
barnasm1 Nov 28, 2024
25c5f39
fix gather transformation
barnasm1 Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 334 additions & 71 deletions .gitignore

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "openvino/op/convert_like.hpp"
#include "openvino/op/convolution.hpp"
#include "openvino/op/deformable_convolution.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/group_conv.hpp"
#include "openvino/op/if.hpp"
#include "openvino/op/parameter.hpp"
Expand Down Expand Up @@ -322,6 +323,21 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptr<ov::
is_changed = true;
}
is_changed |= inherit_output_type(op, {0});
} else if (ov::as_type_ptr<ov::op::v0::Convert>(op)) {
is_changed |= inherit_output_shape(op, {0});
is_changed |= inherit_output_rank(op, {0});
} else if (auto gather_op = ov::as_type_ptr<ov::op::v8::Gather>(op)) {
is_changed |= inherit_output_type(op, {0});

const auto& output_shape = op->get_output_partial_shape(0);
const auto& data_shape = op->get_input_partial_shape(0);
const auto batch = gather_op->get_batch_dims();

if (op->get_input_size() > 1 && batch >= 0 && op->get_input_partial_shape(1).rank().is_dynamic()) {
op->get_input_tensor(1).m_partial_shape =
ov::PartialShape::dynamic(output_shape.rank() - data_shape.rank() + batch + 1);
is_changed = true;
}
}
}
return is_changed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,125 @@ TEST_F(TransformationTestsF, TransposeWithConstantOrderReverseInfer2) {
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
}
}

TEST_F(TransformationTestsF, GatherReverseInferIndicesRank) {
auto dyn = Dimension::dynamic();
{
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{1, 22, 333, 4444});
auto indices = std::make_shared<opset10::Parameter>(ov::element::i32, PartialShape::dynamic());
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis);

// Concat is needed to produce static rank for indces
// Specify rank and type in one of Concat input to inherit in another
auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{1, 22, 333, 4444});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
manager.register_pass<pass::ReverseShapeAndTypeInfer>();
}
{
auto data = std::make_shared<ov::op::v0::Parameter>(element::f32, ov::PartialShape{1, 22, 333, 4444});
auto indices = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{dyn});
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis);

auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{1, 22, 333, 4444});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
}
}

TEST_F(TransformationTestsF, GatherReverseInferIndicesRankCustomBatchDims) {
auto dyn = Dimension::dynamic();
{
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 5});
auto indices = std::make_shared<opset10::Parameter>(ov::element::i32, ov::PartialShape::dynamic());
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {1});
int64_t batch_dims = 1;
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis, batch_dims);

// Concat is needed to produce static rank for indces
// Specify rank and type in one of Concat input to inherit in another
auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{2, 3});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
manager.register_pass<pass::ReverseShapeAndTypeInfer>();
}
{
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 5});
auto indices = std::make_shared<opset10::Parameter>(ov::element::i32, ov::PartialShape{dyn, dyn});
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {1});
int64_t batch_dims = 1;
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis, batch_dims);

auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{2, 3});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
}
}

TEST_F(TransformationTestsF, GatherReverseInferIndicesRankCustomBatchDims2) {
auto dyn = Dimension::dynamic();
{
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 2, 5});
auto indices = std::make_shared<opset10::Parameter>(ov::element::i32, ov::PartialShape::dynamic());
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {2});
int64_t batch_dims = 2;
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis, batch_dims);

// Concat is needed to produce static rank for indces
// Specify rank and type in one of Concat input to inherit in another
auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{2, 2, 3});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
manager.register_pass<pass::ReverseShapeAndTypeInfer>();
}
{
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 2, 5});
auto indices = std::make_shared<opset10::Parameter>(ov::element::i32, ov::PartialShape{dyn, dyn, dyn});
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {2});
int64_t batch_dims = 2;
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis, batch_dims);

auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{2, 2, 3});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
}
}

TEST_F(TransformationTestsF, GatherReverseInferIndicesRankCustomBatchDims3) {
auto dyn = Dimension::dynamic();
{
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 1, 5, 4});
auto indices = std::make_shared<opset10::Parameter>(ov::element::i32, ov::PartialShape::dynamic());
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {2});
int64_t batch_dims = 1;
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis, batch_dims);

// Concat is needed to produce static rank for indces
// Specify rank and type in one of Concat input to inherit in another
auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{2, 1, 3, 4});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
manager.register_pass<pass::ReverseShapeAndTypeInfer>();
}
{
auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 1, 5, 4});
auto indices = std::make_shared<opset10::Parameter>(ov::element::i32, ov::PartialShape{dyn, dyn});
auto axis = ov::op::v0::Constant::create(element::i32, Shape{}, {2});
int64_t batch_dims = 1;
auto gather = std::make_shared<ov::op::v8::Gather>(data, indices, axis, batch_dims);

auto data2 = std::make_shared<opset10::Parameter>(element::f32, PartialShape{2, 1, 3, 4});
auto concat = std::make_shared<opset10::Concat>(OutputVector{gather, data2}, 1);
auto result = std::make_shared<opset10::Result>(concat);
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, indices, data2});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ std::shared_ptr<ov::Model> create_v15_model(const IndicesMode indices_mode,
ov::ParameterVector params = {data};
std::shared_ptr<op::v15::Squeeze> squeeze;
if (indices_mode == IndicesMode::NONE) {
squeeze = std::make_shared<ov::opset15::Squeeze>(data, allow_axis_skip);
squeeze = std::make_shared<ov::opset15::Squeeze>(data);
} else if (indices_mode == IndicesMode::PARAM) {
const auto& indices =
std::make_shared<ov::opset15::Parameter>(ov::element::i32, PartialShape({data_shape.rank()}));
Expand Down
2 changes: 2 additions & 0 deletions src/core/include/openvino/op/squeeze.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class OPENVINO_API Squeeze : public util::SqueezeBase {
public:
OPENVINO_OP("Squeeze", "opset15");

static constexpr const char* output_rt_info_name = "rt_info_reverse_infer_output_rank";

Squeeze();
/// \brief Constructs a squeeze v15 operation.
///
Expand Down
Loading
Loading