Skip to content

Commit ed7eec6

Browse files
committed
SIMD exp/log
1 parent 25e958f commit ed7eec6

File tree

4 files changed

+302
-11
lines changed

4 files changed

+302
-11
lines changed

shared/libebm/compute/avx2_ebm/avx2_32.cpp

+90-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "Registration.hpp"
2626
#include "Objective.hpp"
2727

28+
#include "math.hpp"
2829
#include "approximate_math.hpp"
2930
#include "compute_wrapper.hpp"
3031

@@ -94,6 +95,10 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
9495
return Avx2_32_Int(_mm256_add_epi32(m_data, other.m_data));
9596
}
9697

98+
inline Avx2_32_Int operator-(const Avx2_32_Int& other) const noexcept {
99+
return Avx2_32_Int(_mm256_sub_epi32(m_data, other.m_data));
100+
}
101+
97102
inline Avx2_32_Int operator*(const T& other) const noexcept {
98103
return Avx2_32_Int(_mm256_mullo_epi32(m_data, _mm256_set1_epi32(other)));
99104
}
@@ -106,6 +111,15 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
106111
return Avx2_32_Int(_mm256_and_si256(m_data, other.m_data));
107112
}
108113

114+
inline Avx2_32_Int operator|(const Avx2_32_Int& other) const noexcept {
115+
return Avx2_32_Int(_mm256_or_si256(m_data, other.m_data));
116+
}
117+
118+
friend inline Avx2_32_Int IfThenElse(
119+
const Avx2_32_Int& cmp, const Avx2_32_Int& trueVal, const Avx2_32_Int& falseVal) noexcept {
120+
return Avx2_32_Int(_mm256_blendv_epi8(falseVal.m_data, trueVal.m_data, cmp.m_data));
121+
}
122+
109123
friend inline Avx2_32_Int PermuteForInterleaf(const Avx2_32_Int& val) noexcept {
110124
// this function permutes the values into positions that the Interleaf function expects
111125
// but for any SIMD implementation the positions can be variable as long as they work together
@@ -124,7 +138,28 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
124138
static_assert(std::is_standard_layout<Avx2_32_Int>::value && std::is_trivially_copyable<Avx2_32_Int>::value,
125139
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
126140

141+
template<bool bNegateInput = false,
142+
bool bNaNPossible = true,
143+
bool bUnderflowPossible = true,
144+
bool bOverflowPossible = true>
145+
inline Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept;
146+
template<bool bNegateOutput = false,
147+
bool bNaNPossible = true,
148+
bool bNegativePossible = true,
149+
bool bZeroPossible = true,
150+
bool bPositiveInfinityPossible = true>
151+
inline Avx2_32_Float Log(const Avx2_32_Float& val) noexcept;
152+
127153
struct alignas(k_cAlignment) Avx2_32_Float final {
154+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
155+
friend Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept;
156+
template<bool bNegateOutput,
157+
bool bNaNPossible,
158+
bool bNegativePossible,
159+
bool bZeroPossible,
160+
bool bPositiveInfinityPossible>
161+
friend Avx2_32_Float Log(const Avx2_32_Float& val) noexcept;
162+
128163
using T = float;
129164
using TPack = __m256;
130165
using TInt = Avx2_32_Int;
@@ -142,6 +177,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
142177
inline Avx2_32_Float(const double val) noexcept : m_data(_mm256_set1_ps(static_cast<T>(val))) {}
143178
inline Avx2_32_Float(const float val) noexcept : m_data(_mm256_set1_ps(static_cast<T>(val))) {}
144179
inline Avx2_32_Float(const int val) noexcept : m_data(_mm256_set1_ps(static_cast<T>(val))) {}
180+
explicit Avx2_32_Float(const Avx2_32_Int& val) : m_data(_mm256_cvtepi32_ps(val.m_data)) {}
145181

146182
inline Avx2_32_Float operator+() const noexcept { return *this; }
147183

@@ -150,6 +186,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
150186
_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(m_data), _mm256_set1_epi32(0x80000000))));
151187
}
152188

189+
inline Avx2_32_Float operator~() const noexcept {
190+
return Avx2_32_Float(_mm256_xor_ps(m_data, _mm256_castsi256_ps(_mm256_set1_epi32(-1))));
191+
}
192+
153193
inline Avx2_32_Float operator+(const Avx2_32_Float& other) const noexcept {
154194
return Avx2_32_Float(_mm256_add_ps(m_data, other.m_data));
155195
}
@@ -218,6 +258,10 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
218258
return Avx2_32_Float(val) / other;
219259
}
220260

261+
friend inline Avx2_32_Float operator<=(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
262+
return Avx2_32_Float(_mm256_cmp_ps(left.m_data, right.m_data, _CMP_LE_OQ));
263+
}
264+
221265
inline static Avx2_32_Float Load(const T* const a) noexcept { return Avx2_32_Float(_mm256_load_ps(a)); }
222266

223267
inline void Store(T* const a) const noexcept { _mm256_store_ps(a, m_data); }
@@ -484,6 +528,11 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
484528
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, mask));
485529
}
486530

