@@ -20,7 +20,6 @@ Module Name:
2020#include " softmax.h"
2121#include " softmax_kernel_neon.h"
2222
23- // TODO(fajin): intra-loop parallelism
2423namespace softmax_neon {
2524
2625template <typename T>
@@ -44,7 +43,7 @@ struct MlasExpConstants {
4443 T MaximumExponent;
4544};
4645
47- const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
46+ constexpr MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
4847 0xcc55 , // -25 * ln2
4948 0x498c , // 16 * ln2
5049 0xc95f , // -15.5 * ln2
@@ -64,67 +63,65 @@ const MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = {
6463 0x3C00 , // 15
6564};
6665
67- const MlasExpConstants<float16x8_t > ExpConstantsFp16x8 = {
68- MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRange ),
69- MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRange ),
70- MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRangeSumExp ),
71- MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRangeSumExp ),
72- MlasBroadcastFloat16x8 (ExpConstantsFp16.RoundingBias ),
73- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Reciprocal ),
74- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2High ),
75- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Mid ),
76- MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Low ),
77- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_0 ),
78- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_1 ),
79- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_2 ),
80- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_3 ),
81- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_4 ),
82- MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_56 ),
83- MlasBroadcastFloat16x8 (ExpConstantsFp16.MinimumExponent ),
84- MlasBroadcastFloat16x8 (ExpConstantsFp16.MaximumExponent ),
85- };
86-
87- const MlasExpConstants<float16x4_t > ExpConstantsFp16x4 = {
88- MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRange ),
89- MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRange ),
90- MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRangeSumExp ),
91- MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRangeSumExp ),
92- MlasBroadcastFloat16x4 (ExpConstantsFp16.RoundingBias ),
93- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Reciprocal ),
94- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2High ),
95- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Mid ),
96- MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Low ),
97- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_0 ),
98- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_1 ),
99- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_2 ),
100- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_3 ),
101- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_4 ),
102- MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_56 ),
103- MlasBroadcastFloat16x4 (ExpConstantsFp16.MinimumExponent ),
104- MlasBroadcastFloat16x4 (ExpConstantsFp16.MaximumExponent ),
105- };
106-
10766template <typename T>
10867MLAS_FORCEINLINE
109- MlasExpConstants<T> Get_Exp_Constants ();
68+ const MlasExpConstants<T>& Get_Exp_Constants ();
11069
11170template <>
11271MLAS_FORCEINLINE
113- MlasExpConstants<float16x8_t > Get_Exp_Constants<float16x8_t >() {
72+ const MlasExpConstants<float16x8_t >& Get_Exp_Constants<float16x8_t >() {
73+ const static MlasExpConstants<float16x8_t > ExpConstantsFp16x8 = {
74+ MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRange ),
75+ MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRange ),
76+ MlasBroadcastFloat16x8 (ExpConstantsFp16.LowerRangeSumExp ),
77+ MlasBroadcastFloat16x8 (ExpConstantsFp16.UpperRangeSumExp ),
78+ MlasBroadcastFloat16x8 (ExpConstantsFp16.RoundingBias ),
79+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Reciprocal ),
80+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2High ),
81+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Mid ),
82+ MlasBroadcastFloat16x8 (ExpConstantsFp16.Log2Low ),
83+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_0 ),
84+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_1 ),
85+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_2 ),
86+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_3 ),
87+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_4 ),
88+ MlasBroadcastFloat16x8 (ExpConstantsFp16.poly_56 ),
89+ MlasBroadcastFloat16x8 (ExpConstantsFp16.MinimumExponent ),
90+ MlasBroadcastFloat16x8 (ExpConstantsFp16.MaximumExponent ),
91+ };
11492 return ExpConstantsFp16x8;
11593}
11694
11795template <>
11896MLAS_FORCEINLINE
119- MlasExpConstants<float16x4_t > Get_Exp_Constants<float16x4_t >() {
97+ const MlasExpConstants<float16x4_t >& Get_Exp_Constants<float16x4_t >() {
98+ const static MlasExpConstants<float16x4_t > ExpConstantsFp16x4 = {
99+ MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRange ),
100+ MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRange ),
101+ MlasBroadcastFloat16x4 (ExpConstantsFp16.LowerRangeSumExp ),
102+ MlasBroadcastFloat16x4 (ExpConstantsFp16.UpperRangeSumExp ),
103+ MlasBroadcastFloat16x4 (ExpConstantsFp16.RoundingBias ),
104+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Reciprocal ),
105+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2High ),
106+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Mid ),
107+ MlasBroadcastFloat16x4 (ExpConstantsFp16.Log2Low ),
108+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_0 ),
109+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_1 ),
110+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_2 ),
111+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_3 ),
112+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_4 ),
113+ MlasBroadcastFloat16x4 (ExpConstantsFp16.poly_56 ),
114+ MlasBroadcastFloat16x4 (ExpConstantsFp16.MinimumExponent ),
115+ MlasBroadcastFloat16x4 (ExpConstantsFp16.MaximumExponent ),
116+ };
120117 return ExpConstantsFp16x4;
121118}
122119
123120// Range reduction + polynomial approximation. Refer algorithm details to MlasComputeExpVector.
124121template <typename T>
125122MLAS_FORCEINLINE
126123T Exp_Vector_Fp16 (T x) {
127- const auto constants = Get_Exp_Constants<T>();
124+ const auto & constants = Get_Exp_Constants<T>();
128125 auto clamped_x = MlasClampFloat16 (x, constants.LowerRange , constants.UpperRange );
129126
130127 // integral
@@ -242,7 +239,7 @@ void Exp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) {
242239template <typename T>
243240MLAS_FORCEINLINE
244241T SumExp_Vector_Fp16 (T x, T negative_maximum) {
245- const auto constants = Get_Exp_Constants<T>();
242+ const auto & constants = Get_Exp_Constants<T>();
246243 auto clamped_x = MlasMaximumFloat16 (MlasAddFloat16 (x, negative_maximum), constants.LowerRangeSumExp );
247244
248245 // integral
@@ -419,7 +416,7 @@ struct MlasTanhConstants {
419416 T beta_0;
420417};
421418
422- const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
419+ constexpr MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
423420 0xc308 , // -3.51562
424421 0x4308 , // 3.51562
425422 0x0001 ,
@@ -432,53 +429,51 @@ const MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = {
432429 0x1d03 ,
433430};
434431
435- const MlasTanhConstants<float16x8_t > TanhConstantsFp16x8 = {
436- MlasBroadcastFloat16x8 (TanhConstantsFp16.LowerRange ),
437- MlasBroadcastFloat16x8 (TanhConstantsFp16.UpperRange ),
438- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_7 ),
439- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_5 ),
440- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_3 ),
441- MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_1 ),
442- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_6 ),
443- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_4 ),
444- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_2 ),
445- MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_0 ),
446- };
447-
448- const MlasTanhConstants<float16x4_t > TanhConstantsFp16x4 = {
449- MlasBroadcastFloat16x4 (TanhConstantsFp16.LowerRange ),
450- MlasBroadcastFloat16x4 (TanhConstantsFp16.UpperRange ),
451- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_7 ),
452- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_5 ),
453- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_3 ),
454- MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_1 ),
455- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_6 ),
456- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_4 ),
457- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_2 ),
458- MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_0 ),
459- };
460-
461432template <typename T>
462433MLAS_FORCEINLINE
463- MlasTanhConstants<T> Get_Tanh_Constants ();
434+ const MlasTanhConstants<T>& Get_Tanh_Constants ();
464435
465436template <>
466437MLAS_FORCEINLINE
467- MlasTanhConstants<float16x8_t > Get_Tanh_Constants<float16x8_t >() {
438+ const MlasTanhConstants<float16x8_t >& Get_Tanh_Constants<float16x8_t >() {
439+ const static MlasTanhConstants<float16x8_t > TanhConstantsFp16x8 = {
440+ MlasBroadcastFloat16x8 (TanhConstantsFp16.LowerRange ),
441+ MlasBroadcastFloat16x8 (TanhConstantsFp16.UpperRange ),
442+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_7 ),
443+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_5 ),
444+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_3 ),
445+ MlasBroadcastFloat16x8 (TanhConstantsFp16.alpha_1 ),
446+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_6 ),
447+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_4 ),
448+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_2 ),
449+ MlasBroadcastFloat16x8 (TanhConstantsFp16.beta_0 ),
450+ };
468451 return TanhConstantsFp16x8;
469452}
470453
471454template <>
472455MLAS_FORCEINLINE
473- MlasTanhConstants<float16x4_t > Get_Tanh_Constants<float16x4_t >() {
456+ const MlasTanhConstants<float16x4_t >& Get_Tanh_Constants<float16x4_t >() {
457+ const static MlasTanhConstants<float16x4_t > TanhConstantsFp16x4 = {
458+ MlasBroadcastFloat16x4 (TanhConstantsFp16.LowerRange ),
459+ MlasBroadcastFloat16x4 (TanhConstantsFp16.UpperRange ),
460+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_7 ),
461+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_5 ),
462+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_3 ),
463+ MlasBroadcastFloat16x4 (TanhConstantsFp16.alpha_1 ),
464+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_6 ),
465+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_4 ),
466+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_2 ),
467+ MlasBroadcastFloat16x4 (TanhConstantsFp16.beta_0 ),
468+ };
474469 return TanhConstantsFp16x4;
475470}
476471
477472// TODO(fajin): optimize polynomial coefficients
478473template <typename T>
479474MLAS_FORCEINLINE
480475T Tanh_Vector_Fp16 (T x) {
481- const auto constants = Get_Tanh_Constants<T>();
476+ const auto & constants = Get_Tanh_Constants<T>();
482477 x = MlasClampFloat16 (x, constants.LowerRange , constants.UpperRange );
483478
484479 T x_2 = MlasMultiplyFloat16 (x, x);
0 commit comments