Skip to content

[TRANSFORMATIONS] Rewrite RoPEFusionFlux using new Symbolic approach #30184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 39 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fb2c7d9
Attributes matching in a pattern
jane-intel Feb 24, 2025
d1b187b
Python API update
jane-intel Feb 24, 2025
9c633bc
Merge branch 'master' into attribute_pattern_matching
jane-intel Feb 24, 2025
d36881e
Linter fix
jane-intel Feb 24, 2025
2cc7a82
Merge branch 'attribute_pattern_matching' of https://github.com/jane-…
jane-intel Feb 24, 2025
56ac8ed
Python API: allow regular Dict[str, ?] to describe expected attributes
jane-intel Feb 25, 2025
540a560
Merge remote-tracking branch 'upstream/master' into attribute_pattern…
jane-intel Feb 25, 2025
19091b3
Minor changes: name for the always true predicate; wrap_type -- no ne…
jane-intel Feb 26, 2025
aee859d
Merge remote-tracking branch 'upstream/master' into attribute_pattern…
jane-intel Feb 26, 2025
e35cd4d
Fixing NodePredicate conversion to ValuePredicate
jane-intel Feb 26, 2025
6b06d96
Removing unused function
jane-intel Feb 26, 2025
ba52522
Merge branch 'master' into CVS-163062_CVS-163051
jane-intel Feb 26, 2025
ab23914
Resolve NodePredicate ambiguity
jane-intel Feb 27, 2025
5277791
Merge remote-tracking branch 'upstream/master' into CVS-163062_CVS-16…
jane-intel Feb 27, 2025
97a4c9d
Merge branch 'CVS-163062_CVS-163051' into attribute_pattern_matching
jane-intel Feb 27, 2025
113efdd
refactor
jane-intel Mar 3, 2025
e7ccae0
Merge remote-tracking branch 'upstream/master' into attribute_pattern…
jane-intel Mar 3, 2025
74453e1
Merge branch 'master' into attribute_pattern_matching
jane-intel Mar 6, 2025
aaff25c
Merge branch 'master' into attribute_pattern_matching
jane-intel Mar 14, 2025
9e4e473
Merge remote-tracking branch 'upstream/master' into attribute_pattern…
jane-intel Mar 14, 2025
ee4c812
Merge remote-tracking branch 'upstream/master' into attribute_pattern…
jane-intel Mar 17, 2025
bb5f44f
Merge branch 'master' into attribute_pattern_matching
jane-intel Apr 7, 2025
792203d
Merge branch 'attribute_pattern_matching' of https://github.com/jane-…
jane-intel Apr 7, 2025
7559ae0
std::nullptr_t
jane-intel Apr 7, 2025
d92f6a4
Style
jane-intel Apr 7, 2025
88b99c7
Merge branch 'master' into attribute_pattern_matching
jane-intel Apr 7, 2025
148e1b7
Merge branch 'master' into attribute_pattern_matching
jane-intel Apr 8, 2025
cc97acb
Revert unnecessary changes
jane-intel Apr 9, 2025
3503e49
Removed unnecessary usage of nullptr_T
jane-intel Apr 9, 2025
7415388
Merge remote-tracking branch 'upstream/master' into attribute_pattern…
jane-intel Apr 9, 2025
9dc9823
Merge branch 'master' into attribute_pattern_matching
jane-intel Apr 10, 2025
53320ef
Apply suggestions from code review
jane-intel Apr 11, 2025
bedd7b7
Merge branch 'master' into attribute_pattern_matching
jane-intel Apr 11, 2025
abe67a7
wip
CuriousPanCake Apr 14, 2025
72197b0
base preparations done
CuriousPanCake Apr 15, 2025
8a69ef5
Merge remote-tracking branch 'upstream/master' into symbolic_feature_…
jane-intel Apr 16, 2025
d1adb6c
Constant matching
jane-intel Apr 16, 2025
4369273
Merge remote-tracking branch 'jane-intel/symbolic_feature_branch' int…
CuriousPanCake Apr 16, 2025
460e0b6
rewrite RoPEFusionFlux
CuriousPanCake Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bindings/python/src/openvino/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
type_matches,
type_matches_any,
shape_matches,
attrs_match,
)
from openvino._pyopenvino.passes import Serialize, ConstantFolding, VisualizeTree, MakeStateful, LowLatency2, ConvertFP32ToFP16, Version
from openvino.passes.manager import Manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "pyopenvino/core/common.hpp"
#include "pyopenvino/utils/utils.hpp"

