Skip to content

Commit 05e3b26

Browse files
committed
SIMD style exp for doubles
1 parent db714e3 commit 05e3b26

File tree

2 files changed

+155
-6
lines changed

2 files changed

+155
-6
lines changed

shared/libebm/compute/cpu_ebm/cpu_64.cpp

+49-4
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct Cpu_64_Int final {
7474

7575
inline Cpu_64_Int operator+(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data + other.m_data); }
7676

77+
inline Cpu_64_Int operator-(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data - other.m_data); }
78+
7779
inline Cpu_64_Int operator*(const T& other) const noexcept { return Cpu_64_Int(m_data * other); }
7880

7981
inline Cpu_64_Int operator>>(int shift) const noexcept { return Cpu_64_Int(m_data >> shift); }
@@ -82,13 +84,28 @@ struct Cpu_64_Int final {
8284

8385
inline Cpu_64_Int operator&(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data & other.m_data); }
8486

87+
inline Cpu_64_Int operator|(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data | other.m_data); }
88+
89+
friend inline Cpu_64_Int IfThenElse(const bool cmp, const Cpu_64_Int& trueVal, const Cpu_64_Int& falseVal) noexcept {
90+
return cmp ? trueVal : falseVal;
91+
}
92+
8593
private:
8694
TPack m_data;
8795
};
8896
static_assert(std::is_standard_layout<Cpu_64_Int>::value && std::is_trivially_copyable<Cpu_64_Int>::value,
8997
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
9098

99+
template<bool bNegateInput = false,
100+
bool bNaNPossible = true,
101+
bool bUnderflowPossible = true,
102+
bool bOverflowPossible = true>
103+
inline Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept;
104+
91105
struct Cpu_64_Float final {
106+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
107+
friend Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept;
108+
92109
using T = double;
93110
using TPack = double;
94111
using TInt = Cpu_64_Int;
@@ -106,6 +123,8 @@ struct Cpu_64_Float final {
106123
inline Cpu_64_Float(const double val) noexcept : m_data(static_cast<T>(val)) {}
107124
inline Cpu_64_Float(const float val) noexcept : m_data(static_cast<T>(val)) {}
108125
inline Cpu_64_Float(const int val) noexcept : m_data(static_cast<T>(val)) {}
126+
inline Cpu_64_Float(const int64_t val) noexcept : m_data(static_cast<T>(val)) {}
127+
explicit Cpu_64_Float(const Cpu_64_Int& val) : m_data(static_cast<T>(val.m_data)) {}
109128

110129
inline Cpu_64_Float operator+() const noexcept { return *this; }
111130

@@ -179,6 +198,10 @@ struct Cpu_64_Float final {
179198
return Cpu_64_Float(val) / other;
180199
}
181200

201+
friend inline bool operator<=(const Cpu_64_Float& left, const Cpu_64_Float& right) noexcept {
202+
return left.m_data <= right.m_data;
203+
}
204+
182205
inline static Cpu_64_Float Load(const T* const a) noexcept { return Cpu_64_Float(*a); }
183206

184207
inline void Store(T* const a) const noexcept { *a = m_data; }
@@ -207,6 +230,11 @@ struct Cpu_64_Float final {
207230
return cmp1.m_data < cmp2.m_data ? trueVal : falseVal;
208231
}
209232

233+
friend inline Cpu_64_Float IfThenElse(
234+
const bool cmp, const Cpu_64_Float& trueVal, const Cpu_64_Float& falseVal) noexcept {
235+
return cmp ? trueVal : falseVal;
236+
}
237+
210238
friend inline Cpu_64_Float IfEqual(const Cpu_64_Float& cmp1,
211239
const Cpu_64_Float& cmp2,
212240
const Cpu_64_Float& trueVal,
@@ -226,10 +254,24 @@ struct Cpu_64_Float final {
226254
return cmp1.m_data == cmp2.m_data ? trueVal : falseVal;
227255
}
228256

229-
friend inline Cpu_64_Float Abs(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::abs(val.m_data)); }
257+
static inline bool ReinterpretInt(const bool val) noexcept { return val; }
258+
259+
static inline Cpu_64_Int ReinterpretInt(const Cpu_64_Float& val) noexcept {
260+
typename Cpu_64_Int::T mem;
261+
memcpy(&mem, &val.m_data, sizeof(T));
262+
return Cpu_64_Int(mem);
263+
}
264+
265+
static inline Cpu_64_Float ReinterpretFloat(const Cpu_64_Int& val) noexcept {
266+
T mem;
267+
memcpy(&mem, &val.m_data, sizeof(T));
268+
return Cpu_64_Float(mem);
269+
}
230270

231271
friend inline Cpu_64_Float Round(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::round(val.m_data)); }
232272

273+
friend inline Cpu_64_Float Abs(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::abs(val.m_data)); }
274+
233275
friend inline Cpu_64_Float FastApproxReciprocal(const Cpu_64_Float& val) noexcept {
234276
return Cpu_64_Float(T{1.0} / val.m_data);
235277
}
@@ -250,8 +292,6 @@ struct Cpu_64_Float final {
250292

251293
friend inline Cpu_64_Float Sqrt(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::sqrt(val.m_data)); }
252294

253-
friend inline Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::exp(val.m_data)); }
254-
255295
friend inline Cpu_64_Float Log(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::log(val.m_data)); }
256296

