Skip to content

Commit 836a3e1

Browse files
aobolenskv-Golubev
andauthored
[Snippets][CPU] Unify GELU v0/v7 emitters across architectures (openvinotoolkit#36388)
### Details: Unify GELU v0/v7 emitters across all (x64/aarch64/riscv64) architectures and extract CREATE_GELU_V7_EMITTER to common header ### Tickets: - N/A ### AI Assistance: - *AI assistance used: no* Co-authored-by: Vladislav Golubev <vladislav.golubev@intel.com>
1 parent ed7c038 commit 836a3e1

5 files changed

Lines changed: 47 additions & 86 deletions

File tree

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_ext_emitters.hpp

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "openvino/core/type/element_type.hpp"
1818
#include "openvino/op/clamp.hpp"
1919
#include "openvino/op/elu.hpp"
20-
#include "openvino/op/gelu.hpp"
2120
#include "openvino/op/round.hpp"
2221
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
2322
#include "utils/general_utils.h"
@@ -117,37 +116,27 @@ class jit_hswish_emitter : public jit_dnnl_emitter {
117116
}
118117
};
119118

120-
class jit_gelu_v0_emitter : public jit_dnnl_emitter {
119+
class jit_gelu_erf_emitter : public jit_dnnl_emitter {
121120
public:
122-
jit_gelu_v0_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
123-
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
124-
const std::shared_ptr<ov::Node>& n,
125-
ov::element::Type exec_prc = ov::element::f32)
121+
jit_gelu_erf_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
122+
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
123+
const std::shared_ptr<ov::Node>& n,
124+
ov::element::Type exec_prc = ov::element::f32)
126125
: jit_dnnl_emitter(host, host_isa, n, exec_prc) {
127126
kind = dnnl_eltwise_gelu_erf;
128127

129128
set_injector();
130129
}
131130
};
132131

133-
class jit_gelu_v7_emitter : public jit_dnnl_emitter {
132+
class jit_gelu_tanh_emitter : public jit_dnnl_emitter {
134133
public:
135-
jit_gelu_v7_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
136-
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
137-
const std::shared_ptr<ov::Node>& n,
138-
ov::element::Type exec_prc = ov::element::f32)
134+
jit_gelu_tanh_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
135+
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
136+
const std::shared_ptr<ov::Node>& n,
137+
ov::element::Type exec_prc = ov::element::f32)
139138
: jit_dnnl_emitter(host, host_isa, n, exec_prc) {
140-
auto gelu = getNgraphOpAs<ov::op::v7::Gelu>(n);
141-
ov::op::GeluApproximationMode approximationMode = gelu->get_approximation_mode();
142-
if (approximationMode == ov::op::GeluApproximationMode::ERF) {
143-
kind = dnnl_eltwise_gelu_erf;
144-
} else if (approximationMode == ov::op::GeluApproximationMode::TANH) {
145-
kind = dnnl_eltwise_gelu_tanh;
146-
} else {
147-
OPENVINO_THROW_NOT_IMPLEMENTED(
148-
"Subgraph node doesn't support ngraph operation Gelu with approximation mode: ",
149-
approximationMode);
150-
}
139+
kind = dnnl_eltwise_gelu_tanh;
151140

152141
set_injector();
153142
}

