1+ /* ++
2+
3+ Copyright 2025 FUJITSU LIMITED
4+
5+ Module Name:
6+
7+ Gelu.cpp
8+
9+ Abstract:
10+
11+ This module contains Gelu helper functions .
12+
13+ --*/
14+ #include " gelu.h"
15+
16+ #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
17+
18+ void
19+ MLASCALL
20+ MlasNeonGeluF16Kernel (const MLAS_FP16* input, MLAS_FP16* output, MLAS_FP16* temp, int64_t count, const std::string& algo)
21+ {
22+ const float16_t v_half1 = 0 .5f ;
23+ const float16_t v_one1 = 1 .0f ;
24+ const float16_t v_sqrt1_21 = static_cast <float >(M_SQRT1_2);
25+ const float16_t v_B1 = 0 .7978845608028654f ;
26+ const float16_t v_C1 = 0 .035677408136300125f ;
27+ const float16_t c1 = 5 .0f ;
28+ const float16_t c2 = -5 .0f ;
29+ const MLAS_FLOAT16X8 v_half = MlasBroadcastF16Float16x8 (v_half1);
30+ const MLAS_FLOAT16X8 v_one = MlasBroadcastF16Float16x8 (v_one1);
31+ const MLAS_FLOAT16X8 v_sqrt1_2 = MlasBroadcastF16Float16x8 (v_sqrt1_21);
32+ const MLAS_FLOAT16X8 v_B = MlasBroadcastF16Float16x8 (v_B1);
33+ const MLAS_FLOAT16X8 v_C = MlasBroadcastF16Float16x8 (v_C1);
34+
35+ int64_t i = 0 ;
36+
37+ if (algo == " tanh" ) {
38+ // Preprocess input into temp[] for tanh
39+ for (; i + 7 < count; i += 8 ) {
40+ MLAS_FLOAT16X8 x = MlasLoadf16Float16x8 (reinterpret_cast <const float16_t *>(input + i));
41+ MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16 (x, x);
42+ MLAS_FLOAT16X8 inner = MlasMultiplyAddFloat16 (v_C, x2, v_B); // B + C * x^2
43+ MLAS_FLOAT16X8 tanh_arg = MlasMultiplyFloat16 (x, inner); // x * (B + C * x^2)
44+ tanh_arg = MlasMaximumFloat16 (MlasBroadcastF16Float16x8 (c2), MlasMinimumFloat16 (tanh_arg, MlasBroadcastF16Float16x8 (c1)));
45+ MlasStoref16Float16x8 (reinterpret_cast <float16_t *>(temp + i), tanh_arg);
46+ }
47+
48+ // Tail
49+ for (; i < count; ++i) {
50+ float x = static_cast <float >(input[i]);
51+ float inner = x * (0 .7979f + 0 .03568f * x * x);
52+ inner = std::max (-5 .0f , std::min (5 .0f , inner));
53+ temp[i] = static_cast <MLAS_FP16>(inner);
54+ }
55+
56+ // Tanh processing
57+ MlasComputeTanh<MLAS_FP16>(temp, temp, count);
58+
59+ } else if (algo == " none" ) {
60+ // Preprocess input into temp[] for erf
61+ for (i = 0 ; i + 7 < count; i += 8 ) {
62+ MLAS_FLOAT16X8 x = MlasLoadf16Float16x8 (reinterpret_cast <const float16_t *>(input + i));
63+ MLAS_FLOAT16X8 scaled = MlasMultiplyFloat16 (x, v_sqrt1_2);
64+ MlasStoref16Float16x8 (reinterpret_cast <float16_t *>(temp + i), scaled);
65+ }
66+
67+ // Tail
68+ for (; i < count; ++i) {
69+ float x = static_cast <float >(input[i]);
70+ temp[i] = static_cast <MLAS_FP16>(x * 0 .70710678f );
71+ }
72+
73+ // Erf processing
74+ MlasNeonErfF16Kernel (reinterpret_cast <const _mlas_fp16_*>(temp), reinterpret_cast <_mlas_fp16_*>(temp), count);
75+ }
76+
77+ // Final GELU output = 0.5 * x * (1 + tanh|erf)
78+ i = 0 ;
79+ for (; i + 7 < count; i += 8 ) {
80+ MLAS_FLOAT16X8 x = MlasLoadf16Float16x8 (reinterpret_cast <const float16_t *>(input + i));
81+ MLAS_FLOAT16X8 t = MlasLoadf16Float16x8 (reinterpret_cast <const float16_t *>(temp + i));
82+ MLAS_FLOAT16X8 result = MlasMultiplyFloat16 (v_half, MlasMultiplyFloat16 (x, MlasAddFloat16 (v_one, t)));
83+ MlasStoref16Float16x8 (reinterpret_cast <float16_t *>(output + i), result);
84+ }
85+
86+ for (; i < count; ++i) {
87+ float x = static_cast <float >(input[i]);
88+ float t = static_cast <float >(temp[i]);
89+ float gelu = 0 .5f * x * (1 .0f + t);
90+ output[i] = static_cast <MLAS_FP16>(gelu);
91+ }
92+ }
93+ #endif
0 commit comments