1+ /* ++
2+
3+ Copyright 2025 FUJITSU LIMITED
4+
5+ Module Name:
6+
7+ erf_neon_fp16.cpp
8+
9+ Abstract:
10+
11+ This module contains the procedure prototypes for the ERF NEON FP16 intrinsics.
12+
13+ --*/
14+
15+ #include " erf_neon_fp16.h"
16+
17+ // Helpers to safely convert between float and FP16-bit representation
18+ static float
19+ fp16_to_float (uint16_t h)
20+ {
21+ __fp16 tmp;
22+ memcpy (&tmp, &h, sizeof (h));
23+ return (float )tmp;
24+ }
25+
26+ static uint16_t
27+ float_to_fp16 (float f)
28+ {
29+ __fp16 tmp = (__fp16)f;
30+ uint16_t h;
31+ memcpy (&h, &tmp, sizeof (h));
32+ return h;
33+ }
34+
35+ static inline MLAS_FLOAT16X8
36+ exp_neg_rational_approx_f16 (MLAS_FLOAT16X8 x)
37+ {
38+ const float16_t a0 = 6 .0f ;
39+ MLAS_FLOAT16X8 max_x = MlasBroadcastF16Float16x8 (a0);
40+ x = MlasMinimumFloat16 (x, max_x);
41+
42+ const float16_t c0 = 1 .330f ;
43+ const float16_t c1 = -0 .390f ;
44+ const float16_t c2 = 0 .0288f ;
45+
46+ const float16_t d0 = 1 .338f ;
47+ const float16_t d1 = 0 .848f ;
48+ const float16_t d2 = 0 .467f ;
49+
50+ MLAS_FLOAT16X8 c0v = MlasBroadcastF16Float16x8 (c0);
51+ MLAS_FLOAT16X8 c1v = MlasBroadcastF16Float16x8 (c1);
52+ MLAS_FLOAT16X8 c2v = MlasBroadcastF16Float16x8 (c2);
53+
54+ MLAS_FLOAT16X8 d0v = MlasBroadcastF16Float16x8 (d0);
55+ MLAS_FLOAT16X8 d1v = MlasBroadcastF16Float16x8 (d1);
56+ MLAS_FLOAT16X8 d2v = MlasBroadcastF16Float16x8 (d2);
57+ MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16 (x, x);
58+ MLAS_FLOAT16X8 num = MlasMultiplyAddFloat16 (c1v, x, c0v);
59+ num = MlasMultiplyAddFloat16 (c2v, x2, num);
60+ MLAS_FLOAT16X8 den = MlasMultiplyAddFloat16 (d1v, x, d0v);
61+ den = MlasMultiplyAddFloat16 (d2v, x2, den);
62+ MLAS_FLOAT16X8 recip = MlasApproximateReciprocalFloat16 (den);
63+ recip = MlasMultiplyFloat16 (recip, MlasReciprocalSqrtFloat16 (den, recip));
64+ recip = MlasMultiplyFloat16 (recip, MlasReciprocalSqrtFloat16 (den, recip));
65+ MLAS_FLOAT16X8 result = MlasMultiplyFloat16 (num, recip);
66+ return result;
67+ }
68+
69+ void
70+ MlasNeonErfKernelFp16 (const _mlas_fp16_* Input, _mlas_fp16_* Output, size_t N)
71+ {
72+ const float16_t p = 0 .328f ;
73+ const float16_t a1 = 0 .2505f ;
74+ const float16_t a2 = -0 .2881f ;
75+ const float16_t a3 = 1 .4102f ;
76+ const float16_t a4 = -1 .423f ;
77+ const float16_t a5 = 1 .0547f ;
78+
79+ MLAS_FLOAT16X8 vp = MlasBroadcastF16Float16x8 (p);
80+ MLAS_FLOAT16X8 va1 = MlasBroadcastF16Float16x8 (a1);
81+ MLAS_FLOAT16X8 va2 = MlasBroadcastF16Float16x8 (a2);
82+ MLAS_FLOAT16X8 va3 = MlasBroadcastF16Float16x8 (a3);
83+ MLAS_FLOAT16X8 va4 = MlasBroadcastF16Float16x8 (a4);
84+ MLAS_FLOAT16X8 va5 = MlasBroadcastF16Float16x8 (a5);
85+
86+ constexpr float16_t one_fp16 = 1 .0f ;
87+ constexpr float16_t neg_one_fp16 = -1 .0f ;
88+ constexpr float16_t zero_fp16 = 0 .0f ;
89+ constexpr float16_t four_fp16 = 4 .0f ;
90+
91+ MLAS_FLOAT16X8 vone = MlasBroadcastF16Float16x8 (one_fp16);
92+ MLAS_FLOAT16X8 vneg_one = MlasBroadcastF16Float16x8 (neg_one_fp16);
93+ MLAS_FLOAT16X8 vzero = MlasBroadcastF16Float16x8 (zero_fp16);
94+ MLAS_FLOAT16X8 vth = MlasBroadcastF16Float16x8 (four_fp16);
95+
96+ size_t i = 0 ;
97+ for (; i + 8 <= N; i += 8 ) {
98+ MLAS_FLOAT16X8 x = MlasLoadFloat16x8 (&Input[i]);
99+ MLAS_UINT16X8 neg_mask = MlasCompareLessThanFloat16 (x, vzero);
100+ MLAS_FLOAT16X8 sign = MlasSelectFloat16 (neg_mask, vneg_one, vone);
101+ MLAS_FLOAT16X8 absx = MlasAbsFloat16 (x);
102+ MLAS_UINT16X8 use_mask = MlasCompareLessThanFloat16 (absx, vth);
103+ MLAS_FLOAT16X8 absx_clamped = MlasMinimumFloat16 (absx, vth);
104+ MLAS_FLOAT16X8 denom = MlasMultiplyAddFloat16 (vp, absx_clamped, vone);
105+ MLAS_FLOAT16X8 t = MlasApproximateReciprocalFloat16 (denom);
106+ t = MlasMultiplyFloat16 (t, MlasReciprocalSqrtFloat16 (denom, t));
107+ t = MlasMultiplyFloat16 (t, MlasReciprocalSqrtFloat16 (denom, t));
108+ MLAS_FLOAT16X8 t2 = MlasMultiplyFloat16 (t, t);
109+ MLAS_FLOAT16X8 t3 = MlasMultiplyFloat16 (t2, t);
110+ MLAS_FLOAT16X8 t4 = MlasMultiplyFloat16 (t3, t);
111+ MLAS_FLOAT16X8 t5 = MlasMultiplyFloat16 (t4, t);
112+ MLAS_FLOAT16X8 poly = MlasMultiplyFloat16 (va1, t);
113+ poly = MlasMultiplyAddFloat16 (va2, t2, poly);
114+ poly = MlasMultiplyAddFloat16 (va3, t3, poly);
115+ poly = MlasMultiplyAddFloat16 (va4, t4, poly);
116+ poly = MlasMultiplyAddFloat16 (va5, t5, poly);
117+ MLAS_FLOAT16X8 x2 = MlasMultiplyFloat16 (absx_clamped, absx_clamped);
118+ MLAS_FLOAT16X8 exp_neg_x2 = exp_neg_rational_approx_f16 (x2);
119+ MLAS_FLOAT16X8 poly_mul_exp = MlasMultiplyFloat16 (poly, exp_neg_x2);
120+ MLAS_FLOAT16X8 one_minus_term = MlasSubtractFloat16 (vone, poly_mul_exp);
121+ MLAS_FLOAT16X8 erf_approx = MlasMultiplyFloat16 (sign, one_minus_term);
122+ erf_approx = MlasMinimumFloat16 (erf_approx, vone);
123+ erf_approx = MlasMaximumFloat16 (erf_approx, vneg_one);
124+ MLAS_FLOAT16X8 result = MlasSelectFloat16 (use_mask, erf_approx, sign);
125+ MlasStoreFloat16x8 (&Output[i], result);
126+ }
127+
128+ for (; i < N; i++) {
129+ float x = fp16_to_float (Input[i]);
130+ float sign = (x < 0 ) ? -1 .0f : 1 .0f ;
131+ float absx = fabsf (x);
132+
133+ if (absx > 4 .0f ) {
134+ Output[i] = float_to_fp16 (sign);
135+ continue ;
136+ }
137+
138+ float t = 1 .0f / (1 .0f + p * absx);
139+ float poly = a1 * t + a2 * t * t + a3 * t * t * t + a4 * t * t * t * t + a5 * t * t * t * t * t;
140+ float exp_neg_x2 = expf (-absx * absx);
141+ float erf_approx = sign * (1 .0f - poly * exp_neg_x2);
142+ if (erf_approx > 1 .0f ) erf_approx = 1 .0f ;
143+ if (erf_approx < -1 .0f ) erf_approx = -1 .0f ;
144+
145+ Output[i] = float_to_fp16 (erf_approx);
146+ }
147+ }
0 commit comments