src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -139,37 +139,6 @@ static bool is_segfault_detector_emitter(const intel_cpu::aarch64::jit_emitter*
139139

140140
#endif
141141

142-
#define CREATE_GELU_V7_EMITTER(e_type_erf, e_type_tanh) \
143-
{[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
144-
const auto& n = expr->get_node(); \
145-
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
146-
if (gelu == nullptr) { \
147-
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
148-
} \
149-
const auto approximationMode = gelu->get_approximation_mode(); \
150-
if (approximationMode == ov::op::GeluApproximationMode::ERF) { \
151-
return std::make_shared<e_type_erf>(h.get(), isa, n); \
152-
} \
153-
if (approximationMode == ov::op::GeluApproximationMode::TANH) { \
154-
return std::make_shared<e_type_tanh>(h.get(), isa, n); \
155-
} \
156-
OPENVINO_THROW("Unsupported Gelu approximation mode"); \
157-
}, \
158-
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
159-
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
160-
if (gelu == nullptr) { \
161-
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
162-
} \
163-
const auto approximationMode = gelu->get_approximation_mode(); \
164-
if (approximationMode == ov::op::GeluApproximationMode::ERF) { \
165-
return e_type_erf::get_supported_precisions(n); \
166-
} \
167-
if (approximationMode == ov::op::GeluApproximationMode::TANH) { \
168-
return e_type_tanh::get_supported_precisions(n); \
169-
} \
170-
OPENVINO_THROW("Unsupported Gelu approximation mode"); \
171-
}}
172-
173142
#define CREATE_ROUND_V5_EMITTER(e_type_from_zero, e_type_even) \
174143
{[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
175144
const auto& n = expr->get_node(); \

src/plugins/intel_cpu/src/emitters/snippets/common/emitter_factory.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
#include <vector>
1212

1313
#include "cache/multi_cache.h"
14+
#include "openvino/core/except.hpp"
1415
#include "openvino/core/node.hpp"
1516
#include "openvino/core/type/element_type.hpp"
17+
#include "openvino/op/gelu.hpp"
1618
#include "snippets/lowered/expression.hpp"
1719
#include "snippets/target_machine.hpp"
1820

@@ -162,4 +164,35 @@ template <typename GetHost, typename Isa, typename Wrap, typename GetKernelExecu
162164
EmitterFactory(GetHost, Isa, Wrap, GetKernelExecutorTable, MultiCacheWeakPtr)
163165
-> EmitterFactory<GetHost, Isa, Wrap, GetKernelExecutorTable>;
164166

167+
#define CREATE_GELU_V7_EMITTER(e_type_erf, e_type_tanh) \
168+
{[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
169+
const auto& n = expr->get_node(); \
170+
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
171+
if (gelu == nullptr) { \
172+
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
173+
} \
174+
const auto approximationMode = gelu->get_approximation_mode(); \
175+
if (approximationMode == ov::op::GeluApproximationMode::ERF) { \
176+
return std::make_shared<e_type_erf>(h.get(), isa, n); \
177+
} \
178+
if (approximationMode == ov::op::GeluApproximationMode::TANH) { \
179+
return std::make_shared<e_type_tanh>(h.get(), isa, n); \
180+
} \
181+
OPENVINO_THROW("Unsupported Gelu approximation mode"); \
182+
}, \
183+
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
184+
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
185+
if (gelu == nullptr) { \
186+
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
187+
} \
188+
const auto approximationMode = gelu->get_approximation_mode(); \
189+
if (approximationMode == ov::op::GeluApproximationMode::ERF) { \
190+
return e_type_erf::get_supported_precisions(n); \
191+
} \
192+
if (approximationMode == ov::op::GeluApproximationMode::TANH) { \
193+
return e_type_tanh::get_supported_precisions(n); \
194+
} \
195+
OPENVINO_THROW("Unsupported Gelu approximation mode"); \
196+
}}
197+
165198
} // namespace ov::intel_cpu

src/plugins/intel_cpu/src/emitters/snippets/riscv64/cpu_generator.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -116,37 +116,6 @@
116116
# include "snippets/op/perf_count.hpp"
117117
#endif
118118

119-
#define CREATE_GELU_V7_EMITTER(e_type_erf, e_type_tanh) \
120-
{[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
121-
const auto& n = expr->get_node(); \
122-
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
123-
if (gelu == nullptr) { \
124-
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
125-
} \
126-
const auto approximationMode = gelu->get_approximation_mode(); \
127-
if (approximationMode == ov::op::GeluApproximationMode::ERF) { \
128-
return std::make_shared<e_type_erf>(h.get(), isa, n); \
129-
} \
130-
if (approximationMode == ov::op::GeluApproximationMode::TANH) { \
131-
return std::make_shared<e_type_tanh>(h.get(), isa, n); \
132-
} \
133-
OPENVINO_THROW("Unsupported Gelu approximation mode"); \
134-
}, \
135-
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<ov::element::Type>> { \
136-
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
137-
if (gelu == nullptr) { \
138-
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
139-
} \
140-
const auto approximationMode = gelu->get_approximation_mode(); \
141-
if (approximationMode == ov::op::GeluApproximationMode::ERF) { \
142-
return e_type_erf::get_supported_precisions(n); \
143-
} \
144-
if (approximationMode == ov::op::GeluApproximationMode::TANH) { \
145-
return e_type_tanh::get_supported_precisions(n); \
146-
} \
147-
OPENVINO_THROW("Unsupported Gelu approximation mode"); \
148-
}}
149-
150119
#define CREATE_ROUND_V5_EMITTER(e_type_from_zero, e_type_even) \
151120
{[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
152121
const auto& n = expr->get_node(); \

src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,9 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
329329
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] =
330330
emitter_factory.from_node<ov::intel_cpu::jit_swish_emitter>();
331331
jitters[ov::op::v4::HSwish::get_type_info_static()] = emitter_factory.from_node<intel_cpu::jit_hswish_emitter>();
332-
jitters[ov::op::v0::Gelu::get_type_info_static()] = emitter_factory.from_node<intel_cpu::jit_gelu_v0_emitter>();
333-
jitters[ov::op::v7::Gelu::get_type_info_static()] = emitter_factory.from_node<intel_cpu::jit_gelu_v7_emitter>();
332+
jitters[ov::op::v0::Gelu::get_type_info_static()] = emitter_factory.from_node<intel_cpu::jit_gelu_erf_emitter>();
333+
jitters[ov::op::v7::Gelu::get_type_info_static()] =
334+
CREATE_GELU_V7_EMITTER(intel_cpu::jit_gelu_erf_emitter, intel_cpu::jit_gelu_tanh_emitter);
334335
jitters[snippets::op::Fill::get_type_info_static()] = emitter_factory.from_expr<intel_cpu::jit_fill_emitter>();
335336

336337
jitters[snippets::op::HorizonMax::get_type_info_static()] =

0 commit comments

Comments
 (0)