41
41
42
42
#define SOFTMAX_HACK
43
43
44
- #ifdef __AVX2__
44
+ #ifdef __AVX__
45
45
#include <immintrin.h>
46
+
47
+
48
+ #ifdef __AVX2__
46
49
static __m256 exp8_approx (__m256 X )
47
50
{
48
51
const __m256 K0 = _mm256_set1_ps (0.99992522f );
@@ -65,7 +68,44 @@ static __m256 exp8_approx(__m256 X)
65
68
Y = _mm256_castsi256_ps (_mm256_and_si256 (mask , _mm256_add_epi32 (I , _mm256_castps_si256 (Y ))));
66
69
return Y ;
67
70
}
68
-
71
+ #else
72
+ #define _mm256_fmadd_ps (a ,b ,c ) _mm256_add_ps(_mm256_mul_ps(a, b), c)
73
+ #define _mm_fmadd_ps (a ,b ,c ) _mm_add_ps(_mm_mul_ps(a, b), c)
74
+ static __m128 exp4_approx (__m128 X )
75
+ {
76
+ const __m128 K0 = _mm_set1_ps (0.99992522f );
77
+ const __m128 K1 = _mm_set1_ps (0.69583354f );
78
+ const __m128 K2 = _mm_set1_ps (0.22606716f );
79
+ const __m128 K3 = _mm_set1_ps (0.078024523f );
80
+ const __m128 log2_E = _mm_set1_ps (1.44269504 );
81
+ const __m128 max_in = _mm_set1_ps (50.f );
82
+ const __m128 min_in = _mm_set1_ps (-50.f );
83
+ const __m128i mask = _mm_set1_epi32 (0x7fffffff );
84
+ __m128 XF , Y ;
85
+ __m128i I ;
86
+ X = _mm_mul_ps (X , log2_E );
87
+ X = _mm_max_ps (min_in , _mm_min_ps (max_in , X ));
88
+ XF = _mm_floor_ps (X );
89
+ I = _mm_cvtps_epi32 (XF );
90
+ X = _mm_sub_ps (X , XF );
91
+ Y = _mm_fmadd_ps (_mm_fmadd_ps (_mm_fmadd_ps (K3 , X , K2 ), X , K1 ), X , K0 );
92
+ I = _mm_slli_epi32 (I , 23 );
93
+ Y = _mm_castsi128_ps (_mm_and_si128 (mask , _mm_add_epi32 (I , _mm_castps_si128 (Y ))));
94
+ return Y ;
95
+ }
96
+ static __m256 exp8_approx (__m256 X )
97
+ {
98
+ __m256 Y ;
99
+ __m128 Xhi , Xlo , Yhi , Ylo ;
100
+ Xhi = _mm256_extractf128_ps (X , 1 );
101
+ Xlo = _mm256_extractf128_ps (X , 0 );
102
+ Yhi = exp4_approx (Xhi );
103
+ Ylo = exp4_approx (Xlo );
104
+ Y = _mm256_insertf128_ps (_mm256_setzero_ps (), Yhi , 1 );
105
+ Y = _mm256_insertf128_ps (Y , Ylo , 0 );
106
+ return Y ;
107
+ }
108
+ #endif
69
109
70
110
static float celt_exp (float x )
71
111
{
0 commit comments