Skip to content

Commit 766b269

Browse files
authored
Extend Gelu pattern to support Mul (#29321)
### Details: Extend Gelu patterns to support Mul Simplify the gelu fusion transformations logic Update unit tests ### Tickets: - *CVS-163672*
1 parent 3abf01f commit 766b269

File tree

2 files changed

+96
-178
lines changed

2 files changed

+96
-178
lines changed

src/common/transformations/src/transformations/common_optimizations/gelu_fusion.cpp

Lines changed: 70 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -26,60 +26,63 @@
2626
#include "openvino/pass/pattern/op/wrap_type.hpp"
2727
#include "transformations/utils/utils.hpp"
2828

29+
using namespace ov;
30+
using namespace ov::op::util;
31+
using namespace ov::pass::pattern::op;
32+
33+
constexpr auto SQRT2 = static_cast<float>(M_SQRT2);
34+
constexpr auto SQRT1_2 = static_cast<float>(M_SQRT1_2);
35+
36+
namespace {
37+
38+
Predicate check_value(float ref, float eps = std::numeric_limits<float>::epsilon()) {
39+
return Predicate(
40+
[=](const Output<Node>& output) -> bool {
41+
return ov::op::util::has_constant_value<float>(output.get_node_shared_ptr(), ref, eps);
42+
},
43+
"has_constant_value(" + std::to_string(ref) + ")");
44+
}
45+
46+
bool gelu_replacer(ov::pass::pattern::Matcher& m, const std::shared_ptr<ov::Node>& pattern_input_to_relu) {
47+
ov::pass::NodeRegistry rg;
48+
auto pattern_to_output = m.get_pattern_map();
49+
auto x_output = pattern_to_output.at(pattern_input_to_relu);
50+
51+
auto gelu = rg.make<ov::op::v7::Gelu>(x_output);
52+
53+
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
54+
copy_runtime_info(m.get_matched_nodes(), rg.get());
55+
replace_node(m.get_match_root(), gelu);
56+
return true;
57+
}
58+
59+
} // namespace
60+
2961
ov::pass::GeluFusionWithErfOne::GeluFusionWithErfOne() {
3062
MATCHER_SCOPE(GeluFusionWithErfOne);
3163
// Replaces a sub-graph with a Gelu op
3264
// Shared by every pattern: (1 + erf(x / sqrt(2)))
3365
auto input = pass::pattern::any_input();
34-
auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
66+
auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(SQRT2, 0.001f));
3567
auto div = ov::pass::pattern::wrap_type<ov::op::v1::Divide>({input, div_constant});
36-
auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({div});
37-
auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
68+
69+
// In case of ConvertDivideWithConstant is applied and Div is converted to Mul
70+
auto mul_as_div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(SQRT1_2, 0.001f));
71+
auto mul_as_div = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_as_div_constant});
72+
auto erf_input = std::make_shared<Or>(ov::OutputVector{div, mul_as_div});
73+
74+
auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({erf_input});
75+
76+
auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(1.0f));
3877
auto add = ov::pass::pattern::wrap_type<ov::op::v1::Add>({add_constant, erf});
39-
auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
78+
auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(0.5f));
4079

4180
// (0.5 * x) * (1 + erf(x / sqrt(2))
4281
auto mul_first = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_constant});
4382
auto mul = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({mul_first, add});
4483

4584
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
46-
auto& pattern_to_output = m.get_pattern_value_map();
47-
auto x_output = pattern_to_output.at(input);
48-
49-
auto div_const_value =
50-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
51-
auto add_const_value =
52-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
53-
auto mul_const_value =
54-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
55-
56-
if (!div_const_value || !add_const_value || !mul_const_value) {
57-
return false;
58-
}
59-
60-
bool valid_constant_values =
61-
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
62-
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
63-
op::util::has_constant_value<float>(mul_const_value, 0.5f);
64-
65-
if (!valid_constant_values) {
66-
return false;
67-
}
68-
69-
auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output);
70-
71-
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
72-
ov::copy_runtime_info(
73-
{
74-
pattern_to_output.at(div).get_node_shared_ptr(),
75-
pattern_to_output.at(erf).get_node_shared_ptr(),
76-
pattern_to_output.at(add).get_node_shared_ptr(),
77-
pattern_to_output.at(mul_first).get_node_shared_ptr(),
78-
pattern_to_output.at(mul).get_node_shared_ptr(),
79-
},
80-
gelu);
81-
ov::replace_node(m.get_match_root(), gelu);
82-
return true;
85+
return gelu_replacer(m, input);
8386
};
8487

