@@ -74,6 +74,8 @@ struct Cpu_64_Int final {
74
74
75
75
inline Cpu_64_Int operator +(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data + other.m_data ); }
76
76
77
+ inline Cpu_64_Int operator -(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data - other.m_data ); }
78
+
77
79
inline Cpu_64_Int operator *(const T& other) const noexcept { return Cpu_64_Int (m_data * other); }
78
80
79
81
inline Cpu_64_Int operator >>(int shift) const noexcept { return Cpu_64_Int (m_data >> shift); }
@@ -82,13 +84,28 @@ struct Cpu_64_Int final {
82
84
83
85
inline Cpu_64_Int operator &(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data & other.m_data ); }
84
86
87
+ inline Cpu_64_Int operator |(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data | other.m_data ); }
88
+
89
+ friend inline Cpu_64_Int IfThenElse (const bool cmp, const Cpu_64_Int& trueVal, const Cpu_64_Int& falseVal) noexcept {
90
+ return cmp ? trueVal : falseVal;
91
+ }
92
+
85
93
private:
86
94
TPack m_data;
87
95
};
88
96
static_assert (std::is_standard_layout<Cpu_64_Int>::value && std::is_trivially_copyable<Cpu_64_Int>::value,
89
97
" This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
90
98
99
+ template <bool bNegateInput = false ,
100
+ bool bNaNPossible = true ,
101
+ bool bUnderflowPossible = true ,
102
+ bool bOverflowPossible = true >
103
+ inline Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept ;
104
+
91
105
struct Cpu_64_Float final {
106
+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
107
+ friend Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept ;
108
+
92
109
using T = double ;
93
110
using TPack = double ;
94
111
using TInt = Cpu_64_Int;
@@ -106,6 +123,8 @@ struct Cpu_64_Float final {
106
123
inline Cpu_64_Float (const double val) noexcept : m_data(static_cast <T>(val)) {}
107
124
inline Cpu_64_Float (const float val) noexcept : m_data(static_cast <T>(val)) {}
108
125
inline Cpu_64_Float (const int val) noexcept : m_data(static_cast <T>(val)) {}
126
+ inline Cpu_64_Float (const int64_t val) noexcept : m_data(static_cast <T>(val)) {}
127
+ explicit Cpu_64_Float (const Cpu_64_Int& val) : m_data(static_cast <T>(val.m_data)) {}
109
128
110
129
inline Cpu_64_Float operator +() const noexcept { return *this ; }
111
130
@@ -179,6 +198,10 @@ struct Cpu_64_Float final {
179
198
return Cpu_64_Float (val) / other;
180
199
}
181
200
201
+ friend inline bool operator <=(const Cpu_64_Float& left, const Cpu_64_Float& right) noexcept {
202
+ return left.m_data <= right.m_data ;
203
+ }
204
+
182
205
inline static Cpu_64_Float Load (const T* const a) noexcept { return Cpu_64_Float (*a); }
183
206
184
207
inline void Store (T* const a) const noexcept { *a = m_data; }
@@ -207,6 +230,11 @@ struct Cpu_64_Float final {
207
230
return cmp1.m_data < cmp2.m_data ? trueVal : falseVal;
208
231
}
209
232
233
+ friend inline Cpu_64_Float IfThenElse (
234
+ const bool cmp, const Cpu_64_Float& trueVal, const Cpu_64_Float& falseVal) noexcept {
235
+ return cmp ? trueVal : falseVal;
236
+ }
237
+
210
238
friend inline Cpu_64_Float IfEqual (const Cpu_64_Float& cmp1,
211
239
const Cpu_64_Float& cmp2,
212
240
const Cpu_64_Float& trueVal,
@@ -226,10 +254,24 @@ struct Cpu_64_Float final {
226
254
return cmp1.m_data == cmp2.m_data ? trueVal : falseVal;
227
255
}
228
256
229
- friend inline Cpu_64_Float Abs (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::abs (val.m_data )); }
257
+ static inline bool ReinterpretInt (const bool val) noexcept { return val; }
258
+
259
+ static inline Cpu_64_Int ReinterpretInt (const Cpu_64_Float& val) noexcept {
260
+ typename Cpu_64_Int::T mem;
261
+ memcpy (&mem, &val.m_data , sizeof (T));
262
+ return Cpu_64_Int (mem);
263
+ }
264
+
265
+ static inline Cpu_64_Float ReinterpretFloat (const Cpu_64_Int& val) noexcept {
266
+ T mem;
267
+ memcpy (&mem, &val.m_data , sizeof (T));
268
+ return Cpu_64_Float (mem);
269
+ }
230
270
231
271
friend inline Cpu_64_Float Round (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::round (val.m_data )); }
232
272
273
+ friend inline Cpu_64_Float Abs (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::abs (val.m_data )); }
274
+
233
275
friend inline Cpu_64_Float FastApproxReciprocal (const Cpu_64_Float& val) noexcept {
234
276
return Cpu_64_Float (T{1.0 } / val.m_data );
235
277
}
@@ -250,8 +292,6 @@ struct Cpu_64_Float final {
250
292
251
293
friend inline Cpu_64_Float Sqrt (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::sqrt (val.m_data )); }
252
294
253
- friend inline Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::exp (val.m_data )); }
254
-
255
295
friend inline Cpu_64_Float Log (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::log (val.m_data )); }
256
296
257
297
template <bool bDisableApprox,
@@ -264,7 +304,7 @@ struct Cpu_64_Float final {
264
304
static inline Cpu_64_Float ApproxExp (const Cpu_64_Float& val,
265
305
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
266
306
UNUSED (addExpSchraudolphTerm);
267
- return Exp ( bNegateInput ? -val : val);
307
+ return Exp< bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>( val);
268
308
}
269
309
270
310
template <bool bDisableApprox,
@@ -354,6 +394,11 @@ struct Cpu_64_Float final {
354
394
static_assert (std::is_standard_layout<Cpu_64_Float>::value && std::is_trivially_copyable<Cpu_64_Float>::value,
355
395
" This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
356
396
397
+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
398
+ inline Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept {
399
+ return Exp64<Cpu_64_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
400
+ }
401
+
357
402
INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Cpu_64 (
358
403
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
359
404
const Objective* const pObjective = static_cast <const Objective*>(pObjectiveWrapper->m_pObjective );
0 commit comments