diff --git a/src/bindings/python/src/openvino/passes/__init__.py b/src/bindings/python/src/openvino/passes/__init__.py index 75b2ed23d05709..5187c68b8f65dd 100644 --- a/src/bindings/python/src/openvino/passes/__init__.py +++ b/src/bindings/python/src/openvino/passes/__init__.py @@ -14,6 +14,7 @@ type_matches, type_matches_any, shape_matches, + attrs_match, ) from openvino._pyopenvino.passes import Serialize, ConstantFolding, VisualizeTree, MakeStateful, LowLatency2, ConvertFP32ToFP16, Version from openvino.passes.manager import Manager diff --git a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp index 0f6681997bc503..3cc58935bfa41a 100644 --- a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp +++ b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp @@ -18,6 +18,7 @@ #include "openvino/pass/pattern/op/pattern.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "pyopenvino/core/common.hpp" +#include "pyopenvino/utils/utils.hpp" static ov::NodeTypeInfo get_type(const std::string& type_name) { // Supported types: opsetX.OpName or opsetX::OpName @@ -1014,6 +1015,9 @@ inline void reg_predicates(py::module m) { m.def("type_matches", &ov::pass::pattern::type_matches); m.def("type_matches_any", &ov::pass::pattern::type_matches_any); m.def("shape_matches", &ov::pass::pattern::shape_matches); + m.def("attrs_match", [](py::object& attrs) { + return ov::pass::pattern::attrs_match(Common::utils::py_object_to_unordered_any_map(attrs)); + }); } void reg_passes_pattern_ops(py::module m) { diff --git a/src/bindings/python/src/pyopenvino/utils/utils.cpp b/src/bindings/python/src/pyopenvino/utils/utils.cpp index 8520a5781a8370..aa19c20b81d646 100644 --- a/src/bindings/python/src/pyopenvino/utils/utils.cpp +++ b/src/bindings/python/src/pyopenvino/utils/utils.cpp @@ -411,6 +411,21 @@ ov::AnyMap py_object_to_any_map(const py::object& py_obj) { return return_value; } +std::unordered_map py_object_to_unordered_any_map(const py::object& py_obj) { + OPENVINO_ASSERT(py_object_is_any_map(py_obj), "Unsupported attribute type."); + std::unordered_map return_value = {}; + for (auto& item : py::cast(py_obj)) { + std::string key = py::cast(item.first); + py::object value = py::cast(item.second); + if (py_object_is_any_map(value)) { + return_value[key] = Common::utils::py_object_to_any_map(value); + } else { + return_value[key] = Common::utils::py_object_to_any(value); + } + } + return return_value; +} + template std::tuple tuple_from_py_tuple_impl(const py::tuple& py_tuple, std::index_sequence) { return std::make_tuple(py_tuple[I].cast()...); diff --git a/src/bindings/python/src/pyopenvino/utils/utils.hpp b/src/bindings/python/src/pyopenvino/utils/utils.hpp index fc1dbb646ea628..86d73d18d1cdf7 100644 --- a/src/bindings/python/src/pyopenvino/utils/utils.hpp +++ b/src/bindings/python/src/pyopenvino/utils/utils.hpp @@ -140,6 +140,8 @@ class MemoryBuffer : public std::streambuf { ov::AnyMap py_object_to_any_map(const py::object& py_obj); + std::unordered_map py_object_to_unordered_any_map(const py::object& py_obj); + ov::Any py_object_to_any(const py::object& py_obj); ov::pass::Serialize::Version convert_to_version(const std::string& version); diff --git a/src/bindings/python/tests/test_transformations/test_pattern_ops.py b/src/bindings/python/tests/test_transformations/test_pattern_ops.py index 4ce8cadc876b4b..6a381a749c16f5 100644 --- a/src/bindings/python/tests/test_transformations/test_pattern_ops.py +++ b/src/bindings/python/tests/test_transformations/test_pattern_ops.py @@ -16,6 +16,7 @@ type_matches, type_matches_any, shape_matches, + attrs_match, ) from openvino.utils.types import get_element_type @@ -278,6 +279,19 @@ def symbol_matching_test(shape: PartialShape, pattern: str): assert symbols["Six"] == 6, symbols +def test_attrs_match(): + param = ops.parameter([-1, -1]) + + def test_shape_of_attribute(et: str): + node = ops.shape_of(param, output_type=et) + attr = {"output_type": et} + matcher = Matcher(AnyInput(attrs_match(attr)), "Find shape_of with attribute") + assert matcher.match(node), f"Match failed for {node} with attribute" + + test_shape_of_attribute("i64") + test_shape_of_attribute("i32") + + def test_optional_full_match(): model_input = ops.parameter(PartialShape.dynamic()) model_abs = ops.abs(model_input) diff --git a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp index 6861d8469b2129..dd0776a09d035a 100644 --- a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp @@ -88,29 +88,12 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass { * @ingroup ov_transformation_common_api * @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation */ -class ov::pass::RoPEFusion : public ov::pass::GraphRewrite { +class ov::pass::RoPEFusion : public ov::pass::ModelPass { public: - OPENVINO_GRAPH_REWRITE_RTTI("RoPEFusion"); - RoPEFusion(bool support_2d_rope = false) { - add_matcher(); - add_matcher(); - add_matcher(); - // optional heads & tails are fused in separate matcher pass, - // after RoPENode has been created. - add_matcher(); - add_matcher(); - add_matcher(); - - add_matcher(0); - add_matcher(1); - if (support_2d_rope) { - add_matcher(0, true); - add_matcher(1, true); - } - - add_matcher(0); - add_matcher(1); - - add_matcher(); - } + OPENVINO_MODEL_PASS_RTTI("RoPEFusion"); + RoPEFusion(bool support_2d_rope = false); + bool run_on_model(const std::shared_ptr& model) override; + +private: + bool m_support_2d_rope; }; diff --git a/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp b/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp index 9e85fcd4977430..3801e4ee65b766 100644 --- a/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp +++ b/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp @@ -25,7 +25,7 @@ class TRANSFORMATIONS_API LabelResolvingThroughSelect; class ov::pass::SymbolicOptimizations : public ov::pass::ModelPass { public: OPENVINO_MODEL_PASS_RTTI("SymbolicOptimizations"); - explicit SymbolicOptimizations(bool full_run = true); + explicit SymbolicOptimizations(bool full_run = true, std::shared_ptr pass_config = nullptr); bool run_on_model(const std::shared_ptr& m) override; std::shared_ptr get_manager() { return m_manager; diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 54bbfabeeb0cc6..8d167e3cb01f5a 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -17,13 +17,54 @@ #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/pattern/op/optional.hpp" #include "ov_ops/rotary_positional_embeddings.hpp" #include "ov_ops/type_relaxed.hpp" #include "transformations/utils/gen_pattern.hpp" #include "transformations/utils/utils.hpp" +#include "transformations/symbolic_transformations/symbolic_optimizations.hpp" +#include "transformations/utils/utils.hpp" + +#include "openvino/pass/visualize_tree.hpp" using namespace ov::gen_pattern; +ov::pass::RoPEFusion::RoPEFusion(bool support_2d_rope) : m_support_2d_rope(support_2d_rope) { +} + +bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr& model) { + RUN_ON_MODEL_SCOPE(RoPEFusion); + ov::pass::SymbolicOptimizations symbolic_optimizations(false, get_pass_config()); + + std::cout << "SETTING UP RoPEFusion" << std::endl; + auto symbolic_ctx_manager = symbolic_optimizations.get_manager(); + + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + // optional heads & tails are fused in separate matcher pass, + // after RoPENode has been created. + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + + symbolic_ctx_manager->register_pass(0); + symbolic_ctx_manager->register_pass(1); + if (m_support_2d_rope) { + symbolic_ctx_manager->register_pass(0, true); + symbolic_ctx_manager->register_pass(1, true); + } + symbolic_ctx_manager->register_pass(0); + symbolic_ctx_manager->register_pass(1); + + symbolic_ctx_manager->register_pass(); + + std::cout << "About to run the transformations" << std::endl; + bool a = symbolic_optimizations.get_manager()->run_passes(model); + std::cout << "Run the transformations" << std::endl; + return a; +} + // This is a utility function used in the work around in ChatGLM pattern. // Since the existing implementation of Symbols don't allow for checking // permutations of the same Symbols in a shape, we need to check the @@ -56,51 +97,42 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() { // y1 = x * t_cos // y2 = x3 * t_sin // y = y1 + y2 - auto x = makePattern(ov::Rank(4)); - auto t_cos = makePattern(ov::Rank(4)); - auto t_sin = makePattern(ov::Rank(4)); - - auto num_heads = ov::gen_pattern::Symbol("num_heads"); - auto head_size = ov::gen_pattern::Symbol("head_size"); + auto x = pattern::any_input(pattern::rank_equals(4) && pattern::shape_matches("[PRESERVED_DIMS..., head_size]")); + auto t_cos = pattern::any_input(pattern::rank_equals(4)); + auto t_sin = pattern::any_input(pattern::rank_equals(4)); - auto x1_target_shape = makeConst({0, num_heads, 0, -1, 2}); - auto x1 = makePattern({x, x1_target_shape}, {{"special_zero", true}}); - auto split = makePattern({x1, -1}, {{"num_splits", 2}}); + auto x1 = pattern::wrap_type({x, pattern::any_input()}, pattern::shape_matches("[PRESERVED_DIMS..., ?, 2]")); + auto split = pattern::wrap_type({x1, INT_CONSTANT_WITH_PREDICATE(value == std::vector{-1})}, {{"num_splits", 2l}}); split->set_output_size(2); // 3 versions of mulitply by -1 depending on transformations execution prior to this pass - auto x1_1_neg_1 = makePattern({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}}); - - auto squeeze_2 = makePattern({split->output(1), -1}); - auto x1_1_neg_2 = makePattern({squeeze_2, -1.0f}, {{"auto_broadcast", "numpy"}}); - auto unsqueeze_2 = makePattern({x1_1_neg_2, -1}); - - auto x1_1_neg_3 = makePattern({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}}); - auto squeeze_3 = makePattern({x1_1_neg_3, -1}); - auto unsqueeze_3 = makePattern({squeeze_3, -1}); + auto opt_squeeze = pattern::optional({split->output(1), INT_CONSTANT_WITH_PREDICATE(value == std::vector{-1})}); + auto x1_1_neg = pattern::wrap_type({opt_squeeze, FLOAT_CONSTANT_WITH_PREDICATE(value == std::vector{-1})}, {{"auto_broadcast", "numpy"}}); + auto opt_squeeze_1 = pattern::optional({x1_1_neg, INT_CONSTANT_WITH_PREDICATE(value == std::vector{-1})}); + auto opt_unsqueeze = pattern::optional({opt_squeeze_1, INT_CONSTANT_WITH_PREDICATE(value == std::vector{-1})}); - auto x2 = makePattern({x1_1_neg_1 | unsqueeze_2 | unsqueeze_3, split->output(0)}, {{"axis", -1}}); - auto x3_target_shape = makeConst({0, num_heads, 0, head_size}); - auto x3 = makePattern({x2, x3_target_shape}, {{"special_zero", true}}); + auto x2 = pattern::wrap_type({opt_unsqueeze, split->output(0)}, {{"axis", -1l}}); + auto x3 = pattern::wrap_type({x2, pattern::any_input()}, pattern::shape_matches("[PRESERVED_DIMS..., head_size]")); - auto y1 = makePattern({x, t_cos}, {{"auto_broadcast", "numpy"}}); - auto y2 = makePattern({x3, t_sin}, {{"auto_broadcast", "numpy"}}); + auto y1 = pattern::wrap_type({x, t_cos}, {{"auto_broadcast", "numpy"}}); + auto y2 = pattern::wrap_type({x3, t_sin}, {{"auto_broadcast", "numpy"}}); - auto y = makePattern({y1, y2}, {{"auto_broadcast", "numpy"}}); - auto result = y; + auto result = pattern::wrap_type({y1, y2}, {{"auto_broadcast", "numpy"}}); matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { - PatternValidator validator(m); - if (!validator) { - return false; - } - const auto& pattern_map = m.get_pattern_value_map(); auto root = m.get_match_root(); + auto symbols = m.get_symbols(); + auto num_heads = symbols["PRESERVED_DIMS"].g()[1]; + auto head_size = symbols["head_size"]; + if (!num_heads.is_static() || !head_size.is_static()) { + return false; + } + op::internal::RoPE::Config config; - config.head_cnt = static_cast(validator["num_heads"]); - config.head_size = static_cast(validator["head_size"]); + config.head_cnt = static_cast(num_heads.i()); + config.head_size = static_cast(head_size.i()); config.rotary_ndims = config.head_size; config.is_interleaved = true; config.output_trans0213 = false; @@ -126,6 +158,7 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() { // this new node may match following additional matchers register_new_node(new_node); + std::cout << "END OF RoPEFusionFlux" << std::endl; return true; }; diff --git a/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp b/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp index 778ab2dbfee92a..84ff6c2667deeb 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp @@ -346,16 +346,12 @@ ov::pass::DeReshapeFullyConnected::DeReshapeFullyConnected() { using namespace ov::op; using namespace ov::pass::pattern; - auto transpose_a_false = [](const std::shared_ptr& node) -> bool { - auto mm = as_type_ptr(node); - return mm && !mm->get_transpose_a(); - }; - auto input = wrap_type({any_input(shape_matches("BATCHES_1...,Y")), any_input()}, shape_matches("BATCHES_2...,Y")); auto converted = pattern::optional(input, consumers_count(1)); auto mm_label = wrap_type({converted, any_input(rank_equals(2))}, - consumers_count(1) && transpose_a_false && shape_matches("BATCHES_2...,Z")); + consumers_count(1) && shape_matches("BATCHES_2...,Z"), + {{"transpose_a", false}}); auto output = wrap_type({mm_label, any_input()}, shape_matches("BATCHES_1...,Z")); ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { diff --git a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp index ec9945b8ea1ecc..c49d93efadbf65 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp @@ -171,8 +171,12 @@ ov::pass::LabelResolvingThroughSelect::LabelResolvingThroughSelect() { register_matcher(m, matcher_pass_callback); } -ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) { - m_manager = std::make_shared("Symbolic"); +ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run, std::shared_ptr pass_config) { + if (pass_config) + m_manager = std::make_shared(pass_config, "Symbolic"); + else + m_manager = std::make_shared("Symbolic"); + m_manager->set_per_pass_validation(false); #define REGISTER_SYMBOLIC(region, ...) m_manager->register_pass(__VA_ARGS__); diff --git a/src/core/include/openvino/core/attribute_adapter.hpp b/src/core/include/openvino/core/attribute_adapter.hpp index ec1c8dd992a155..ef3c1d8dffad22 100644 --- a/src/core/include/openvino/core/attribute_adapter.hpp +++ b/src/core/include/openvino/core/attribute_adapter.hpp @@ -38,6 +38,9 @@ class OPENVINO_API ValueAccessor { virtual void set_as_any(const ov::Any& x) { OPENVINO_NOT_IMPLEMENTED; } + virtual ov::Any get_as_any() { + OPENVINO_NOT_IMPLEMENTED; + } }; /// \brief Provides access to values via get/set methods from an m_value, typically from @@ -65,6 +68,9 @@ class ValueAccessor : public ValueAccessor { OPENVINO_THROW("Bad cast from: ", x.type_info().name(), " to: ", typeid(VAT).name()); } } + ov::Any get_as_any() override { + return {get()}; + } }; template <> diff --git a/src/core/include/openvino/pass/pattern/op/label.hpp b/src/core/include/openvino/pass/pattern/op/label.hpp index c535c91ecf5a5d..c1a94663e89ba0 100644 --- a/src/core/include/openvino/pass/pattern/op/label.hpp +++ b/src/core/include/openvino/pass/pattern/op/label.hpp @@ -72,7 +72,7 @@ class OPENVINO_API Label : public Pattern { }; } // namespace op -OPENVINO_API std::shared_ptr any_input(); +OPENVINO_API std::shared_ptr any_input(const Attributes& attrs = {}); template std::shared_ptr any_input(const TPredicate& pred) { diff --git a/src/core/include/openvino/pass/pattern/op/op.hpp b/src/core/include/openvino/pass/pattern/op/op.hpp new file mode 100644 index 00000000000000..f8b67a6ab46081 --- /dev/null +++ b/src/core/include/openvino/pass/pattern/op/op.hpp @@ -0,0 +1,40 @@ +//// Copyright (C) 2018-2025 Intel Corporation +//// SPDX-License-Identifier: Apache-2.0 +//// +// +//#pragma once +// +//#include "openvino/pass/pattern/op/predicate.hpp" +// +// namespace ov::pass::pattern { +//// A glue/syntax-sugar type which allows more types to be used as input to pattern operations +// struct PatternOp { +// std::shared_ptr op; +// size_t output_idx; +// +// operator ov::Output() const; +// ov::Output get_output() const; +// +// PatternOp(); +// PatternOp(const Output &out); +// +// template >>* = nullptr> +// PatternOp(PredicateT pred) { +// op = any_input(pred); +// } +// +// PatternOp(ov::Rank rank); +// +// // Constant matching +// PatternOp(std::string value_notation); +// PatternOp(int v); +// PatternOp(float v); +// PatternOp(double v); +// PatternOp(long long v); +// PatternOp(std::initializer_list v); +// PatternOp(std::initializer_list v); +// PatternOp(std::initializer_list v); +// PatternOp(std::initializer_list v); +// }; +// } \ No newline at end of file diff --git a/src/core/include/openvino/pass/pattern/op/optional.hpp b/src/core/include/openvino/pass/pattern/op/optional.hpp index 5491f3f6bbd332..b5e26774082ead 100644 --- a/src/core/include/openvino/pass/pattern/op/optional.hpp +++ b/src/core/include/openvino/pass/pattern/op/optional.hpp @@ -85,36 +85,50 @@ void collect_type_info(std::vector& type_info_vec) { } template -std::shared_ptr optional(const OutputVector& inputs, const TPredicate& pred) { +std::shared_ptr optional(const OutputVector& inputs, const TPredicate& pred, const Attributes& attrs = {}) { std::vector optional_type_info_vec; collect_type_info(optional_type_info_vec); - return std::make_shared(optional_type_info_vec, inputs, op::Predicate(pred)); + return std::make_shared( + optional_type_info_vec, + inputs, + attrs.empty() ? op::Predicate(pred) : attrs_match(attrs) && op::Predicate(pred)); } template -std::shared_ptr optional(const Output& input, const TPredicate& pred) { - return optional(OutputVector{input}, op::Predicate(pred)); +std::shared_ptr optional(const Output& input, const TPredicate& pred, const Attributes& attrs = {}) { + return optional(OutputVector{input}, pred, attrs); } template >* = nullptr> -std::shared_ptr optional(const TPredicate& pred) { - return optional(OutputVector{}, op::Predicate(pred)); + typename std::enable_if_t && + !std::is_constructible_v>* = nullptr> +std::shared_ptr optional(const TPredicate& pred, const Attributes& attrs = {}) { + return optional(OutputVector{}, op::Predicate(pred), attrs); } template -std::shared_ptr optional(const OutputVector& inputs) { - return optional(inputs, op::Predicate()); +std::shared_ptr optional(const OutputVector& inputs, const Attributes& attrs = {}) { + return optional(inputs, attrs.empty() ? op::Predicate() : attrs_match(attrs)); } template -std::shared_ptr optional(const Output& input) { - return optional(OutputVector{input}, op::Predicate()); +std::shared_ptr optional(std::initializer_list>&& inputs, const Attributes& attrs = {}) { + return optional(OutputVector(inputs), attrs); +} + +template +std::shared_ptr optional(const Output& input, const Attributes& attrs = {}) { + return optional(OutputVector{input}, attrs); +} + +template +std::shared_ptr optional(const Attributes& attrs) { + return optional(OutputVector{}, attrs); } template std::shared_ptr optional() { - return optional(OutputVector{}, op::Predicate()); + return optional(OutputVector{}); } } // namespace ov::pass::pattern diff --git a/src/core/include/openvino/pass/pattern/op/pattern.hpp b/src/core/include/openvino/pass/pattern/op/pattern.hpp index d76289408cf582..976577299895c2 100644 --- a/src/core/include/openvino/pass/pattern/op/pattern.hpp +++ b/src/core/include/openvino/pass/pattern/op/pattern.hpp @@ -22,6 +22,7 @@ using PatternValueMap = std::map, Output>; using PatternValueMaps = std::vector; using PatternMap = std::map, std::shared_ptr>; +using Attributes = std::unordered_map; PatternMap as_pattern_map(const PatternValueMap& pattern_value_map); PatternValueMap as_pattern_value_map(const PatternMap& pattern_map); @@ -61,7 +62,10 @@ OPENVINO_API op::Predicate type_matches_any(const std::vector& ty OPENVINO_API op::Predicate all_of(const std::vector)>>& predicates); +OPENVINO_API op::Predicate attrs_match(const Attributes& expected_attrs); + OPENVINO_API op::Predicate shape_matches(const std::string& shape_notation); +OPENVINO_API op::Predicate value_matches(const std::string& value_notation); namespace op { OPENVINO_DEPRECATED("This method is deprecated. Use constructor of ov::pass::pattern::Predicate instead") diff --git a/src/core/include/openvino/pass/pattern/op/wrap_type.hpp b/src/core/include/openvino/pass/pattern/op/wrap_type.hpp index df3a6c6251eac6..dced56c1095a13 100644 --- a/src/core/include/openvino/pass/pattern/op/wrap_type.hpp +++ b/src/core/include/openvino/pass/pattern/op/wrap_type.hpp @@ -64,19 +64,34 @@ void collect_wrap_info(std::vector& info) { } template -std::shared_ptr wrap_type(const OutputVector& inputs, const TPredicate& pred) { +std::shared_ptr wrap_type(const OutputVector& inputs, const TPredicate& pred, const Attributes& attrs = {}) { std::vector info; collect_wrap_info(info); - return std::make_shared(info, op::Predicate(pred), inputs); + return std::make_shared( + info, + (attrs.empty() ? op::Predicate(pred) : attrs_match(attrs) && op::Predicate(pred)), + inputs); +} + +template >* = nullptr> +std::shared_ptr wrap_type(const TPredicate& pred, const Attributes& attrs = {}) { + return wrap_type({}, op::Predicate(pred), attrs); } template -std::shared_ptr wrap_type(const OutputVector& inputs = {}) { - return wrap_type(inputs, op::Predicate()); +std::shared_ptr wrap_type(const OutputVector& inputs, const Attributes& attrs = {}) { + return wrap_type(inputs, (attrs.empty() ? op::Predicate() : attrs_match(attrs))); } -template -std::shared_ptr wrap_type(const TPredicate& pred) { - return wrap_type({}, op::Predicate(pred)); +template +std::shared_ptr wrap_type(std::initializer_list>&& inputs, const Attributes& attrs = {}) { + return wrap_type(OutputVector(inputs), attrs); +} + +template +std::shared_ptr wrap_type(const Attributes& attrs = {}) { + return wrap_type({}, attrs); } } // namespace ov::pass::pattern diff --git a/src/core/src/pattern/op/label.cpp b/src/core/src/pattern/op/label.cpp index 3b4ec752276485..6bb1bc027105be 100644 --- a/src/core/src/pattern/op/label.cpp +++ b/src/core/src/pattern/op/label.cpp @@ -54,6 +54,6 @@ bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* match return false; } -std::shared_ptr ov::pass::pattern::any_input() { - return std::make_shared(); +std::shared_ptr ov::pass::pattern::any_input(const Attributes& attrs) { + return attrs.empty() ? std::make_shared() : any_input(attrs_match(attrs)); } diff --git a/src/core/src/pattern/op/op.cpp b/src/core/src/pattern/op/op.cpp new file mode 100644 index 00000000000000..114cbd69689b34 --- /dev/null +++ b/src/core/src/pattern/op/op.cpp @@ -0,0 +1,48 @@ +//// Copyright (C) 2018-2025 Intel Corporation +//// SPDX-License-Identifier: Apache-2.0 +//// +// +//#include "openvino/pass/pattern/op/op.hpp" +// +//#include "openvino/pass/pattern/op/label.hpp" +//#include "openvino/pass/pattern/op/wrap_type.hpp" +//#include "openvino/op/constant.hpp" +//#include "openvino/util/common_util.hpp" +// +// namespace ov::pass::pattern { +// PatternOp::operator ov::Output() const { +// return get_output(); +// } +// +// ov::Output PatternOp::get_output() const { +// if (output_idx >= 0) +// return op->output(output_idx); +// return op->get_default_output(); +// } +// +// PatternOp::PatternOp(const Output &out) +// : op(out.get_node_shared_ptr()), +// output_idx(out.get_index()) {} +// +// PatternOp::PatternOp() { +// op = any_input(); +// } +// +// PatternOp::PatternOp(ov::Rank rank) { +// op = any_input(rank_equals(rank)); +// } +// +// PatternOp::PatternOp(std::string value_notation) { +// op = wrap_type(value_matches(value_notation)); +// } +// +// PatternOp::PatternOp(int v) : PatternOp(std::to_string(v)) {} +// PatternOp::PatternOp(float v) : PatternOp(std::to_string(v)) {} +// PatternOp::PatternOp(double v) : PatternOp(std::to_string(v)) {} +// PatternOp::(long long v) : PatternOp(std::to_string(v)) {} +// +// PatternOp::PatternOp(std::initializer_list v) : PatternOp(ov::util::join(v, ",")) {} +// PatternOp::PatternOp(std::initializer_list v) : PatternOp(ov::util::join(v, ",")) {} +// PatternOp::PatternOp(std::initializer_list v) : PatternOp(ov::util::join(v, ",")) {} +// PatternOp::PatternOp(std::initializer_list v) : PatternOp(ov::util::join(v, ",")) {} +//} \ No newline at end of file diff --git a/src/core/src/pattern/op/pattern.cpp b/src/core/src/pattern/op/pattern.cpp index 2108d742b130cf..57d98e2206d887 100644 --- a/src/core/src/pattern/op/pattern.cpp +++ b/src/core/src/pattern/op/pattern.cpp @@ -7,7 +7,9 @@ #include #include +#include "openvino/op/constant.hpp" #include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" namespace ov::pass::pattern { namespace op { @@ -181,6 +183,74 @@ op::Predicate all_of(const std::vector)>>& predi "all_of(...)"); } +namespace { +class AttributeMatchingVisitor : public ov::AttributeVisitor { +public: + explicit AttributeMatchingVisitor(const Attributes& expected_attrs) + : ov::AttributeVisitor(), + m_expected_attrs{expected_attrs} {} + + void on_adapter(const std::string& name, ValueAccessor& adapter) override { + if (m_attrs_match && m_expected_attrs.count(name)) { + try { + const auto& node_attribute = adapter.get_as_any(); + const auto& expected_attribute = m_expected_attrs.at(name); + if (node_attribute.type_info() != expected_attribute.type_info()) + OPENVINO_DEBUG(" Attribute `", + name, + "` -- data type does not match. ", + node_attribute.type_info().name(), + " vs ", + expected_attribute.type_info().name()); + bool status = node_attribute == expected_attribute; + if (!status) + OPENVINO_DEBUG(" Attribute `", name, "` -- value does not match. ", [&]() { + std::stringstream ss; + node_attribute.print(ss); + ss << " vs "; + expected_attribute.print(ss); + return ss.str(); + }()); + m_attrs_match &= status; + } catch (...) { + OPENVINO_DEBUG(" Attribute `", name, "` matching went wrong"); + m_attrs_match = false; + } + } + }; + + bool get_match_status() const { + return m_attrs_match; + } + +private: + const Attributes& m_expected_attrs; + bool m_attrs_match = true; +}; +} // namespace + +op::Predicate attrs_match(const Attributes& expected_attrs) { + std::stringstream ss; + ss << "{ "; + bool first = true; + for (const auto& [key, value] : expected_attrs) { + if (!first) + ss << ", "; + first = false; + ss << key << ": "; + value.print(ss); + } + ss << " }"; + return op::Predicate( + [expected_attrs](PatternSymbolMap&, const Output& output) -> bool { + const auto& node = output.get_node_shared_ptr(); + AttributeMatchingVisitor visitor(expected_attrs); + node->visit_attributes(visitor); + return visitor.get_match_status(); + }, + "attrs_match(" + ss.str() + ")"); +} + namespace { bool ends_with(std::string_view str, std::string_view suffix) { return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; @@ -241,6 +311,16 @@ std::pair str2int(const std::string& str) { return {0, l}; } +std::pair str2double(const std::string& str) { + auto s = str.c_str(); + char* end; + double d; + d = strtod(s, &end); + if (*s == '\0' || *end != '\0') + return {1, 0}; + return {0, d}; +} + struct GroupDetails { std::string name; int64_t begin = 0, end = 0; @@ -250,6 +330,24 @@ struct GroupDetails { } }; +ov::Any get_element(const ov::Any& values, size_t idx) { + if (values.is>()) { + const auto& vec = values.as>(); + if (idx < 0) + idx += vec.size(); + OPENVINO_ASSERT(idx < vec.size(), "Unexpected index"); + return {vec[idx]}; + } + if (values.is>()) { + const auto& vec = values.as>(); + if (idx < 0) + idx += vec.size(); + OPENVINO_ASSERT(idx < vec.size(), "Unexpected index"); + return {vec[idx]}; + } + OPENVINO_ASSERT(false, "Unreachable"); +} + } // namespace /** @@ -298,6 +396,8 @@ op::Predicate shape_matches(const std::string& shape_notation) { PatternSymbolMap local_m; GroupDetails group; for (const auto& [this_dim_idx, name] : idx_to_name) { + if (!group.name.empty() && this_dim_idx < group.end) + group.end = this_dim_idx; if (name == "?" || name == "...") continue; if (ends_with(name, "...")) { // named group detected @@ -305,8 +405,6 @@ op::Predicate shape_matches(const std::string& shape_notation) { group.begin = this_dim_idx; continue; } - if (!group.name.empty() && this_dim_idx < group.end) - group.end = this_dim_idx; const auto& this_dim = shape[this_dim_idx]; const auto& [conversion_failed, converted_int] = str2int(name); if (conversion_failed) { // failed the conversion -- this is a name @@ -384,4 +482,127 @@ op::Predicate shape_matches(const std::string& shape_notation) { }, "shape_matches('" + shape_notation + "')"); } + +op::Predicate value_matches(const std::string& value_notation) { + auto item = parse_notation(value_notation); + const auto& idx_to_name = item.first; + const auto& element_count_restrictions = item.second; + return op::Predicate( + [idx_to_name, element_count_restrictions](PatternSymbolMap& m, const Output& output) -> bool { + const auto& constant = ov::as_type_ptr(output.get_node_shared_ptr()); + if (!constant) + return false; + + const auto& shape = constant->get_shape(); + const auto& element_count = shape_size(shape); + if (element_count_restrictions == 0) // empty + return element_count == 0; + if (element_count_restrictions == -1) // fully dynamic (impossible to have dynamic number of elements) + return false; + if (element_count_restrictions == -2 && element_count + 1 < idx_to_name.size()) + // minimum num element check; checking element_count + 1 because idx_to_name contains a record with + // group that may match to an empty set of elements + return false; + + if (element_count_restrictions > 0 && static_cast(element_count) != element_count_restrictions) + return false; + + bool is_int = constant->get_element_type().is_integral(); + // TODO: check for dynamic et + const auto& values = + is_int ? ov::Any(constant->cast_vector()) : ov::Any(constant->cast_vector()); + + PatternSymbolMap local_m; + GroupDetails group; + for (const auto& [this_el_idx, name] : idx_to_name) { + if (name == "?" || name == "...") + continue; + if (ends_with(name, "...")) { // named group detected + group.name = {name.substr(0, name.size() - 3)}; + group.begin = this_el_idx; + continue; + } + if (!group.name.empty() && this_el_idx < group.end) + group.end = this_el_idx; + const auto& this_el = get_element(values, this_el_idx); + const auto& [i_conversion_failed, converted_int] = str2int(name); + const auto& [d_conversion_failed, converted_double] = str2double(name); + if (i_conversion_failed && d_conversion_failed) { // failed the conversion -- this is a name + if (m.count(name) || local_m.count(name)) { + const auto& recorded_value = m.count(name) ? m.at(name) : local_m.at(name); + if (recorded_value.is_integer()) { + if (!this_el.is() || this_el.as() != recorded_value.i()) + return false; + } + if (recorded_value.is_double()) { + if (!this_el.is() || + this_el.as() != recorded_value.d()) // TODO: double cmp + return false; + } else { + return false; + } + } else { + if (this_el.is()) + local_m[name] = {this_el.as()}; + else if (this_el.is()) + local_m[name] = {this_el.as()}; + else + return false; + } + } else if (!i_conversion_failed) { // this_dim is not a name, but an integer + if (!this_el.is() || this_el.as() != converted_int) + return false; + } else if (!d_conversion_failed) { // this_dim is not a name, but a double + if (!this_el.is() || this_el.as() != converted_double) // TODO: cmp double + return false; + } + } + + if (!group.name.empty()) { + OPENVINO_ASSERT(group.end <= 0); // end == 0 means group is placed at the end of the notation + group.end = group.end + element_count; + + if (m.count(group.name) || local_m.count(group.name)) { + const auto& recorded_value = m.count(group.name) ? m.at(group.name) : local_m.at(group.name); + OPENVINO_ASSERT(recorded_value.is_group(), + "Mixing group and non group symbolic predicate notation"); + const auto& recorded_group = recorded_value.g(); + if (recorded_group.size() != group.size()) + return false; + for (size_t i = 0; i < recorded_group.size(); ++i) { + const auto& recorded_i = recorded_group[i]; + const auto& this_el = get_element(values, group.begin + i); + if (recorded_i.is_integer()) { + if (!this_el.is() || this_el.as() != recorded_i.i()) + return false; + } + if (recorded_i.is_double()) { + if (!this_el.is() || this_el.as() != recorded_i.d()) // TODO: cmp double + return false; + } else { // FIXME: allow conversion comparison -- int with double + return false; + } + } + } else { + std::vector group_value; + for (size_t i = 0; i < group.size(); ++i) { + const auto& this_el = get_element(values, group.begin + i); + if (this_el.is()) + group_value.emplace_back(this_el.as()); + else if (this_el.is()) + group_value.emplace_back(this_el.as()); + else + return false; + } + local_m[group.name] = group_value; + } + } + + // only write locally collected data to the global map when the match is complete to avoid partially + // collected data for the case when Predicate::operator|| was used + m.insert(local_m.begin(), local_m.end()); + return true; + }, + "value_matches('" + value_notation + "')"); +} } // namespace ov::pass::pattern diff --git a/src/core/src/pattern/op/predicate.cpp b/src/core/src/pattern/op/predicate.cpp index 2044fef377cf8b..2bdfbe14ae2208 100644 --- a/src/core/src/pattern/op/predicate.cpp +++ b/src/core/src/pattern/op/predicate.cpp @@ -70,7 +70,7 @@ constexpr bool symbol_true_predicate(pass::pattern::PatternSymbolMap&, const Out } } // namespace -Predicate::Predicate() : m_pred(symbol_true_predicate) {} +Predicate::Predicate() : m_name("always_true"), m_pred(symbol_true_predicate) {} Predicate::Predicate(std::nullptr_t) : Predicate() {} bool Predicate::operator()(pass::pattern::PatternSymbolMap& m, const Output& output) const { diff --git a/src/core/tests/pattern.cpp b/src/core/tests/pattern.cpp index 582862262a42ac..34d139f701e9f2 100644 --- a/src/core/tests/pattern.cpp +++ b/src/core/tests/pattern.cpp @@ -32,10 +32,12 @@ #include "openvino/op/reduce_sum.hpp" #include "openvino/op/relu.hpp" #include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" #include "openvino/op/sigmoid.hpp" #include "openvino/op/strided_slice.hpp" #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" +#include "openvino/op/util/attr_types.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/manager.hpp" @@ -1467,3 +1469,68 @@ TEST(pattern, pattern_symbol_predicate_and_operators) { ASSERT_NO_THROW(predicate_and(m, input)); } } + +TEST(pattern, predicate_attr_match) { + TestMatcher tm; + auto input = std::make_shared(element::dynamic, PartialShape::dynamic()); + auto constant = op::v0::Constant::create(element::i64, {4}, {0, 0, 0, 20}); + + // boolean attr check + auto pattern_true = pattern::any_input(pattern::attrs_match({{"special_zero", true}})); + auto pattern_false = pattern::any_input(pattern::attrs_match({{"special_zero", false}})); + + auto reshape_true = std::make_shared(input, constant, true); + auto reshape_false = std::make_shared(input, constant, false); + + ASSERT_TRUE(tm.match(pattern_true, reshape_true)); + ASSERT_FALSE(tm.match(pattern_true, reshape_false)); + ASSERT_TRUE(tm.match(pattern_false, reshape_false)); + ASSERT_FALSE(tm.match(pattern_false, reshape_true)); + + // element type check + auto pattern_i64 = + pattern::wrap_type(pattern::attrs_match({{"output_type", "i64"}})); + auto pattern_i32 = pattern::wrap_type({{"output_type", "i32"}}); + + auto shape_of_i64 = std::make_shared(input, element::i64); + auto shape_of_i32 = std::make_shared(input, element::i32); + + ASSERT_TRUE(tm.match(pattern_i64, shape_of_i64)); + ASSERT_FALSE(tm.match(pattern_i64, shape_of_i32)); + ASSERT_FALSE(tm.match(pattern_i32, shape_of_i64)); + ASSERT_TRUE(tm.match(pattern_i32, shape_of_i32)); + + // broadcasting check + auto pattern_numpy = pattern::any_input({{"auto_broadcast", "numpy"}}); + auto pattern_pdpd = pattern::optional({{"auto_broadcast", "pdpd"}}); + auto pattern_numpy_or_pdpd = pattern::any_input(pattern::attrs_match({{"auto_broadcast", "numpy"}}) || + pattern::attrs_match({{"auto_broadcast", "pdpd"}})); + + auto mul_numpy = std::make_shared(input, input, op::AutoBroadcastType::NUMPY); + auto mul_pdpd = std::make_shared(input, input, op::AutoBroadcastType::PDPD); + auto mul_none = std::make_shared(input, input, op::AutoBroadcastType::NONE); + + ASSERT_TRUE(tm.match(pattern_numpy, mul_numpy)); + ASSERT_FALSE(tm.match(pattern_numpy, mul_pdpd)); + ASSERT_FALSE(tm.match(pattern_pdpd, mul_numpy)); + ASSERT_TRUE(tm.match(pattern_pdpd, mul_pdpd)); + + ASSERT_TRUE(tm.match(pattern_numpy_or_pdpd, mul_numpy)); + ASSERT_TRUE(tm.match(pattern_numpy_or_pdpd, mul_pdpd)); + ASSERT_FALSE(tm.match(pattern_numpy_or_pdpd, mul_none)); +} + +TEST(pattern, predicate_value_match) { + TestMatcher tm; + auto constant_i = op::v0::Constant::create(element::i64, {4}, vector{-1, 0, 1, 2}); + auto constant_d = op::v0::Constant::create(element::f64, {4}, vector{-1.5, 0, 1.3, 2.75}); + + // actual value check + auto pattern_i = pattern::any_input(pattern::value_matches("[-1, 0, 1, 2]")); + auto pattern_d = pattern::any_input(pattern::value_matches("[-1.5, 0, 1.3, 2.75]")); + + ASSERT_TRUE(tm.match(pattern_i, constant_i)); + ASSERT_FALSE(tm.match(pattern_i, constant_d)); + ASSERT_TRUE(tm.match(pattern_d, constant_d)); + ASSERT_FALSE(tm.match(pattern_d, constant_i)); +}