8588
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, matcher_name);
@@ -91,55 +94,25 @@ ov::pass::GeluFusionWithErfTwo::GeluFusionWithErfTwo() {
9194
// Replaces a sub-graph with a Gelu op
9295
// Shared by every pattern: (1 + erf(x / sqrt(2)))
9396
auto input = pass::pattern::any_input();
94-
auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
97+
auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(SQRT2, 0.001f));
9598
auto div = ov::pass::pattern::wrap_type<ov::op::v1::Divide>({input, div_constant});
96-
auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({div});
97-
auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
99+
100+
// In case of ConvertDivideWithConstant is applied and Div is converted to Mul
101+
auto mul_as_div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(SQRT1_2, 0.001f));
102+
auto mul_as_div = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_as_div_constant});
103+
auto erf_input = std::make_shared<Or>(ov::OutputVector{div, mul_as_div});
104+
105+
auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({erf_input});
106+
auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(1.0f));
98107
auto add = ov::pass::pattern::wrap_type<ov::op::v1::Add>({add_constant, erf});
99-
auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
108+
auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(0.5f));
100109

101110
// 0.5 * (x * (1 + erf(x / sqrt(2)))
102111
auto mul_first = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, add});
103112
auto mul = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({mul_constant, mul_first});
104113

105114
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
106-
auto& pattern_to_output = m.get_pattern_value_map();
107-
auto x_output = pattern_to_output.at(input);
108-
109-
auto div_const_value =
110-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
111-
auto add_const_value =
112-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
113-
auto mul_const_value =
114-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
115-
116-
if (!div_const_value || !add_const_value || !mul_const_value) {
117-
return false;
118-
}
119-
120-
bool valid_constant_values =
121-
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
122-
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
123-
op::util::has_constant_value<float>(mul_const_value, 0.5f);
124-
125-
if (!valid_constant_values) {
126-
return false;
127-
}
128-
129-
auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output);
130-
131-
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
132-
ov::copy_runtime_info(
133-
{
134-
pattern_to_output.at(div).get_node_shared_ptr(),
135-
pattern_to_output.at(erf).get_node_shared_ptr(),
136-
pattern_to_output.at(add).get_node_shared_ptr(),
137-
pattern_to_output.at(mul_first).get_node_shared_ptr(),
138-
pattern_to_output.at(mul).get_node_shared_ptr(),
139-
},
140-
gelu);
141-
ov::replace_node(m.get_match_root(), gelu);
142-
return true;
115+
return gelu_replacer(m, input);
143116
};
144117

145118
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, matcher_name);
@@ -151,55 +124,25 @@ ov::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
151124
// Replaces a sub-graph with a Gelu op
152125
// Shared by every pattern: (1 + erf(x / sqrt(2)))
153126
auto input = pass::pattern::any_input();
154-
auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
127+
auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(SQRT2, 0.001f));
155128
auto div = ov::pass::pattern::wrap_type<ov::op::v1::Divide>({input, div_constant});
156-
auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({div});
157-
auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
129+
130+
// In case of ConvertDivideWithConstant is applied and Div is converted to Mul
131+
auto mul_as_div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(SQRT1_2, 0.001f));
132+
auto mul_as_div = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_as_div_constant});
133+
auto erf_input = std::make_shared<Or>(ov::OutputVector{div, mul_as_div});
134+
135+
auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({erf_input});
136+
auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(1.0f));
158137
auto add = ov::pass::pattern::wrap_type<ov::op::v1::Add>({add_constant, erf});
159-
auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
138+
auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value(0.5f));
160139

161140
// x * (0.5 * (1 + erf(x / sqrt(2)))
162141
auto mul_first = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add, mul_constant});
163142
auto mul = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_first});
164143

