2525#include " Registration.hpp"
2626#include " Objective.hpp"
2727
28+ #include " math.hpp"
2829#include " approximate_math.hpp"
2930#include " compute_wrapper.hpp"
3031
@@ -94,6 +95,10 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
9495 return Avx2_32_Int (_mm256_add_epi32 (m_data, other.m_data ));
9596 }
9697
98+ inline Avx2_32_Int operator -(const Avx2_32_Int& other) const noexcept {
99+ return Avx2_32_Int (_mm256_sub_epi32 (m_data, other.m_data ));
100+ }
101+
97102 inline Avx2_32_Int operator *(const T& other) const noexcept {
98103 return Avx2_32_Int (_mm256_mullo_epi32 (m_data, _mm256_set1_epi32 (other)));
99104 }
@@ -106,6 +111,15 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
106111 return Avx2_32_Int (_mm256_and_si256 (m_data, other.m_data ));
107112 }
108113
114+ inline Avx2_32_Int operator |(const Avx2_32_Int& other) const noexcept {
115+ return Avx2_32_Int (_mm256_or_si256 (m_data, other.m_data ));
116+ }
117+
118+ friend inline Avx2_32_Int IfThenElse (
119+ const Avx2_32_Int& cmp, const Avx2_32_Int& trueVal, const Avx2_32_Int& falseVal) noexcept {
120+ return Avx2_32_Int (_mm256_blendv_epi8 (falseVal.m_data , trueVal.m_data , cmp.m_data ));
121+ }
122+
109123 friend inline Avx2_32_Int PermuteForInterleaf (const Avx2_32_Int& val) noexcept {
110124 // this function permutes the values into positions that the Interleaf function expects
111125 // but for any SIMD implementation the positions can be variable as long as they work together
@@ -124,7 +138,28 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
124138static_assert (std::is_standard_layout<Avx2_32_Int>::value && std::is_trivially_copyable<Avx2_32_Int>::value,
125139 " This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
126140
141+ template <bool bNegateInput = false ,
142+ bool bNaNPossible = true ,
143+ bool bUnderflowPossible = true ,
144+ bool bOverflowPossible = true >
145+ inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept ;
146+ template <bool bNegateOutput = false ,
147+ bool bNaNPossible = true ,
148+ bool bNegativePossible = true ,
149+ bool bZeroPossible = true ,
150+ bool bPositiveInfinityPossible = true >
151+ inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept ;
152+
127153struct alignas (k_cAlignment) Avx2_32_Float final {
154+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
155+ friend Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept ;
156+ template <bool bNegateOutput,
157+ bool bNaNPossible,
158+ bool bNegativePossible,
159+ bool bZeroPossible,
160+ bool bPositiveInfinityPossible>
161+ friend Avx2_32_Float Log (const Avx2_32_Float& val) noexcept ;
162+
128163 using T = float ;
129164 using TPack = __m256;
130165 using TInt = Avx2_32_Int;
@@ -142,6 +177,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
142177 inline Avx2_32_Float (const double val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
143178 inline Avx2_32_Float (const float val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
144179 inline Avx2_32_Float (const int val) noexcept : m_data (_mm256_set1_ps (static_cast <T>(val))) {}
180+ explicit Avx2_32_Float (const Avx2_32_Int& val) : m_data (_mm256_cvtepi32_ps (val.m_data )) {}
145181
146182 inline Avx2_32_Float operator +() const noexcept { return *this ; }
147183
@@ -150,6 +186,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
150186 _mm256_castsi256_ps (_mm256_xor_si256 (_mm256_castps_si256 (m_data), _mm256_set1_epi32 (0x80000000 ))));
151187 }
152188
189+ inline Avx2_32_Float operator ~() const noexcept {
190+ return Avx2_32_Float (_mm256_xor_ps (m_data, _mm256_castsi256_ps (_mm256_set1_epi32 (-1 ))));
191+ }
192+
153193 inline Avx2_32_Float operator +(const Avx2_32_Float& other) const noexcept {
154194 return Avx2_32_Float (_mm256_add_ps (m_data, other.m_data ));
155195 }
@@ -218,6 +258,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
218258 return Avx2_32_Float (val) / other;
219259 }
220260
261+ friend inline Avx2_32_Float operator <=(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
262+ return Avx2_32_Float (_mm256_cmp_ps (left.m_data , right.m_data , _CMP_LE_OQ));
263+ }
264+
221265 inline static Avx2_32_Float Load (const T* const a) noexcept { return Avx2_32_Float (_mm256_load_ps (a)); }
222266
223267 inline void Store (T* const a) const noexcept { _mm256_store_ps (a, m_data); }
@@ -484,6 +528,11 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
484528 return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , mask));
485529 }
486530
531+ friend inline Avx2_32_Float IfThenElse (
532+ const Avx2_32_Float& cmp, const Avx2_32_Float& trueVal, const Avx2_32_Float& falseVal) noexcept {
533+ return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , cmp.m_data ));
534+ }
535+
487536 friend inline Avx2_32_Float IfEqual (const Avx2_32_Float& cmp1,
488537 const Avx2_32_Float& cmp2,
489538 const Avx2_32_Float& trueVal,
@@ -511,6 +560,26 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
511560 return Avx2_32_Float (_mm256_blendv_ps (falseVal.m_data , trueVal.m_data , _mm256_castsi256_ps (mask)));
512561 }
513562
563+ static inline Avx2_32_Int ReinterpretInt (const Avx2_32_Float& val) noexcept {
564+ return Avx2_32_Int (_mm256_castps_si256 (val.m_data ));
565+ }
566+
567+ static inline Avx2_32_Float ReinterpretFloat (const Avx2_32_Int& val) noexcept {
568+ return Avx2_32_Float (_mm256_castsi256_ps (val.m_data ));
569+ }
570+
571+ friend inline Avx2_32_Float Round (const Avx2_32_Float& val) noexcept {
572+ return Avx2_32_Float (_mm256_round_ps (val.m_data , _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
573+ }
574+
575+ friend inline Avx2_32_Float Mantissa (const Avx2_32_Float& val) noexcept {
576+ return ReinterpretFloat ((ReinterpretInt (val) & 0x007FFFFF ) | 0x3F000000 );
577+ }
578+
579+ friend inline Avx2_32_Int Exponent (const Avx2_32_Float& val) noexcept {
580+ return ((ReinterpretInt (val) << 1 ) >> 24 ) - Avx2_32_Int (0x7F );
581+ }
582+
514583 friend inline Avx2_32_Float Abs (const Avx2_32_Float& val) noexcept {
515584 return Avx2_32_Float (_mm256_and_ps (val.m_data , _mm256_castsi256_ps (_mm256_set1_epi32 (0x7FFFFFFF ))));
516585 }
@@ -553,14 +622,6 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
553622 return Avx2_32_Float (_mm256_sqrt_ps (val.m_data ));
554623 }
555624
556- friend inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept {
557- return ApplyFunc ([](T x) { return std::exp (x); }, val);
558- }
559-
560- friend inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept {
561- return ApplyFunc ([](T x) { return std::log (x); }, val);
562- }
563-
564625 template <bool bDisableApprox,
565626 bool bNegateInput = false ,
566627 bool bNaNPossible = true ,
@@ -571,7 +632,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
571632 static inline Avx2_32_Float ApproxExp (const Avx2_32_Float& val,
572633 const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
573634 UNUSED (addExpSchraudolphTerm);
574- return Exp ( bNegateInput ? -val : val);
635+ return Exp< bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>( val);
575636 }
576637
577638 template <bool bDisableApprox,
@@ -631,8 +692,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
631692 static inline Avx2_32_Float ApproxLog (
632693 const Avx2_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
633694 UNUSED (addLogSchraudolphTerm);
634- Avx2_32_Float ret = Log (val);
635- return bNegateOutput ? -ret : ret;
695+ return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
636696 }
637697
638698 template <bool bDisableApprox,
@@ -723,6 +783,25 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
723783static_assert (std::is_standard_layout<Avx2_32_Float>::value && std::is_trivially_copyable<Avx2_32_Float>::value,
724784 " This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
725785
786+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
787+ inline Avx2_32_Float Exp (const Avx2_32_Float& val) noexcept {
788+ return Exp32<Avx2_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
789+ }
790+
791+ template <bool bNegateOutput,
792+ bool bNaNPossible,
793+ bool bNegativePossible,
794+ bool bZeroPossible,
795+ bool bPositiveInfinityPossible>
796+ inline Avx2_32_Float Log (const Avx2_32_Float& val) noexcept {
797+ return Log32<Avx2_32_Float,
798+ bNegateOutput,
799+ bNaNPossible,
800+ bNegativePossible,
801+ bZeroPossible,
802+ bPositiveInfinityPossible>(val);
803+ }
804+
726805INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx2_32 (
727806 const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
728807 const Objective* const pObjective = static_cast <const Objective*>(pObjectiveWrapper->m_pObjective );
0 commit comments