25
25
#include " Registration.hpp"
26
26
#include " Objective.hpp"
27
27
28
+ #include " math.hpp"
28
29
#include " approximate_math.hpp"
29
30
#include " compute_wrapper.hpp"
30
31
@@ -94,6 +95,10 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
94
95
return Avx2_32_Int (_mm256_add_epi32 (m_data, other.m_data ));
95
96
}
96
97
98
+ inline Avx2_32_Int operator -(const Avx2_32_Int& other) const noexcept {
99
+ return Avx2_32_Int (_mm256_sub_epi32 (m_data, other.m_data ));
100
+ }
101
+
97
102
inline Avx2_32_Int operator *(const T& other) const noexcept {
98
103
return Avx2_32_Int (_mm256_mullo_epi32 (m_data, _mm256_set1_epi32 (other)));
99
104
}
@@ -106,6 +111,15 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
106
111
return Avx2_32_Int (_mm256_and_si256 (m_data, other.m_data ));
107
112
}
108
113
114
+ inline Avx2_32_Int operator |(const Avx2_32_Int& other) const noexcept {
115
+ return Avx2_32_Int (_mm256_or_si256 (m_data, other.m_data ));
116
+ }
117
+
118
+ friend inline Avx2_32_Int IfThenElse (
119
+ const Avx2_32_Int& cmp, const Avx2_32_Int& trueVal, const Avx2_32_Int& falseVal) noexcept {
120
+ return Avx2_32_Int (_mm256_blendv_epi8 (falseVal.m_data , trueVal.m_data , cmp.m_data ));
121
+ }
122
+
109
123
friend inline Avx2_32_Int PermuteForInterleaf (const Avx2_32_Int& val) noexcept {
110
124
// this function permutes the values into positions that the Interleaf function expects
111
125
// but for any SIMD implementation the positions can be variable as long as they work together
@@ -124,7 +138,28 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
124
138
static_assert (std::is_standard_layout<Avx2_32_Int>::value && std::is_trivially_copyable<Avx2_32_Int>::value,
125
139
" This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
126
140
141
+ template <bool bNegateInput = false ,
142
+ bool bNaNPossible = true ,
143
+ bool bUnderflowPossible = true ,
144
+ bool bOverflowPossible = true >
145
+ inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept ;
146
+ template <bool bNegateOutput = false ,
147
+ bool bNaNPossible = true ,
148
+ bool bNegativePossible = true ,
149
+ bool bZeroPossible = true ,
150
+ bool bPositiveInfinityPossible = true >
151
+ inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept ;
152
+
127
153
struct alignas (k_cAlignment) Avx2_32_Float final {
154
+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
155
+ friend Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept ;
156
+ template <bool bNegateOutput,
157
+ bool bNaNPossible,
158
+ bool bNegativePossible,
159
+ bool bZeroPossible,
160
+ bool bPositiveInfinityPossible>
161
+ friend Avx2_32_Float Log (const Avx2_32_Float& val) noexcept ;
162
+
128
163
using T = float ;
129
164
using TPack = __m256;
130
165
using TInt = Avx2_32_Int;
@@ -142,6 +177,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
142
177
inline Avx2_32_Float (const double val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
143
178
inline Avx2_32_Float (const float val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
144
179
inline Avx2_32_Float (const int val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
180
+ explicit Avx2_32_Float (const Avx2_32_Int& val) : m_data (_mm256_cvtepi32_ps (val.m_data )) {}
145
181
146
182
inline Avx2_32_Float operator +() const noexcept { return *this ; }
147
183
@@ -150,6 +186,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
150
186
_mm256_castsi256_ps (_mm256_xor_si256 (_mm256_castps_si256 (m_data), _mm256_set1_epi32 (0x80000000 ))));
151
187
}
152
188
189
+ inline Avx2_32_Float operator ~() const noexcept {
190
+ return Avx2_32_Float (_mm256_xor_ps (m_data, _mm256_castsi256_ps (_mm256_set1_epi32 (-1 ))));
191
+ }
192
+
153
193
inline Avx2_32_Float operator +(const Avx2_32_Float& other) const noexcept {
154
194
return Avx2_32_Float (_mm256_add_ps (m_data, other.m_data ));
155
195
}
@@ -218,6 +258,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
218
258
return Avx2_32_Float (val) / other;
219
259
}
220
260
261
+ friend inline Avx2_32_Float operator <=(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
262
+ return Avx2_32_Float (_mm256_cmp_ps (left.m_data , right.m_data , _CMP_LE_OQ));
263
+ }
264
+
221
265
inline static Avx2_32_Float Load (const T* const a) noexcept { return Avx2_32_Float (_mm256_load_ps (a)); }
222
266
223
267
inline void Store (T* const a) const noexcept { _mm256_store_ps (a, m_data); }
@@ -484,6 +528,11 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
484
528
return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , mask));
485
529
}
486
530
531
+ friend inline Avx2_32_Float IfThenElse (
532
+ const Avx2_32_Float& cmp, const Avx2_32_Float& trueVal, const Avx2_32_Float& falseVal) noexcept {
533
+ return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , cmp.m_data ));
534
+ }
535
+
487
536
friend inline Avx2_32_Float IfEqual (const Avx2_32_Float& cmp1,
488
537
const Avx2_32_Float& cmp2,
489
538
const Avx2_32_Float& trueVal,
@@ -511,6 +560,26 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
511
560
return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , _mm256_castsi256_ps (mask)));
512
561
}
513
562
563
+ static inline Avx2_32_Int ReinterpretInt (const Avx2_32_Float& val) noexcept {
564
+ return Avx2_32_Int (_mm256_castps_si256 (val.m_data ));
565
+ }
566
+
567
+ static inline Avx2_32_Float ReinterpretFloat (const Avx2_32_Int& val) noexcept {
568
+ return Avx2_32_Float (_mm256_castsi256_ps (val.m_data ));
569
+ }
570
+
571
+ friend inline Avx2_32_Float Round (const Avx2_32_Float& val) noexcept {
572
+ return Avx2_32_Float (_mm256_round_ps (val.m_data , _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
573
+ }
574
+
575
+ friend inline Avx2_32_Float Mantissa (const Avx2_32_Float& val) noexcept {
576
+ return ReinterpretFloat ((ReinterpretInt (val) & 0x007FFFFF ) | 0x3F000000 );
577
+ }
578
+
579
+ friend inline Avx2_32_Int Exponent (const Avx2_32_Float& val) noexcept {
580
+ return ((ReinterpretInt (val) << 1 ) >> 24 ) - Avx2_32_Int (0x7F );
581
+ }
582
+
514
583
friend inline Avx2_32_Float Abs (const Avx2_32_Float& val) noexcept {
515
584
return Avx2_32_Float (_mm256_and_ps (val.m_data , _mm256_castsi256_ps (_mm256_set1_epi32 (0x7FFFFFFF ))));
516
585
}
@@ -553,14 +622,6 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
553
622
return Avx2_32_Float (_mm256_sqrt_ps (val.m_data ));
554
623
}
555
624
556
- friend inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept {
557
- return ApplyFunc ([](T x) { return std::exp (x); }, val);
558
- }
559
-
560
- friend inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept {
561
- return ApplyFunc ([](T x) { return std::log (x); }, val);
562
- }
563
-
564
625
template <bool bDisableApprox,
565
626
bool bNegateInput = false ,
566
627
bool bNaNPossible = true ,
@@ -571,7 +632,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
571
632
static inline Avx2_32_Float ApproxExp (const Avx2_32_Float& val,
572
633
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
573
634
UNUSED (addExpSchraudolphTerm);
574
- return Exp ( bNegateInput ? -val : val);
635
+ return Exp< bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>( val);
575
636
}
576
637
577
638
template <bool bDisableApprox,
@@ -631,8 +692,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
631
692
static inline Avx2_32_Float ApproxLog (
632
693
const Avx2_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
633
694
UNUSED (addLogSchraudolphTerm);
634
- Avx2_32_Float ret = Log (val);
635
- return bNegateOutput ? -ret : ret;
695
+ return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
636
696
}
637
697
638
698
template <bool bDisableApprox,
@@ -723,6 +783,25 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
723
783
static_assert (std::is_standard_layout<Avx2_32_Float>::value && std::is_trivially_copyable<Avx2_32_Float>::value,
724
784
" This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
725
785
786
+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
787
+ inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept {
788
+ return Exp32<Avx2_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
789
+ }
790
+
791
+ template <bool bNegateOutput,
792
+ bool bNaNPossible,
793
+ bool bNegativePossible,
794
+ bool bZeroPossible,
795
+ bool bPositiveInfinityPossible>
796
+ inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept {
797
+ return Log32<Avx2_32_Float,
798
+ bNegateOutput,
799
+ bNaNPossible,
800
+ bNegativePossible,
801
+ bZeroPossible,
802
+ bPositiveInfinityPossible>(val);
803
+ }
804
+
726
805
INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx2_32 (
727
806
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
728
807
const Objective* const pObjective = static_cast <const Objective*>(pObjectiveWrapper->m_pObjective );
0 commit comments