165144
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
166-
auto& pattern_to_output = m.get_pattern_value_map();
167-
auto x_output = pattern_to_output.at(input);
168-
169-
auto div_const_value =
170-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
171-
auto add_const_value =
172-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
173-
auto mul_const_value =
174-
ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
175-
176-
if (!div_const_value || !add_const_value || !mul_const_value) {
177-
return false;
178-
}
179-
180-
bool valid_constant_values =
181-
op::util::has_constant_value<float>(div_const_value, static_cast<float>(M_SQRT2), 0.001f) &&
182-
op::util::has_constant_value<float>(add_const_value, 1.0f) &&
183-
op::util::has_constant_value<float>(mul_const_value, 0.5f);
184-
185-
if (!valid_constant_values) {
186-
return false;
187-
}
188-
189-
auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output);
190-
191-
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
192-
ov::copy_runtime_info(
193-
{
194-
pattern_to_output.at(div).get_node_shared_ptr(),
195-
pattern_to_output.at(erf).get_node_shared_ptr(),
196-
pattern_to_output.at(add).get_node_shared_ptr(),
197-
pattern_to_output.at(mul_first).get_node_shared_ptr(),
198-
pattern_to_output.at(mul).get_node_shared_ptr(),
199-
},
200-
gelu);
201-
ov::replace_node(m.get_match_root(), gelu);
202-
return true;
145+
return gelu_replacer(m, input);
203146
};
204147

205148
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, matcher_name);
@@ -212,45 +155,19 @@ ov::pass::GeluFusionWithErfFour::GeluFusionWithErfFour() {
212155
using namespace ov::pass::pattern;
213156

214157
auto input = any_input();
215-
auto mul1_constant = wrap_type<ov::op::v0::Constant>();
158+
auto mul1_constant = wrap_type<ov::op::v0::Constant>(check_value(SQRT1_2, 0.001f));
216159
auto mul1 = wrap_type<ov::op::v1::Multiply>({input, mul1_constant});
217160
auto erf = wrap_type<ov::op::v0::Erf>({mul1});
218-
auto mul2_constant = wrap_type<ov::op::v0::Constant>();
161+
auto mul2_constant = wrap_type<ov::op::v0::Constant>(check_value(0.5f));
219162
auto mul2 = wrap_type<ov::op::v1::Multiply>({erf, mul2_constant});
220-
auto add_constant = wrap_type<ov::op::v0::Constant>();
163+
auto add_constant = wrap_type<ov::op::v0::Constant>(check_value(0.5f));
221164
auto add = wrap_type<ov::op::v1::Add>({add_constant, mul2});
222165

223166
// x * (0.5 + 0.5 * erf(x * (1 / sqrt(2))))
224167
auto mul3 = wrap_type<ov::op::v1::Multiply>({input, add});
225168

226169
matcher_pass_callback callback = [=](Matcher& m) {
227-
NodeRegistry rg;
228-
auto pattern_to_output = m.get_pattern_map();
229-
auto x_output = pattern_to_output.at(input);
230-
231-
auto mul1_const_value = ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(mul1_constant));
232-
auto add_const_value = ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(add_constant));
233-
auto mul2_const_value = ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(mul2_constant));
234-
235-
if (!mul1_const_value || !add_const_value || !mul2_const_value) {
236-
return false;
237-
}
238-
239-
constexpr auto sqrt2 = static_cast<float>(M_SQRT2);
240-
bool valid_constant_values = ov::op::util::has_constant_value<float>(mul1_const_value, 1.0f / sqrt2, 0.001f) &&
241-
ov::op::util::has_constant_value<float>(add_const_value, 0.5f) &&
242-
ov::op::util::has_constant_value<float>(mul2_const_value, 0.5f);
243-
244-
if (!valid_constant_values) {
245-
return false;
246-
}
247-
248-
auto gelu = rg.make<ov::op::v7::Gelu>(x_output);
249-
250-
gelu->set_friendly_name(m.get_match_root()->get_friendly_name());
251-
copy_runtime_info(m.get_matched_nodes(), rg.get());
252-
replace_node(m.get_match_root(), gelu);
253-
return true;
170+
return gelu_replacer(m, input);
254171
};
255172

256173
auto m = std::make_shared<Matcher>(mul3, matcher_name);

0 commit comments

Comments
 (0)