531+
friend inline Avx2_32_Float IfThenElse(
532+
const Avx2_32_Float& cmp, const Avx2_32_Float& trueVal, const Avx2_32_Float& falseVal) noexcept {
533+
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, cmp.m_data));
534+
}
535+
487536
friend inline Avx2_32_Float IfEqual(const Avx2_32_Float& cmp1,
488537
const Avx2_32_Float& cmp2,
489538
const Avx2_32_Float& trueVal,
@@ -511,6 +560,26 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
511560
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, _mm256_castsi256_ps(mask)));
512561
}
513562

563+
static inline Avx2_32_Int ReinterpretInt(const Avx2_32_Float& val) noexcept {
564+
return Avx2_32_Int(_mm256_castps_si256(val.m_data));
565+
}
566+
567+
static inline Avx2_32_Float ReinterpretFloat(const Avx2_32_Int& val) noexcept {
568+
return Avx2_32_Float(_mm256_castsi256_ps(val.m_data));
569+
}
570+
571+
friend inline Avx2_32_Float Round(const Avx2_32_Float& val) noexcept {
572+
return Avx2_32_Float(_mm256_round_ps(val.m_data, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
573+
}
574+
575+
friend inline Avx2_32_Float Mantissa(const Avx2_32_Float& val) noexcept {
576+
return ReinterpretFloat((ReinterpretInt(val) & 0x007FFFFF) | 0x3F000000);
577+
}
578+
579+
friend inline Avx2_32_Int Exponent(const Avx2_32_Float& val) noexcept {
580+
return ((ReinterpretInt(val) << 1) >> 24) - Avx2_32_Int(0x7F);
581+
}
582+
514583
friend inline Avx2_32_Float Abs(const Avx2_32_Float& val) noexcept {
515584
return Avx2_32_Float(_mm256_and_ps(val.m_data, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF))));
516585
}
@@ -553,14 +622,6 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
553622
return Avx2_32_Float(_mm256_sqrt_ps(val.m_data));
554623
}
555624

556-
friend inline Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept {
557-
return ApplyFunc([](T x) { return std::exp(x); }, val);
558-
}
559-
560-
friend inline Avx2_32_Float Log(const Avx2_32_Float& val) noexcept {
561-
return ApplyFunc([](T x) { return std::log(x); }, val);
562-
}
563-
564625
template<bool bDisableApprox,
565626
bool bNegateInput = false,
566627
bool bNaNPossible = true,
@@ -571,7 +632,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
571632
static inline Avx2_32_Float ApproxExp(const Avx2_32_Float& val,
572633
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
573634
UNUSED(addExpSchraudolphTerm);
574-
return Exp(bNegateInput ? -val : val);
635+
return Exp<bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
575636
}
576637

577638
template<bool bDisableApprox,
@@ -631,8 +692,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
631692
static inline Avx2_32_Float ApproxLog(
632693
const Avx2_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
633694
UNUSED(addLogSchraudolphTerm);
634-
Avx2_32_Float ret = Log(val);
635-
return bNegateOutput ? -ret : ret;
695+
return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
636696
}
637697

638698
template<bool bDisableApprox,
@@ -723,6 +783,25 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
723783
static_assert(std::is_standard_layout<Avx2_32_Float>::value && std::is_trivially_copyable<Avx2_32_Float>::value,
724784
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
725785

786+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
787+
inline Avx2_32_Float Exp(const Avx2_32_Float& val) noexcept {
788+
return Exp32<Avx2_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
789+
}
790+
791+
template<bool bNegateOutput,
792+
bool bNaNPossible,
793+
bool bNegativePossible,
794+
bool bZeroPossible,
795+
bool bPositiveInfinityPossible>
796+
inline Avx2_32_Float Log(const Avx2_32_Float& val) noexcept {
797+
return Log32<Avx2_32_Float,
798+
bNegateOutput,
799+
bNaNPossible,
800+
bNegativePossible,
801+
bZeroPossible,
802+
bPositiveInfinityPossible>(val);
803+
}
804+
726805
INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx2_32(
727806
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
728807
const Objective* const pObjective = static_cast<const Objective*>(pObjectiveWrapper->m_pObjective);

shared/libebm/compute/cpu_ebm/cpu_64.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "Registration.hpp"
2323
#include "Objective.hpp"
2424

25+
#include "math.hpp"
2526
#include "approximate_math.hpp"
2627
#include "compute_wrapper.hpp"
2728

@@ -227,6 +228,8 @@ struct Cpu_64_Float final {
227228

228229
friend inline Cpu_64_Float Abs(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::abs(val.m_data)); }
229230

231+
friend inline Cpu_64_Float Round(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::round(val.m_data)); }
232+
230233
friend inline Cpu_64_Float FastApproxReciprocal(const Cpu_64_Float& val) noexcept {
231234
return Cpu_64_Float(T{1.0} / val.m_data);
232235
}

0 commit comments

Comments
 (0)