Skip to content

Commit 4401ae3

Browse files
committed
SIMD style exp for doubles
1 parent db714e3 commit 4401ae3

File tree

2 files changed

+154
-6
lines changed

2 files changed

+154
-6
lines changed

shared/libebm/compute/cpu_ebm/cpu_64.cpp

+49-4
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct Cpu_64_Int final {
7474

7575
inline Cpu_64_Int operator+(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data + other.m_data); }
7676

77+
inline Cpu_64_Int operator-(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data - other.m_data); }
78+
7779
inline Cpu_64_Int operator*(const T& other) const noexcept { return Cpu_64_Int(m_data * other); }
7880

7981
inline Cpu_64_Int operator>>(int shift) const noexcept { return Cpu_64_Int(m_data >> shift); }
@@ -82,13 +84,28 @@ struct Cpu_64_Int final {
8284

8385
inline Cpu_64_Int operator&(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data & other.m_data); }
8486

87+
inline Cpu_64_Int operator|(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data | other.m_data); }
88+
89+
friend inline Cpu_64_Int IfThenElse(const bool cmp, const Cpu_64_Int& trueVal, const Cpu_64_Int& falseVal) noexcept {
90+
return cmp ? trueVal : falseVal;
91+
}
92+
8593
private:
8694
TPack m_data;
8795
};
8896
static_assert(std::is_standard_layout<Cpu_64_Int>::value && std::is_trivially_copyable<Cpu_64_Int>::value,
8997
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
9098

99+
template<bool bNegateInput = false,
100+
bool bNaNPossible = true,
101+
bool bUnderflowPossible = true,
102+
bool bOverflowPossible = true>
103+
inline Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept;
104+
91105
struct Cpu_64_Float final {
106+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
107+
friend Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept;
108+
92109
using T = double;
93110
using TPack = double;
94111
using TInt = Cpu_64_Int;
@@ -106,6 +123,8 @@ struct Cpu_64_Float final {
106123
inline Cpu_64_Float(const double val) noexcept : m_data(static_cast<T>(val)) {}
107124
inline Cpu_64_Float(const float val) noexcept : m_data(static_cast<T>(val)) {}
108125
inline Cpu_64_Float(const int val) noexcept : m_data(static_cast<T>(val)) {}
126+
inline Cpu_64_Float(const int64_t val) noexcept : m_data(static_cast<T>(val)) {}
127+
explicit Cpu_64_Float(const Cpu_64_Int& val) : m_data(static_cast<T>(val.m_data)) {}
109128

110129
inline Cpu_64_Float operator+() const noexcept { return *this; }
111130

@@ -179,6 +198,10 @@ struct Cpu_64_Float final {
179198
return Cpu_64_Float(val) / other;
180199
}
181200

201+
friend inline bool operator<=(const Cpu_64_Float& left, const Cpu_64_Float& right) noexcept {
202+
return left.m_data <= right.m_data;
203+
}
204+
182205
inline static Cpu_64_Float Load(const T* const a) noexcept { return Cpu_64_Float(*a); }
183206

184207
inline void Store(T* const a) const noexcept { *a = m_data; }
@@ -207,6 +230,11 @@ struct Cpu_64_Float final {
207230
return cmp1.m_data < cmp2.m_data ? trueVal : falseVal;
208231
}
209232

233+
friend inline Cpu_64_Float IfThenElse(
234+
const bool cmp, const Cpu_64_Float& trueVal, const Cpu_64_Float& falseVal) noexcept {
235+
return cmp ? trueVal : falseVal;
236+
}
237+
210238
friend inline Cpu_64_Float IfEqual(const Cpu_64_Float& cmp1,
211239
const Cpu_64_Float& cmp2,
212240
const Cpu_64_Float& trueVal,
@@ -226,10 +254,24 @@ struct Cpu_64_Float final {
226254
return cmp1.m_data == cmp2.m_data ? trueVal : falseVal;
227255
}
228256

229-
friend inline Cpu_64_Float Abs(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::abs(val.m_data)); }
257+
static inline bool ReinterpretInt(const bool val) noexcept { return val; }
258+
259+
static inline Cpu_64_Int ReinterpretInt(const Cpu_64_Float& val) noexcept {
260+
typename Cpu_64_Int::T mem;
261+
memcpy(&mem, &val.m_data, sizeof(T));
262+
return Cpu_64_Int(mem);
263+
}
264+
265+
static inline Cpu_64_Float ReinterpretFloat(const Cpu_64_Int& val) noexcept {
266+
T mem;
267+
memcpy(&mem, &val.m_data, sizeof(T));
268+
return Cpu_64_Float(mem);
269+
}
230270

231271
friend inline Cpu_64_Float Round(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::round(val.m_data)); }
232272

273+
friend inline Cpu_64_Float Abs(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::abs(val.m_data)); }
274+
233275
friend inline Cpu_64_Float FastApproxReciprocal(const Cpu_64_Float& val) noexcept {
234276
return Cpu_64_Float(T{1.0} / val.m_data);
235277
}
@@ -250,8 +292,6 @@ struct Cpu_64_Float final {
250292

251293
friend inline Cpu_64_Float Sqrt(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::sqrt(val.m_data)); }
252294

253-
friend inline Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::exp(val.m_data)); }
254-
255295
friend inline Cpu_64_Float Log(const Cpu_64_Float& val) noexcept { return Cpu_64_Float(std::log(val.m_data)); }
256296

