diff --git a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp index 6861d8469b2129..dd0776a09d035a 100644 --- a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp @@ -88,29 +88,12 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass { * @ingroup ov_transformation_common_api * @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation */ -class ov::pass::RoPEFusion : public ov::pass::GraphRewrite { +class ov::pass::RoPEFusion : public ov::pass::ModelPass { public: - OPENVINO_GRAPH_REWRITE_RTTI("RoPEFusion"); - RoPEFusion(bool support_2d_rope = false) { - add_matcher(); - add_matcher(); - add_matcher(); - // optional heads & tails are fused in separate matcher pass, - // after RoPENode has been created. - add_matcher(); - add_matcher(); - add_matcher(); - - add_matcher(0); - add_matcher(1); - if (support_2d_rope) { - add_matcher(0, true); - add_matcher(1, true); - } - - add_matcher(0); - add_matcher(1); - - add_matcher(); - } + OPENVINO_MODEL_PASS_RTTI("RoPEFusion"); + RoPEFusion(bool support_2d_rope = false); + bool run_on_model(const std::shared_ptr& model) override; + +private: + bool m_support_2d_rope; }; diff --git a/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp b/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp index 9e85fcd4977430..3801e4ee65b766 100644 --- a/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp +++ b/src/common/transformations/include/transformations/symbolic_transformations/symbolic_optimizations.hpp @@ -25,7 +25,7 @@ class TRANSFORMATIONS_API LabelResolvingThroughSelect; class ov::pass::SymbolicOptimizations : public ov::pass::ModelPass { public: OPENVINO_MODEL_PASS_RTTI("SymbolicOptimizations"); - explicit SymbolicOptimizations(bool full_run = true); + explicit SymbolicOptimizations(bool full_run = true, std::shared_ptr pass_config = nullptr); bool run_on_model(const std::shared_ptr& m) override; std::shared_ptr get_manager() { return m_manager; diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index d71fd92c60ed84..cd32d611768df4 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -6,6 +6,7 @@ #include #include +#include #include "itt.hpp" #include "openvino/core/graph_util.hpp" @@ -31,15 +32,48 @@ #include "openvino/opsets/opset6_decl.hpp" #include "openvino/opsets/opset8_decl.hpp" #include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "ov_ops/rotary_positional_embeddings.hpp" #include "ov_ops/type_relaxed.hpp" +#include "transformations/symbolic_transformations/symbolic_optimizations.hpp" #include "transformations/utils/gen_pattern.hpp" #include "transformations/utils/utils.hpp" using namespace ov::gen_pattern; +ov::pass::RoPEFusion::RoPEFusion(bool support_2d_rope) : m_support_2d_rope(support_2d_rope) {} + +bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr& model) { + RUN_ON_MODEL_SCOPE(RoPEFusion); + ov::pass::SymbolicOptimizations symbolic_optimizations(false, get_pass_config()); + + auto symbolic_ctx_manager = symbolic_optimizations.get_manager(); + + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + // optional heads & tails are fused in separate matcher pass, + // after RoPENode has been created. + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + symbolic_ctx_manager->register_pass(); + + symbolic_ctx_manager->register_pass(0); + symbolic_ctx_manager->register_pass(1); + if (m_support_2d_rope) { + symbolic_ctx_manager->register_pass(0, true); + symbolic_ctx_manager->register_pass(1, true); + } + symbolic_ctx_manager->register_pass(0); + symbolic_ctx_manager->register_pass(1); + + symbolic_ctx_manager->register_pass(); + + return symbolic_optimizations.run_on_model(model); +} + // This is a utility function used in the work around in ChatGLM pattern. // Since the existing implementation of Symbols don't allow for checking // permutations of the same Symbols in a shape, we need to check the @@ -72,51 +106,44 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() { // y1 = x * t_cos // y2 = x3 * t_sin // y = y1 + y2 - auto x = makePattern(ov::Rank(4)); - auto t_cos = makePattern(ov::Rank(4)); - auto t_sin = makePattern(ov::Rank(4)); - - auto num_heads = ov::gen_pattern::Symbol("num_heads"); - auto head_size = ov::gen_pattern::Symbol("head_size"); + auto x = pattern::any_input(pattern::rank_equals(4) && pattern::shape_matches("[PRESERVED_DIMS..., head_size]")); + auto t_cos = pattern::any_input(pattern::rank_equals(4)); + auto t_sin = pattern::any_input(pattern::rank_equals(4)); - auto x1_target_shape = makeConst({0, num_heads, 0, -1, 2}); - auto x1 = makePattern({x, x1_target_shape}, {{"special_zero", true}}); - auto split = makePattern({x1, -1}, {{"num_splits", 2}}); + auto x1 = pattern::wrap_type({x, pattern::any_input()}, + pattern::shape_matches("[PRESERVED_DIMS..., ?, 2]")); + auto split = pattern::wrap_type({x1, -1}, {{"num_splits", 2}}); split->set_output_size(2); // 3 versions of mulitply by -1 depending on transformations execution prior to this pass - auto x1_1_neg_1 = makePattern({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}}); + auto opt_squeeze = pattern::optional({split->output(1), -1}); + auto x1_1_neg = pattern::wrap_type({opt_squeeze, -1}, {{"auto_broadcast", "numpy"}}); + auto opt_squeeze_1 = pattern::optional({x1_1_neg, -1}); + auto opt_unsqueeze = pattern::optional({opt_squeeze_1, -1}); - auto squeeze_2 = makePattern({split->output(1), -1}); - auto x1_1_neg_2 = makePattern({squeeze_2, -1.0f}, {{"auto_broadcast", "numpy"}}); - auto unsqueeze_2 = makePattern({x1_1_neg_2, -1}); + auto x2 = pattern::wrap_type({opt_unsqueeze, split->output(0)}, {{"axis", -1}}); + auto x3 = pattern::wrap_type({x2, pattern::any_input()}, + pattern::shape_matches("[PRESERVED_DIMS..., head_size]")); - auto x1_1_neg_3 = makePattern({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}}); - auto squeeze_3 = makePattern({x1_1_neg_3, -1}); - auto unsqueeze_3 = makePattern({squeeze_3, -1}); + auto y1 = pattern::wrap_type({x, t_cos}, {{"auto_broadcast", "numpy"}}); + auto y2 = pattern::wrap_type({x3, t_sin}, {{"auto_broadcast", "numpy"}}); - auto x2 = makePattern({x1_1_neg_1 | unsqueeze_2 | unsqueeze_3, split->output(0)}, {{"axis", -1}}); - auto x3_target_shape = makeConst({0, num_heads, 0, head_size}); - auto x3 = makePattern({x2, x3_target_shape}, {{"special_zero", true}}); - - auto y1 = makePattern({x, t_cos}, {{"auto_broadcast", "numpy"}}); - auto y2 = makePattern({x3, t_sin}, {{"auto_broadcast", "numpy"}}); - - auto y = makePattern({y1, y2}, {{"auto_broadcast", "numpy"}}); - auto result = y; + auto result = pattern::wrap_type({y1, y2}, {{"auto_broadcast", "numpy"}}); matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { - PatternValidator validator(m); - if (!validator) { - return false; - } - const auto& pattern_map = m.get_pattern_value_map(); auto root = m.get_match_root(); + auto symbols = m.get_symbols(); + auto num_heads = symbols["PRESERVED_DIMS"].g()[1]; + auto head_size = symbols["head_size"]; + if (!num_heads.is_static() || !head_size.is_static()) { + return false; + } + op::internal::RoPE::Config config; - config.head_cnt = static_cast(validator["num_heads"]); - config.head_size = static_cast(validator["head_size"]); + config.head_cnt = static_cast(num_heads.i()); + config.head_size = static_cast(head_size.i()); config.rotary_ndims = config.head_size; config.is_interleaved = true; config.output_trans0213 = false; @@ -142,7 +169,6 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() { // this new node may match following additional matchers register_new_node(new_node); - return true; }; diff --git a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp index ec9945b8ea1ecc..1cf2e6d0230a9a 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp @@ -171,8 +171,13 @@ ov::pass::LabelResolvingThroughSelect::LabelResolvingThroughSelect() { register_matcher(m, matcher_pass_callback); } -ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) { - m_manager = std::make_shared("Symbolic"); +ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run, + std::shared_ptr pass_config) { + if (pass_config) + m_manager = std::make_shared(pass_config, "Symbolic"); + else + m_manager = std::make_shared("Symbolic"); + m_manager->set_per_pass_validation(false); #define REGISTER_SYMBOLIC(region, ...) m_manager->register_pass(__VA_ARGS__);