2626#include " openvino/pass/pattern/op/wrap_type.hpp"
2727#include " transformations/utils/utils.hpp"
2828
29+ using namespace ov ;
30+ using namespace ov ::op::util;
31+ using namespace ov ::pass::pattern::op;
32+
33+ constexpr auto SQRT2 = static_cast <float >(M_SQRT2);
34+ constexpr auto SQRT1_2 = static_cast <float >(M_SQRT1_2);
35+
36+ namespace {
37+
38+ Predicate check_value (float ref, float eps = std::numeric_limits<float >::epsilon()) {
39+ return Predicate (
40+ [=](const Output<Node>& output) -> bool {
41+ return ov::op::util::has_constant_value<float >(output.get_node_shared_ptr (), ref, eps);
42+ },
43+ " has_constant_value(" + std::to_string (ref) + " )" );
44+ }
45+
46+ bool gelu_replacer (ov::pass::pattern::Matcher& m, const std::shared_ptr<ov::Node>& pattern_input_to_relu) {
47+ ov::pass::NodeRegistry rg;
48+ auto pattern_to_output = m.get_pattern_map ();
49+ auto x_output = pattern_to_output.at (pattern_input_to_relu);
50+
51+ auto gelu = rg.make <ov::op::v7::Gelu>(x_output);
52+
53+ gelu->set_friendly_name (m.get_match_root ()->get_friendly_name ());
54+ copy_runtime_info (m.get_matched_nodes (), rg.get ());
55+ replace_node (m.get_match_root (), gelu);
56+ return true ;
57+ }
58+
59+ } // namespace
60+
2961ov::pass::GeluFusionWithErfOne::GeluFusionWithErfOne () {
3062 MATCHER_SCOPE (GeluFusionWithErfOne);
3163 // Replaces a sub-graph with a Gelu op
3264 // Shared by every pattern: (1 + erf(x / sqrt(2)))
3365 auto input = pass::pattern::any_input ();
34- auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
66+ auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (SQRT2, 0 . 001f ) );
3567 auto div = ov::pass::pattern::wrap_type<ov::op::v1::Divide>({input, div_constant});
36- auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({div});
37- auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
68+
69+ // In case of ConvertDivideWithConstant is applied and Div is converted to Mul
70+ auto mul_as_div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (SQRT1_2, 0 .001f ));
71+ auto mul_as_div = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_as_div_constant});
72+ auto erf_input = std::make_shared<Or>(ov::OutputVector{div, mul_as_div});
73+
74+ auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({erf_input});
75+
76+ auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (1 .0f ));
3877 auto add = ov::pass::pattern::wrap_type<ov::op::v1::Add>({add_constant, erf});
39- auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
78+ auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value ( 0 . 5f ) );
4079
4180 // (0.5 * x) * (1 + erf(x / sqrt(2))
4281 auto mul_first = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_constant});
4382 auto mul = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({mul_first, add});
4483
4584 ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
46- auto & pattern_to_output = m.get_pattern_value_map ();
47- auto x_output = pattern_to_output.at (input);
48-
49- auto div_const_value =
50- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (div_constant).get_node_shared_ptr ());
51- auto add_const_value =
52- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (add_constant).get_node_shared_ptr ());
53- auto mul_const_value =
54- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (mul_constant).get_node_shared_ptr ());
55-
56- if (!div_const_value || !add_const_value || !mul_const_value) {
57- return false ;
58- }
59-
60- bool valid_constant_values =
61- op::util::has_constant_value<float >(div_const_value, static_cast <float >(M_SQRT2), 0 .001f ) &&
62- op::util::has_constant_value<float >(add_const_value, 1 .0f ) &&
63- op::util::has_constant_value<float >(mul_const_value, 0 .5f );
64-
65- if (!valid_constant_values) {
66- return false ;
67- }
68-
69- auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output);
70-
71- gelu->set_friendly_name (m.get_match_root ()->get_friendly_name ());
72- ov::copy_runtime_info (
73- {
74- pattern_to_output.at (div).get_node_shared_ptr (),
75- pattern_to_output.at (erf).get_node_shared_ptr (),
76- pattern_to_output.at (add).get_node_shared_ptr (),
77- pattern_to_output.at (mul_first).get_node_shared_ptr (),
78- pattern_to_output.at (mul).get_node_shared_ptr (),
79- },
80- gelu);
81- ov::replace_node (m.get_match_root (), gelu);
82- return true ;
85+ return gelu_replacer (m, input);
8386 };
8487
8588 auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, matcher_name);
@@ -91,55 +94,25 @@ ov::pass::GeluFusionWithErfTwo::GeluFusionWithErfTwo() {
9194 // Replaces a sub-graph with a Gelu op
9295 // Shared by every pattern: (1 + erf(x / sqrt(2)))
9396 auto input = pass::pattern::any_input ();
94- auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
97+ auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (SQRT2, 0 . 001f ) );
9598 auto div = ov::pass::pattern::wrap_type<ov::op::v1::Divide>({input, div_constant});
96- auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({div});
97- auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
99+
100+ // In case of ConvertDivideWithConstant is applied and Div is converted to Mul
101+ auto mul_as_div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (SQRT1_2, 0 .001f ));
102+ auto mul_as_div = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_as_div_constant});
103+ auto erf_input = std::make_shared<Or>(ov::OutputVector{div, mul_as_div});
104+
105+ auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({erf_input});
106+ auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (1 .0f ));
98107 auto add = ov::pass::pattern::wrap_type<ov::op::v1::Add>({add_constant, erf});
99- auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
108+ auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value ( 0 . 5f ) );
100109
101110 // 0.5 * (x * (1 + erf(x / sqrt(2)))
102111 auto mul_first = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, add});
103112 auto mul = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({mul_constant, mul_first});
104113
105114 ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
106- auto & pattern_to_output = m.get_pattern_value_map ();
107- auto x_output = pattern_to_output.at (input);
108-
109- auto div_const_value =
110- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (div_constant).get_node_shared_ptr ());
111- auto add_const_value =
112- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (add_constant).get_node_shared_ptr ());
113- auto mul_const_value =
114- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (mul_constant).get_node_shared_ptr ());
115-
116- if (!div_const_value || !add_const_value || !mul_const_value) {
117- return false ;
118- }
119-
120- bool valid_constant_values =
121- op::util::has_constant_value<float >(div_const_value, static_cast <float >(M_SQRT2), 0 .001f ) &&
122- op::util::has_constant_value<float >(add_const_value, 1 .0f ) &&
123- op::util::has_constant_value<float >(mul_const_value, 0 .5f );
124-
125- if (!valid_constant_values) {
126- return false ;
127- }
128-
129- auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output);
130-
131- gelu->set_friendly_name (m.get_match_root ()->get_friendly_name ());
132- ov::copy_runtime_info (
133- {
134- pattern_to_output.at (div).get_node_shared_ptr (),
135- pattern_to_output.at (erf).get_node_shared_ptr (),
136- pattern_to_output.at (add).get_node_shared_ptr (),
137- pattern_to_output.at (mul_first).get_node_shared_ptr (),
138- pattern_to_output.at (mul).get_node_shared_ptr (),
139- },
140- gelu);
141- ov::replace_node (m.get_match_root (), gelu);
142- return true ;
115+ return gelu_replacer (m, input);
143116 };
144117
145118 auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, matcher_name);
@@ -151,55 +124,25 @@ ov::pass::GeluFusionWithErfThree::GeluFusionWithErfThree() {
151124 // Replaces a sub-graph with a Gelu op
152125 // Shared by every pattern: (1 + erf(x / sqrt(2)))
153126 auto input = pass::pattern::any_input ();
154- auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
127+ auto div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (SQRT2, 0 . 001f ) );
155128 auto div = ov::pass::pattern::wrap_type<ov::op::v1::Divide>({input, div_constant});
156- auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({div});
157- auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
129+
130+ // In case of ConvertDivideWithConstant is applied and Div is converted to Mul
131+ auto mul_as_div_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (SQRT1_2, 0 .001f ));
132+ auto mul_as_div = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_as_div_constant});
133+ auto erf_input = std::make_shared<Or>(ov::OutputVector{div, mul_as_div});
134+
135+ auto erf = ov::pass::pattern::wrap_type<ov::op::v0::Erf>({erf_input});
136+ auto add_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value (1 .0f ));
158137 auto add = ov::pass::pattern::wrap_type<ov::op::v1::Add>({add_constant, erf});
159- auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
138+ auto mul_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(check_value ( 0 . 5f ) );
160139
161140 // x * (0.5 * (1 + erf(x / sqrt(2)))
162141 auto mul_first = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({add, mul_constant});
163142 auto mul = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({input, mul_first});
164143
165144 ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
166- auto & pattern_to_output = m.get_pattern_value_map ();
167- auto x_output = pattern_to_output.at (input);
168-
169- auto div_const_value =
170- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (div_constant).get_node_shared_ptr ());
171- auto add_const_value =
172- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (add_constant).get_node_shared_ptr ());
173- auto mul_const_value =
174- ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (mul_constant).get_node_shared_ptr ());
175-
176- if (!div_const_value || !add_const_value || !mul_const_value) {
177- return false ;
178- }
179-
180- bool valid_constant_values =
181- op::util::has_constant_value<float >(div_const_value, static_cast <float >(M_SQRT2), 0 .001f ) &&
182- op::util::has_constant_value<float >(add_const_value, 1 .0f ) &&
183- op::util::has_constant_value<float >(mul_const_value, 0 .5f );
184-
185- if (!valid_constant_values) {
186- return false ;
187- }
188-
189- auto gelu = std::make_shared<ov::op::v7::Gelu>(x_output);
190-
191- gelu->set_friendly_name (m.get_match_root ()->get_friendly_name ());
192- ov::copy_runtime_info (
193- {
194- pattern_to_output.at (div).get_node_shared_ptr (),
195- pattern_to_output.at (erf).get_node_shared_ptr (),
196- pattern_to_output.at (add).get_node_shared_ptr (),
197- pattern_to_output.at (mul_first).get_node_shared_ptr (),
198- pattern_to_output.at (mul).get_node_shared_ptr (),
199- },
200- gelu);
201- ov::replace_node (m.get_match_root (), gelu);
202- return true ;
145+ return gelu_replacer (m, input);
203146 };
204147
205148 auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, matcher_name);
@@ -212,45 +155,19 @@ ov::pass::GeluFusionWithErfFour::GeluFusionWithErfFour() {
212155 using namespace ov ::pass::pattern;
213156
214157 auto input = any_input ();
215- auto mul1_constant = wrap_type<ov::op::v0::Constant>();
158+ auto mul1_constant = wrap_type<ov::op::v0::Constant>(check_value (SQRT1_2, 0 . 001f ) );
216159 auto mul1 = wrap_type<ov::op::v1::Multiply>({input, mul1_constant});
217160 auto erf = wrap_type<ov::op::v0::Erf>({mul1});
218- auto mul2_constant = wrap_type<ov::op::v0::Constant>();
161+ auto mul2_constant = wrap_type<ov::op::v0::Constant>(check_value ( 0 . 5f ) );
219162 auto mul2 = wrap_type<ov::op::v1::Multiply>({erf, mul2_constant});
220- auto add_constant = wrap_type<ov::op::v0::Constant>();
163+ auto add_constant = wrap_type<ov::op::v0::Constant>(check_value ( 0 . 5f ) );
221164 auto add = wrap_type<ov::op::v1::Add>({add_constant, mul2});
222165
223166 // x * (0.5 + 0.5 * erf(x * (1 / sqrt(2))))
224167 auto mul3 = wrap_type<ov::op::v1::Multiply>({input, add});
225168
226169 matcher_pass_callback callback = [=](Matcher& m) {
227- NodeRegistry rg;
228- auto pattern_to_output = m.get_pattern_map ();
229- auto x_output = pattern_to_output.at (input);
230-
231- auto mul1_const_value = ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (mul1_constant));
232- auto add_const_value = ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (add_constant));
233- auto mul2_const_value = ov::as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at (mul2_constant));
234-
235- if (!mul1_const_value || !add_const_value || !mul2_const_value) {
236- return false ;
237- }
238-
239- constexpr auto sqrt2 = static_cast <float >(M_SQRT2);
240- bool valid_constant_values = ov::op::util::has_constant_value<float >(mul1_const_value, 1 .0f / sqrt2, 0 .001f ) &&
241- ov::op::util::has_constant_value<float >(add_const_value, 0 .5f ) &&
242- ov::op::util::has_constant_value<float >(mul2_const_value, 0 .5f );
243-
244- if (!valid_constant_values) {
245- return false ;
246- }
247-
248- auto gelu = rg.make <ov::op::v7::Gelu>(x_output);
249-
250- gelu->set_friendly_name (m.get_match_root ()->get_friendly_name ());
251- copy_runtime_info (m.get_matched_nodes (), rg.get ());
252- replace_node (m.get_match_root (), gelu);
253- return true ;
170+ return gelu_replacer (m, input);
254171 };
255172
256173 auto m = std::make_shared<Matcher>(mul3, matcher_name);
0 commit comments