Skip to content

Commit e058e81

Browse files
committed
Support for plain AVX with no FMA
1 parent b38c4b4 commit e058e81

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

src/nnet.c

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@
4141

4242
#define SOFTMAX_HACK
4343

44-
#ifdef __AVX2__
44+
#ifdef __AVX__
4545
#include <immintrin.h>
46+
47+
48+
#ifdef __AVX2__
4649
static __m256 exp8_approx(__m256 X)
4750
{
4851
const __m256 K0 = _mm256_set1_ps(0.99992522f);
@@ -65,7 +68,44 @@ static __m256 exp8_approx(__m256 X)
6568
Y = _mm256_castsi256_ps(_mm256_and_si256(mask, _mm256_add_epi32(I, _mm256_castps_si256(Y))));
6669
return Y;
6770
}
68-
71+
#else
72+
#define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
73+
#define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
74+
static __m128 exp4_approx(__m128 X)
75+
{
76+
const __m128 K0 = _mm_set1_ps(0.99992522f);
77+
const __m128 K1 = _mm_set1_ps(0.69583354f);
78+
const __m128 K2 = _mm_set1_ps(0.22606716f);
79+
const __m128 K3 = _mm_set1_ps(0.078024523f);
80+
const __m128 log2_E = _mm_set1_ps(1.44269504);
81+
const __m128 max_in = _mm_set1_ps(50.f);
82+
const __m128 min_in = _mm_set1_ps(-50.f);
83+
const __m128i mask = _mm_set1_epi32(0x7fffffff);
84+
__m128 XF, Y;
85+
__m128i I;
86+
X = _mm_mul_ps(X, log2_E);
87+
X = _mm_max_ps(min_in, _mm_min_ps(max_in, X));
88+
XF = _mm_floor_ps(X);
89+
I = _mm_cvtps_epi32(XF);
90+
X = _mm_sub_ps(X, XF);
91+
Y = _mm_fmadd_ps(_mm_fmadd_ps(_mm_fmadd_ps(K3, X, K2), X, K1), X, K0);
92+
I = _mm_slli_epi32(I, 23);
93+
Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
94+
return Y;
95+
}
96+
static __m256 exp8_approx(__m256 X)
97+
{
98+
__m256 Y;
99+
__m128 Xhi, Xlo, Yhi, Ylo;
100+
Xhi = _mm256_extractf128_ps(X, 1);
101+
Xlo = _mm256_extractf128_ps(X, 0);
102+
Yhi = exp4_approx(Xhi);
103+
Ylo = exp4_approx(Xlo);
104+
Y = _mm256_insertf128_ps(_mm256_setzero_ps(), Yhi, 1);
105+
Y = _mm256_insertf128_ps(Y, Ylo, 0);
106+
return Y;
107+
}
108+
#endif
69109

70110
static float celt_exp(float x)
71111
{

0 commit comments

Comments
 (0)