2525#include " Registration.hpp"
2626#include " Objective.hpp"
2727
28+ #include " math.hpp"
2829#include " approximate_math.hpp"
2930#include " compute_wrapper.hpp"
3031
@@ -94,6 +95,10 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
9495 return Avx2_32_Int (_mm256_add_epi32 (m_data, other.m_data ));
9596 }
9697
98+ inline Avx2_32_Int operator -(const Avx2_32_Int& other) const noexcept {
99+ return Avx2_32_Int (_mm256_sub_epi32 (m_data, other.m_data ));
100+ }
101+
97102 inline Avx2_32_Int operator *(const T& other) const noexcept {
98103 return Avx2_32_Int (_mm256_mullo_epi32 (m_data, _mm256_set1_epi32 (other)));
99104 }
@@ -106,6 +111,15 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
106111 return Avx2_32_Int (_mm256_and_si256 (m_data, other.m_data ));
107112 }
108113
114+ inline Avx2_32_Int operator |(const Avx2_32_Int& other) const noexcept {
115+ return Avx2_32_Int (_mm256_or_si256 (m_data, other.m_data ));
116+ }
117+
118+ friend inline Avx2_32_Int IfThenElse (
119+ const Avx2_32_Int& cmp, const Avx2_32_Int& trueVal, const Avx2_32_Int& falseVal) noexcept {
120+ return Avx2_32_Int (_mm256_blendv_epi8 (falseVal.m_data , trueVal.m_data , cmp.m_data ));
121+ }
122+
109123 friend inline Avx2_32_Int PermuteForInterleaf (const Avx2_32_Int& val) noexcept {
110124 // this function permutes the values into positions that the Interleaf function expects
111125 // but for any SIMD implementation the positions can be variable as long as they work together
@@ -124,7 +138,28 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
124138static_assert (std::is_standard_layout<Avx2_32_Int>::value && std::is_trivially_copyable<Avx2_32_Int>::value,
125139 " This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
126140
141+ template <bool bNegateInput = false ,
142+ bool bNaNPossible = true ,
143+ bool bUnderflowPossible = true ,
144+ bool bOverflowPossible = true >
145+ inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept ;
146+ template <bool bNegateOutput = false ,
147+ bool bNaNPossible = true ,
148+ bool bNegativePossible = true ,
149+ bool bZeroPossible = true ,
150+ bool bPositiveInfinityPossible = true >
151+ inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept ;
152+
127153struct alignas (k_cAlignment) Avx2_32_Float final {
154+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
155+ friend Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept ;
156+ template <bool bNegateOutput,
157+ bool bNaNPossible,
158+ bool bNegativePossible,
159+ bool bZeroPossible,
160+ bool bPositiveInfinityPossible>
161+ friend Avx2_32_Float Log (const Avx2_32_Float& val) noexcept ;
162+
128163 using T = float ;
129164 using TPack = __m256;
130165 using TInt = Avx2_32_Int;
@@ -142,6 +177,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
142177 inline Avx2_32_Float (const double val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
143178 inline Avx2_32_Float (const float val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
144179 inline Avx2_32_Float (const int val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
180+ explicit Avx2_32_Float (const Avx2_32_Int& val) : m_data (_mm256_cvtepi32_ps (val.m_data )) {}
145181
146182 inline Avx2_32_Float operator +() const noexcept { return *this ; }
147183
@@ -150,6 +186,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
150186 _mm256_castsi256_ps (_mm256_xor_si256 (_mm256_castps_si256 (m_data), _mm256_set1_epi32 (0x80000000 ))));
151187 }
152188
189+ inline Avx2_32_Float operator ~() const noexcept {
190+ return Avx2_32_Float (_mm256_xor_ps (m_data, _mm256_castsi256_ps (_mm256_set1_epi32 (-1 ))));
191+ }
192+
153193 inline Avx2_32_Float operator +(const Avx2_32_Float& other) const noexcept {
154194 return Avx2_32_Float (_mm256_add_ps (m_data, other.m_data ));
155195 }
@@ -218,6 +258,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
218258 return Avx2_32_Float (val) / other;
219259 }
220260
261+ friend inline Avx2_32_Float operator <=(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
262+ return Avx2_32_Float (_mm256_cmp_ps (left.m_data , right.m_data , _CMP_LE_OQ));
263+ }
264+
221265 inline static Avx2_32_Float Load (const T* const a) noexcept { return Avx2_32_Float (_mm256_load_ps (a)); }
222266
223267 inline void Store (T* const a) const noexcept { _mm256_store_ps (a, m_data); }
@@ -484,6 +528,11 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
484528 return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , mask));
485529 }
486530
531+ friend inline Avx2_32_Float IfThenElse (
532+ const Avx2_32_Float& cmp, const Avx2_32_Float& trueVal, const Avx2_32_Float& falseVal) noexcept {
533+ return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , cmp.m_data ));
534+ }
535+
487536 friend inline Avx2_32_Float IfEqual (const Avx2_32_Float& cmp1,
488537 const Avx2_32_Float& cmp2,
489538 const Avx2_32_Float& trueVal,
@@ -511,6 +560,18 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
511560 return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , _mm256_castsi256_ps (mask)));
512561 }
513562
563+ static inline Avx2_32_Int ReinterpretInt (const Avx2_32_Float& val) noexcept {
564+ return Avx2_32_Int (_mm256_castps_si256 (val.m_data ));
565+ }
566+
567+ static inline Avx2_32_Float ReinterpretFloat (const Avx2_32_Int& val) noexcept {
568+ return Avx2_32_Float (_mm256_castsi256_ps (val.m_data ));
569+ }
570+
571+ friend inline Avx2_32_Float Round (const Avx2_32_Float& val) noexcept {
572+ return Avx2_32_Float (_mm256_round_ps (val.m_data , _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
573+ }
574+
514575 friend inline Avx2_32_Float Abs (const Avx2_32_Float& val) noexcept {
515576 return Avx2_32_Float (_mm256_and_ps (val.m_data , _mm256_castsi256_ps (_mm256_set1_epi32 (0x7FFFFFFF ))));
516577 }
@@ -553,14 +614,6 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
553614 return Avx2_32_Float (_mm256_sqrt_ps (val.m_data ));
554615 }
555616
556- friend inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept {
557- return ApplyFunc ([](T x) { return std::exp (x); }, val);
558- }
559-
560- friend inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept {
561- return ApplyFunc ([](T x) { return std::log (x); }, val);
562- }
563-
564617 template <bool bDisableApprox,
565618 bool bNegateInput = false ,
566619 bool bNaNPossible = true ,
@@ -571,7 +624,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
571624 static inline Avx2_32_Float ApproxExp (const Avx2_32_Float& val,
572625 const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
573626 UNUSED (addExpSchraudolphTerm);
574- return Exp ( bNegateInput ? -val : val);
627+ return Exp< bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>( val);
575628 }
576629
577630 template <bool bDisableApprox,
@@ -631,8 +684,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
631684 static inline Avx2_32_Float ApproxLog (
632685 const Avx2_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
633686 UNUSED (addLogSchraudolphTerm);
634- Avx2_32_Float ret = Log (val);
635- return bNegateOutput ? -ret : ret;
687+ return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
636688 }
637689
638690 template <bool bDisableApprox,
@@ -723,6 +775,25 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
723775static_assert (std::is_standard_layout<Avx2_32_Float>::value && std::is_trivially_copyable<Avx2_32_Float>::value,
724776 " This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
725777
778+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
779+ inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept {
780+ return Exp32<Avx2_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
781+ }
782+
783+ template <bool bNegateOutput,
784+ bool bNaNPossible,
785+ bool bNegativePossible,
786+ bool bZeroPossible,
787+ bool bPositiveInfinityPossible>
788+ inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept {
789+ return Log32<Avx2_32_Float,
790+ bNegateOutput,
791+ bNaNPossible,
792+ bNegativePossible,
793+ bZeroPossible,
794+ bPositiveInfinityPossible>(val);
795+ }
796+
726797INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx2_32 (
727798 const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
728799 const Objective* const pObjective = static_cast <const Objective*>(pObjectiveWrapper->m_pObjective );
0 commit comments