static ov::NodeTypeInfo get_type(const std::string& type_name) {
// Supported types: opsetX.OpName or opsetX::OpName
Expand Down Expand Up @@ -1014,6 +1015,9 @@ inline void reg_predicates(py::module m) {
m.def("type_matches", &ov::pass::pattern::type_matches);
m.def("type_matches_any", &ov::pass::pattern::type_matches_any);
m.def("shape_matches", &ov::pass::pattern::shape_matches);
m.def("attrs_match", [](py::object& attrs) {
return ov::pass::pattern::attrs_match(Common::utils::py_object_to_unordered_any_map(attrs));
});
}

void reg_passes_pattern_ops(py::module m) {
Expand Down
15 changes: 15 additions & 0 deletions src/bindings/python/src/pyopenvino/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,21 @@ ov::AnyMap py_object_to_any_map(const py::object& py_obj) {
return return_value;
}

std::unordered_map<std::string, ov::Any> py_object_to_unordered_any_map(const py::object& py_obj) {
OPENVINO_ASSERT(py_object_is_any_map(py_obj), "Unsupported attribute type.");
std::unordered_map<std::string, ov::Any> return_value = {};
for (auto& item : py::cast<py::dict>(py_obj)) {
std::string key = py::cast<std::string>(item.first);
py::object value = py::cast<py::object>(item.second);
if (py_object_is_any_map(value)) {
return_value[key] = Common::utils::py_object_to_any_map(value);
} else {
return_value[key] = Common::utils::py_object_to_any(value);
}
}
return return_value;
}

template <typename... Args, std::size_t... I>
std::tuple<Args...> tuple_from_py_tuple_impl(const py::tuple& py_tuple, std::index_sequence<I...>) {
return std::make_tuple(py_tuple[I].cast<Args>()...);
Expand Down
2 changes: 2 additions & 0 deletions src/bindings/python/src/pyopenvino/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class MemoryBuffer : public std::streambuf {

ov::AnyMap py_object_to_any_map(const py::object& py_obj);

std::unordered_map<std::string, ov::Any> py_object_to_unordered_any_map(const py::object& py_obj);

ov::Any py_object_to_any(const py::object& py_obj);

ov::pass::Serialize::Version convert_to_version(const std::string& version);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
type_matches,
type_matches_any,
shape_matches,
attrs_match,
)
from openvino.utils.types import get_element_type

Expand Down Expand Up @@ -278,6 +279,19 @@ def symbol_matching_test(shape: PartialShape, pattern: str):
assert symbols["Six"] == 6, symbols


def test_attrs_match():
param = ops.parameter([-1, -1])

def test_shape_of_attribute(et: str):
node = ops.shape_of(param, output_type=et)
attr = {"output_type": et}
matcher = Matcher(AnyInput(attrs_match(attr)), "Find shape_of with attribute")
assert matcher.match(node), f"Match failed for {node} with attribute"

test_shape_of_attribute("i64")
test_shape_of_attribute("i32")


def test_optional_full_match():
model_input = ops.parameter(PartialShape.dynamic())
model_abs = ops.abs(model_input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,12 @@ class ov::pass::RoPEShareCosSin : public ov::pass::MatcherPass {
* @ingroup ov_transformation_common_api
* @brief Fuses special sub-graph into an internal Rotary Positional Embedding operation
*/
class ov::pass::RoPEFusion : public ov::pass::GraphRewrite {
class ov::pass::RoPEFusion : public ov::pass::ModelPass {
public:
OPENVINO_GRAPH_REWRITE_RTTI("RoPEFusion");
RoPEFusion(bool support_2d_rope = false) {
add_matcher<ov::pass::RoPEFusionFlux>();
add_matcher<ov::pass::RoPEFusionGPTNEOX>();
add_matcher<ov::pass::RoPEFusionGPTJ>();
// optional heads & tails are fused in separate matcher pass,
// after RoPENode has been created.
add_matcher<ov::pass::RoPEFusionCosSinPreprocess>();
add_matcher<ov::pass::RoPEFusionIOSlicing>();
add_matcher<ov::pass::RoPEFusionPreprocess>();

add_matcher<ov::pass::RoPEFusionChatGLM>(0);
add_matcher<ov::pass::RoPEFusionChatGLM>(1);
if (support_2d_rope) {
add_matcher<ov::pass::RoPEFusionChatGLM>(0, true);
add_matcher<ov::pass::RoPEFusionChatGLM>(1, true);
}

add_matcher<ov::pass::RoPEFusionQwen>(0);
add_matcher<ov::pass::RoPEFusionQwen>(1);

add_matcher<ov::pass::RoPEShareCosSin>();
}
OPENVINO_MODEL_PASS_RTTI("RoPEFusion");
RoPEFusion(bool support_2d_rope = false);
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;

private:
bool m_support_2d_rope;
};
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TRANSFORMATIONS_API LabelResolvingThroughSelect;
class ov::pass::SymbolicOptimizations : public ov::pass::ModelPass {
public:
OPENVINO_MODEL_PASS_RTTI("SymbolicOptimizations");
explicit SymbolicOptimizations(bool full_run = true);
explicit SymbolicOptimizations(bool full_run = true, std::shared_ptr<ov::pass::PassConfig> pass_config = nullptr);
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
std::shared_ptr<ov::pass::Manager> get_manager() {
return m_manager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,54 @@
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "ov_ops/rotary_positional_embeddings.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/utils/gen_pattern.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"
#include "transformations/utils/utils.hpp"

#include "openvino/pass/visualize_tree.hpp"

using namespace ov::gen_pattern;

ov::pass::RoPEFusion::RoPEFusion(bool support_2d_rope) : m_support_2d_rope(support_2d_rope) {
}

bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_MODEL_SCOPE(RoPEFusion);
ov::pass::SymbolicOptimizations symbolic_optimizations(false, get_pass_config());

std::cout << "SETTING UP RoPEFusion" << std::endl;
auto symbolic_ctx_manager = symbolic_optimizations.get_manager();

symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionFlux>();
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>();
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTJ>();
// optional heads & tails are fused in separate matcher pass,
// after RoPENode has been created.
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionCosSinPreprocess>();
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionIOSlicing>();
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionPreprocess>();

symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionChatGLM>(0);
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionChatGLM>(1);
if (m_support_2d_rope) {
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionChatGLM>(0, true);
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionChatGLM>(1, true);
}
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionQwen>(0);
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionQwen>(1);

symbolic_ctx_manager->register_pass<ov::pass::RoPEShareCosSin>();

std::cout << "About to run the transformations" << std::endl;
bool a = symbolic_optimizations.get_manager()->run_passes(model);
std::cout << "Run the transformations" << std::endl;
return a;
}

// This is a utility function used in the work around in ChatGLM pattern.
// Since the existing implementation of Symbols don't allow for checking
// permutations of the same Symbols in a shape, we need to check the
Expand Down Expand Up @@ -56,51 +97,42 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() {
// y1 = x * t_cos
// y2 = x3 * t_sin
// y = y1 + y2
auto x = makePattern(ov::Rank(4));
auto t_cos = makePattern(ov::Rank(4));
auto t_sin = makePattern(ov::Rank(4));

auto num_heads = ov::gen_pattern::Symbol("num_heads");
auto head_size = ov::gen_pattern::Symbol("head_size");
auto x = pattern::any_input(pattern::rank_equals(4) && pattern::shape_matches("[PRESERVED_DIMS..., head_size]"));
auto t_cos = pattern::any_input(pattern::rank_equals(4));
auto t_sin = pattern::any_input(pattern::rank_equals(4));

auto x1_target_shape = makeConst({0, num_heads, 0, -1, 2});
auto x1 = makePattern<opset1::Reshape>({x, x1_target_shape}, {{"special_zero", true}});
auto split = makePattern<opset1::Split>({x1, -1}, {{"num_splits", 2}});
auto x1 = pattern::wrap_type<opset1::Reshape>({x, pattern::any_input()}, pattern::shape_matches("[PRESERVED_DIMS..., ?, 2]"));
auto split = pattern::wrap_type<opset1::Split>({x1, INT_CONSTANT_WITH_PREDICATE(value == std::vector<int64_t>{-1})}, {{"num_splits", 2l}});
split->set_output_size(2);

// 3 versions of mulitply by -1 depending on transformations execution prior to this pass
auto x1_1_neg_1 = makePattern<opset1::Multiply>({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}});

auto squeeze_2 = makePattern<opset1::Squeeze>({split->output(1), -1});
auto x1_1_neg_2 = makePattern<opset1::Multiply>({squeeze_2, -1.0f}, {{"auto_broadcast", "numpy"}});
auto unsqueeze_2 = makePattern<opset1::Unsqueeze>({x1_1_neg_2, -1});

auto x1_1_neg_3 = makePattern<opset1::Multiply>({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}});
auto squeeze_3 = makePattern<opset1::Squeeze>({x1_1_neg_3, -1});
auto unsqueeze_3 = makePattern<opset1::Unsqueeze>({squeeze_3, -1});
auto opt_squeeze = pattern::optional<opset1::Squeeze>({split->output(1), INT_CONSTANT_WITH_PREDICATE(value == std::vector<int64_t>{-1})});
auto x1_1_neg = pattern::wrap_type<opset1::Multiply>({opt_squeeze, FLOAT_CONSTANT_WITH_PREDICATE(value == std::vector<float>{-1})}, {{"auto_broadcast", "numpy"}});
auto opt_squeeze_1 = pattern::optional<opset1::Squeeze>({x1_1_neg, INT_CONSTANT_WITH_PREDICATE(value == std::vector<int64_t>{-1})});
auto opt_unsqueeze = pattern::optional<opset1::Unsqueeze>({opt_squeeze_1, INT_CONSTANT_WITH_PREDICATE(value == std::vector<int64_t>{-1})});

auto x2 = makePattern<opset1::Concat>({x1_1_neg_1 | unsqueeze_2 | unsqueeze_3, split->output(0)}, {{"axis", -1}});
auto x3_target_shape = makeConst({0, num_heads, 0, head_size});
auto x3 = makePattern<opset1::Reshape>({x2, x3_target_shape}, {{"special_zero", true}});
auto x2 = pattern::wrap_type<opset1::Concat>({opt_unsqueeze, split->output(0)}, {{"axis", -1l}});
auto x3 = pattern::wrap_type<opset1::Reshape>({x2, pattern::any_input()}, pattern::shape_matches("[PRESERVED_DIMS..., head_size]"));

auto y1 = makePattern<opset1::Multiply>({x, t_cos}, {{"auto_broadcast", "numpy"}});
auto y2 = makePattern<opset1::Multiply>({x3, t_sin}, {{"auto_broadcast", "numpy"}});
auto y1 = pattern::wrap_type<opset1::Multiply>({x, t_cos}, {{"auto_broadcast", "numpy"}});
auto y2 = pattern::wrap_type<opset1::Multiply>({x3, t_sin}, {{"auto_broadcast", "numpy"}});

auto y = makePattern<opset1::Add>({y1, y2}, {{"auto_broadcast", "numpy"}});
auto result = y;
auto result = pattern::wrap_type<opset1::Add>({y1, y2}, {{"auto_broadcast", "numpy"}});

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
PatternValidator validator(m);
if (!validator) {
return false;
}

const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();

auto symbols = m.get_symbols();
auto num_heads = symbols["PRESERVED_DIMS"].g()[1];
auto head_size = symbols["head_size"];
if (!num_heads.is_static() || !head_size.is_static()) {
return false;
}

op::internal::RoPE::Config config;
config.head_cnt = static_cast<size_t>(validator["num_heads"]);
config.head_size = static_cast<size_t>(validator["head_size"]);
config.head_cnt = static_cast<size_t>(num_heads.i());
config.head_size = static_cast<size_t>(head_size.i());
config.rotary_ndims = config.head_size;
config.is_interleaved = true;
config.output_trans0213 = false;
Expand All @@ -126,6 +158,7 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() {

// this new node may match following additional matchers
register_new_node(new_node);
std::cout << "END OF RoPEFusionFlux" << std::endl;

return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,12 @@ ov::pass::DeReshapeFullyConnected::DeReshapeFullyConnected() {
using namespace ov::op;
using namespace ov::pass::pattern;

auto transpose_a_false = [](const std::shared_ptr<Node>& node) -> bool {
auto mm = as_type_ptr<v0::MatMul>(node);
return mm && !mm->get_transpose_a();
};

auto input = wrap_type<v1::Reshape>({any_input(shape_matches("BATCHES_1...,Y")), any_input()},
shape_matches("BATCHES_2...,Y"));
auto converted = pattern::optional<v0::Convert>(input, consumers_count(1));
auto mm_label = wrap_type<v0::MatMul>({converted, any_input(rank_equals(2))},
consumers_count(1) && transpose_a_false && shape_matches("BATCHES_2...,Z"));
consumers_count(1) && shape_matches("BATCHES_2...,Z"),
{{"transpose_a", false}});
auto output = wrap_type<v1::Reshape>({mm_label, any_input()}, shape_matches("BATCHES_1...,Z"));

ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ ov::pass::LabelResolvingThroughSelect::LabelResolvingThroughSelect() {
register_matcher(m, matcher_pass_callback);
}

ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
m_manager = std::make_shared<pass::Manager>("Symbolic");
ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run, std::shared_ptr<ov::pass::PassConfig> pass_config) {
if (pass_config)
m_manager = std::make_shared<pass::Manager>(pass_config, "Symbolic");
else
m_manager = std::make_shared<pass::Manager>("Symbolic");

m_manager->set_per_pass_validation(false);

#define REGISTER_SYMBOLIC(region, ...) m_manager->register_pass<region>(__VA_ARGS__);
Expand Down
6 changes: 6 additions & 0 deletions src/core/include/openvino/core/attribute_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class OPENVINO_API ValueAccessor<void> {
virtual void set_as_any(const ov::Any& x) {
OPENVINO_NOT_IMPLEMENTED;
}
virtual ov::Any get_as_any() {
OPENVINO_NOT_IMPLEMENTED;
}
};

/// \brief Provides access to values via get/set methods from an m_value, typically from
Expand Down Expand Up @@ -65,6 +68,9 @@ class ValueAccessor : public ValueAccessor<void> {
OPENVINO_THROW("Bad cast from: ", x.type_info().name(), " to: ", typeid(VAT).name());
}
}
ov::Any get_as_any() override {
return {get()};
}
};

template <>
Expand Down
2 changes: 1 addition & 1 deletion src/core/include/openvino/pass/pattern/op/label.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class OPENVINO_API Label : public Pattern {
};
} // namespace op

OPENVINO_API std::shared_ptr<Node> any_input();
OPENVINO_API std::shared_ptr<Node> any_input(const Attributes& attrs = {});

template <typename TPredicate>
std::shared_ptr<Node> any_input(const TPredicate& pred) {
Expand Down
40 changes: 40 additions & 0 deletions src/core/include/openvino/pass/pattern/op/op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//// Copyright (C) 2018-2025 Intel Corporation
//// SPDX-License-Identifier: Apache-2.0
////
//
//#pragma once
//
//#include "openvino/pass/pattern/op/predicate.hpp"
//
// namespace ov::pass::pattern {
//// A glue/syntax-sugar type which allows more types to be used as input to pattern operations
// struct PatternOp {
// std::shared_ptr<ov::Node> op;
// size_t output_idx;
//
// operator ov::Output<ov::Node>() const;
// ov::Output<ov::Node> get_output() const;
//
// PatternOp();
// PatternOp(const Output<Node> &out);
//
// template <typename TPredicate,
// typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>>* = nullptr>
// PatternOp(PredicateT pred) {
// op = any_input(pred);
// }
//
// PatternOp(ov::Rank rank);
//
// // Constant matching
// PatternOp(std::string value_notation);
// PatternOp(int v);
// PatternOp(float v);
// PatternOp(double v);
// PatternOp(long long v);
// PatternOp(std::initializer_list<int> v);
// PatternOp(std::initializer_list<float> v);
// PatternOp(std::initializer_list<double> v);
// PatternOp(std::initializer_list<long> v);
// };
// }
Loading
Loading