257297
template<bool bDisableApprox,
@@ -264,7 +304,7 @@ struct Cpu_64_Float final {
264304
static inline Cpu_64_Float ApproxExp(const Cpu_64_Float& val,
265305
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
266306
UNUSED(addExpSchraudolphTerm);
267-
return Exp(bNegateInput ? -val : val);
307+
return Exp<bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
268308
}
269309

270310
template<bool bDisableApprox,
@@ -354,6 +394,11 @@ struct Cpu_64_Float final {
354394
static_assert(std::is_standard_layout<Cpu_64_Float>::value && std::is_trivially_copyable<Cpu_64_Float>::value,
355395
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
356396

397+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
398+
inline Cpu_64_Float Exp(const Cpu_64_Float& val) noexcept {
399+
return Exp64<Cpu_64_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
400+
}
401+
357402
INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Cpu_64(
358403
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
359404
const Objective* const pObjective = static_cast<const Objective*>(pObjectiveWrapper->m_pObjective);

shared/libebm/compute/math.hpp

+105-2
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ template<typename TFloat> static INLINE_ALWAYS typename TFloat::TInt Exponent32(
2222
return ((TFloat::ReinterpretInt(val) << 1) >> 24) - typename TFloat::TInt(0x7F);
2323
}
2424

25-
template<typename TFloat> static INLINE_ALWAYS TFloat Power2(const TFloat val) {
25+
template<typename TFloat> static INLINE_ALWAYS TFloat Power32(const TFloat val) {
2626
return TFloat::ReinterpretFloat(TFloat::ReinterpretInt(val + TFloat{8388608 + 127}) << 23);
2727
}
2828

29+
template<typename TFloat> static INLINE_ALWAYS TFloat Power64(const TFloat val) {
30+
return TFloat::ReinterpretFloat(TFloat::ReinterpretInt(val + TFloat{4503599627370496 + 1023}) << 52);
31+
}
32+
2933
template<typename TFloat>
3034
static INLINE_ALWAYS TFloat Polynomial(const TFloat x,
3135
const TFloat c0,
@@ -60,6 +64,32 @@ static INLINE_ALWAYS TFloat Polynomial(const TFloat x,
6064
FusedMultiplyAdd(FusedMultiplyAdd(c3, x, c2), x2, FusedMultiplyAdd(c1, x, c0) + c8 * x8));
6165
}
6266

67+
template<typename TFloat>
68+
static INLINE_ALWAYS TFloat Polynomial(const TFloat x,
69+
const TFloat c2,
70+
const TFloat c3,
71+
const TFloat c4,
72+
const TFloat c5,
73+
const TFloat c6,
74+
const TFloat c7,
75+
const TFloat c8,
76+
const TFloat c9,
77+
const TFloat c10,
78+
const TFloat c11,
79+
const TFloat c12,
80+
const TFloat c13) {
81+
TFloat x2 = x * x;
82+
TFloat x4 = x2 * x2;
83+
TFloat x8 = x4 * x4;
84+
return FusedMultiplyAdd(FusedMultiplyAdd(FusedMultiplyAdd(c13, x, c12),
85+
x4,
86+
FusedMultiplyAdd(FusedMultiplyAdd(c11, x, c10), x2, FusedMultiplyAdd(c9, x, c8))),
87+
x8,
88+
FusedMultiplyAdd(FusedMultiplyAdd(FusedMultiplyAdd(c7, x, c6), x2, FusedMultiplyAdd(c5, x, c4)),
89+
x4,
90+
FusedMultiplyAdd(FusedMultiplyAdd(c3, x, c2), x2, x)));
91+
}
92+
6393
template<typename TFloat,
6494
bool bNegateInput = false,
6595
bool bNaNPossible = true,
@@ -89,7 +119,7 @@ static INLINE_ALWAYS TFloat Exp32(const TFloat val) {
89119
TFloat{1} / TFloat{5040});
90120
ret = FusedMultiplyAdd(ret, x2, x);
91121

92-
const TFloat rounded2 = Power2(rounded);
122+
const TFloat rounded2 = Power32(rounded);
93123

94124
ret = (ret + TFloat{1}) * rounded2;
95125

@@ -210,6 +240,79 @@ static INLINE_ALWAYS TFloat Log32(const TFloat& val) noexcept {
210240
return ret;
211241
}
212242

243+
template<typename TFloat,
244+
bool bNegateInput = false,
245+
bool bNaNPossible = true,
246+
bool bUnderflowPossible = true,
247+
bool bOverflowPossible = true>
248+
static INLINE_ALWAYS TFloat Exp64(const TFloat val) {
249+
// algorithm comes from:
250+
// https://github.com/vectorclass/version2/blob/f4617df57e17efcd754f5bbe0ec87883e0ed9ce6/vectormath_exp.h#L327
251+
252+
// k_expUnderflow is set to a value that prevents us from returning a denormal number.
253+
static constexpr float k_expUnderflow = -708.25f; // this is exactly representable in IEEE 754
254+
static constexpr float k_expOverflow = 708.25f; // this is exactly representable in IEEE 754
255+
256+
// TODO: make this negation more efficient
257+
TFloat x = bNegateInput ? -val : val;
258+
const TFloat rounded = Round(x * TFloat{1.44269504088896340736});
259+
x = FusedNegateMultiplyAdd(rounded, TFloat{0.693145751953125}, x);
260+
x = FusedNegateMultiplyAdd(rounded, TFloat{1.42860682030941723212E-6}, x);
261+
262+
TFloat ret = Polynomial(x,
263+
TFloat{1} / TFloat{2},
264+
TFloat{1} / TFloat{6},
265+
TFloat{1} / TFloat{24},
266+
TFloat{1} / TFloat{120},
267+
TFloat{1} / TFloat{720},
268+
TFloat{1} / TFloat{5040},
269+
TFloat{1} / TFloat{40320},
270+
TFloat{1} / TFloat{362880},
271+
TFloat{1} / TFloat{3628800},
272+
TFloat{1} / TFloat{39916800},
273+
TFloat{1} / TFloat{479001600},
274+
TFloat{1} / TFloat{int64_t{6227020800}});
275+
276+
const TFloat rounded2 = Power64(rounded);
277+
278+
ret = (ret + TFloat{1}) * rounded2;
279+
280+
if(bOverflowPossible) {
281+
if(bNegateInput) {
282+
ret = IfLess(val,
283+
static_cast<typename TFloat::T>(-k_expOverflow),
284+
std::numeric_limits<typename TFloat::T>::infinity(),
285+
ret);
286+
} else {
287+
ret = IfLess(static_cast<typename TFloat::T>(k_expOverflow),
288+
val,
289+
std::numeric_limits<typename TFloat::T>::infinity(),
290+
ret);
291+
}
292+
}
293+
if(bUnderflowPossible) {
294+
if(bNegateInput) {
295+
ret = IfLess(static_cast<typename TFloat::T>(-k_expUnderflow), val, 0.0f, ret);
296+
} else {
297+
ret = IfLess(val, static_cast<typename TFloat::T>(k_expUnderflow), 0.0f, ret);
298+
}
299+
}
300+
if(bNaNPossible) {
301+
ret = IfNaN(val, val, ret);
302+
}
303+
304+
#ifndef NDEBUG
305+
TFloat::Execute(
306+
[](int, typename TFloat::T orig, typename TFloat::T ret) {
307+
EBM_ASSERT(IsApproxEqual(std::exp(orig), ret, typename TFloat::T{1e-12}));
308+
},
309+
bNegateInput ? -val : val,
310+
ret);
311+
#endif // NDEBUG
312+
313+
return ret;
314+
}
315+
213316
} // namespace DEFINED_ZONE_NAME
214317

215318
#endif // REGISTRATION_HPP

0 commit comments

Comments
 (0)