257297
template<bool bDisableApprox,
@@ -264,7 +304,7 @@ struct Cpu_64_Float final {
264304
static inline Cpu_64_Float ApproxExp(const Cpu_64_Float& val,
265305
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
266306
UNUSED(addExpSchraudolphTerm);
267-
return Exp(bNegateInput ? -val : val);
307+
return Exp<bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
268308
}
269309

270310
template<bool bDisableApprox,
@@ -354,6 +394,11 @@ struct Cpu_64_Float final {
354394
static_assert(std::is_standard_layout<Cpu_64_Float>::value && std::is_trivially_copyable<Cpu_64_Float>::value,
355395
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
356396

397+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
398+
inline Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept {
399+
return Exp64<Cpu_64_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
400+
}
401+
357402
INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Cpu_64(
358403
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
359404
const Objective* const pObjective = static_cast<const Objective*>(pObjectiveWrapper->m_pObjective);

shared/libebm/compute/math.hpp

+106-2
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@ template<typename TFloat> static INLINE_ALWAYS typename TFloat::TInt Exponent32(
2222
return ((TFloat::ReinterpretInt(val) << 1) >> 24) - typename TFloat::TInt(0x7F);
2323
}
2424

25-
template<typename TFloat> static INLINE_ALWAYS TFloat Power2(const TFloat val) {
25+
template<typename TFloat> static INLINE_ALWAYS TFloat Power32(const TFloat val) {
2626
return TFloat::ReinterpretFloat(TFloat::ReinterpretInt(val + TFloat{8388608 + 127}) << 23);
2727
}
2828

29+
template<typename TFloat> static INLINE_ALWAYS TFloat Power64(const TFloat val) {
30+
return TFloat::ReinterpretFloat(
31+
TFloat::ReinterpretInt(val + TFloat{int64_t{4503599627370496} + int64_t{1023}}) << 52);
32+
}
33+
2934
template<typename TFloat>
3035
static INLINE_ALWAYS TFloat Polynomial(const TFloat x,
3136
const TFloat c0,
@@ -60,6 +65,32 @@ static INLINE_ALWAYS TFloat Polynomial(const TFloat x,
6065
FusedMultiplyAdd(FusedMultiplyAdd(c3, x, c2), x2, FusedMultiplyAdd(c1, x, c0) + c8 * x8));
6166
}
6267

68+
template<typename TFloat>
69+
static INLINE_ALWAYS TFloat Polynomial(const TFloat x,
70+
const TFloat c2,
71+
const TFloat c3,
72+
const TFloat c4,
73+
const TFloat c5,
74+
const TFloat c6,
75+
const TFloat c7,
76+
const TFloat c8,
77+
const TFloat c9,
78+
const TFloat c10,
79+
const TFloat c11,
80+
const TFloat c12,
81+
const TFloat c13) {
82+
TFloat x2 = x * x;
83+
TFloat x4 = x2 * x2;
84+
TFloat x8 = x4 * x4;
85+
return FusedMultiplyAdd(FusedMultiplyAdd(FusedMultiplyAdd(c13, x, c12),
86+
x4,
87+
FusedMultiplyAdd(FusedMultiplyAdd(c11, x, c10), x2, FusedMultiplyAdd(c9, x, c8))),
88+
x8,
89+
FusedMultiplyAdd(FusedMultiplyAdd(FusedMultiplyAdd(c7, x, c6), x2, FusedMultiplyAdd(c5, x, c4)),
90+
x4,
91+
FusedMultiplyAdd(FusedMultiplyAdd(c3, x, c2), x2, x)));
92+
}
93+
6394
template<typename TFloat,
6495
bool bNegateInput = false,
6596
bool bNaNPossible = true,
@@ -89,7 +120,7 @@ static INLINE_ALWAYS TFloat Exp32(const TFloat val) {
89120
TFloat{1} / TFloat{5040});
90121
ret = FusedMultiplyAdd(ret, x2, x);
91122

92-
const TFloat rounded2 = Power2(rounded);
123+
const TFloat rounded2 = Power32(rounded);
93124

94125
ret = (ret + TFloat{1}) * rounded2;
95126

@@ -210,6 +241,79 @@ static INLINE_ALWAYS TFloat Log32(const TFloat& val) noexcept {
210241
return ret;
211242
}
212243

244+
template<typename TFloat,
245+
bool bNegateInput = false,
246+
bool bNaNPossible = true,
247+
bool bUnderflowPossible = true,
248+
bool bOverflowPossible = true>
249+
static INLINE_ALWAYS TFloat Exp64(const TFloat val) {
250+
// algorithm comes from:
251+
// https://github.com/vectorclass/version2/blob/f4617df57e17efcd754f5bbe0ec87883e0ed9ce6/vectormath_exp.h#L327
252+
253+
// k_expUnderflow is set to a value that prevents us from returning a denormal number.
254+
static constexpr float k_expUnderflow = -708.25f; // this is exactly representable in IEEE 754
255+
static constexpr float k_expOverflow = 708.25f; // this is exactly representable in IEEE 754
256+
257+
// TODO: make this negation more efficient
258+
TFloat x = bNegateInput ? -val : val;
259+
const TFloat rounded = Round(x * TFloat{1.44269504088896340736});
260+
x = FusedNegateMultiplyAdd(rounded, TFloat{0.693145751953125}, x);
261+
x = FusedNegateMultiplyAdd(rounded, TFloat{1.42860682030941723212E-6}, x);
262+
263+
TFloat ret = Polynomial(x,
264+
TFloat{1} / TFloat{2},
265+
TFloat{1} / TFloat{6},
266+
TFloat{1} / TFloat{24},
267+
TFloat{1} / TFloat{120},
268+
TFloat{1} / TFloat{720},
269+
TFloat{1} / TFloat{5040},
270+
TFloat{1} / TFloat{40320},
271+
TFloat{1} / TFloat{362880},
272+
TFloat{1} / TFloat{3628800},
273+
TFloat{1} / TFloat{39916800},
274+
TFloat{1} / TFloat{479001600},
275+
TFloat{1} / TFloat{int64_t{6227020800}});
276+
277+
const TFloat rounded2 = Power64(rounded);
278+
279+
ret = (ret + TFloat{1}) * rounded2;
280+
281+
if(bOverflowPossible) {
282+
if(bNegateInput) {
283+
ret = IfLess(val,
284+
static_cast<typename TFloat::T>(-k_expOverflow),
285+
std::numeric_limits<typename TFloat::T>::infinity(),
286+
ret);
287+
} else {
288+
ret = IfLess(static_cast<typename TFloat::T>(k_expOverflow),
289+
val,
290+
std::numeric_limits<typename TFloat::T>::infinity(),
291+
ret);
292+
}
293+
}
294+
if(bUnderflowPossible) {
295+
if(bNegateInput) {
296+
ret = IfLess(static_cast<typename TFloat::T>(-k_expUnderflow), val, 0.0f, ret);
297+
} else {
298+
ret = IfLess(val, static_cast<typename TFloat::T>(k_expUnderflow), 0.0f, ret);
299+
}
300+
}
301+
if(bNaNPossible) {
302+
ret = IfNaN(val, val, ret);
303+
}
304+
305+
#ifndef NDEBUG
306+
TFloat::Execute(
307+
[](int, typename TFloat::T orig, typename TFloat::T ret) {
308+
EBM_ASSERT(IsApproxEqual(std::exp(orig), ret, typename TFloat::T{1e-12}));
309+
},
310+
bNegateInput ? -val : val,
311+
ret);
312+
#endif // NDEBUG
313+
314+
return ret;
315+
}
316+
213317
} // namespace DEFINED_ZONE_NAME
214318

215319
#endif // REGISTRATION_HPP

0 commit comments

Comments
 (0)