@@ -74,6 +74,8 @@ struct Cpu_64_Int final {
7474
7575 inline Cpu_64_Int operator +(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data + other.m_data ); }
7676
77+ inline Cpu_64_Int operator -(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data - other.m_data ); }
78+
7779 inline Cpu_64_Int operator *(const T& other) const noexcept { return Cpu_64_Int (m_data * other); }
7880
7981 inline Cpu_64_Int operator >>(int shift) const noexcept { return Cpu_64_Int (m_data >> shift); }
@@ -82,13 +84,28 @@ struct Cpu_64_Int final {
8284
8385 inline Cpu_64_Int operator &(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data & other.m_data ); }
8486
87+ inline Cpu_64_Int operator |(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int (m_data | other.m_data ); }
88+
89+ friend inline Cpu_64_Int IfThenElse (const bool cmp, const Cpu_64_Int& trueVal, const Cpu_64_Int& falseVal) noexcept {
90+ return cmp ? trueVal : falseVal;
91+ }
92+
8593 private:
8694 TPack m_data;
8795};
8896static_assert (std::is_standard_layout<Cpu_64_Int>::value && std::is_trivially_copyable<Cpu_64_Int>::value,
8997 " This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
9098
99+ template <bool bNegateInput = false ,
100+ bool bNaNPossible = true ,
101+ bool bUnderflowPossible = true ,
102+ bool bOverflowPossible = true >
103+ inline Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept ;
104+
91105struct Cpu_64_Float final {
106+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
107+ friend Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept ;
108+
92109 using T = double ;
93110 using TPack = double ;
94111 using TInt = Cpu_64_Int;
@@ -106,6 +123,8 @@ struct Cpu_64_Float final {
106123 inline Cpu_64_Float (const double val) noexcept : m_data(static_cast <T>(val)) {}
107124 inline Cpu_64_Float (const float val) noexcept : m_data(static_cast <T>(val)) {}
108125 inline Cpu_64_Float (const int val) noexcept : m_data(static_cast <T>(val)) {}
126+ inline Cpu_64_Float (const int64_t val) noexcept : m_data(static_cast <T>(val)) {}
127+ explicit Cpu_64_Float (const Cpu_64_Int& val) : m_data(static_cast <T>(val.m_data)) {}
109128
110129 inline Cpu_64_Float operator +() const noexcept { return *this ; }
111130
@@ -179,6 +198,10 @@ struct Cpu_64_Float final {
179198 return Cpu_64_Float (val) / other;
180199 }
181200
201+ friend inline bool operator <=(const Cpu_64_Float& left, const Cpu_64_Float& right) noexcept {
202+ return left.m_data <= right.m_data ;
203+ }
204+
182205 inline static Cpu_64_Float Load (const T* const a) noexcept { return Cpu_64_Float (*a); }
183206
184207 inline void Store (T* const a) const noexcept { *a = m_data; }
@@ -207,6 +230,11 @@ struct Cpu_64_Float final {
207230 return cmp1.m_data < cmp2.m_data ? trueVal : falseVal;
208231 }
209232
233+ friend inline Cpu_64_Float IfThenElse (
234+ const bool cmp, const Cpu_64_Float& trueVal, const Cpu_64_Float& falseVal) noexcept {
235+ return cmp ? trueVal : falseVal;
236+ }
237+
210238 friend inline Cpu_64_Float IfEqual (const Cpu_64_Float& cmp1,
211239 const Cpu_64_Float& cmp2,
212240 const Cpu_64_Float& trueVal,
@@ -226,10 +254,24 @@ struct Cpu_64_Float final {
226254 return cmp1.m_data == cmp2.m_data ? trueVal : falseVal;
227255 }
228256
229- friend inline Cpu_64_Float Abs (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::abs (val.m_data )); }
257+ static inline bool ReinterpretInt (const bool val) noexcept { return val; }
258+
259+ static inline Cpu_64_Int ReinterpretInt (const Cpu_64_Float& val) noexcept {
260+ typename Cpu_64_Int::T mem;
261+ memcpy (&mem, &val.m_data , sizeof (T));
262+ return Cpu_64_Int (mem);
263+ }
264+
265+ static inline Cpu_64_Float ReinterpretFloat (const Cpu_64_Int& val) noexcept {
266+ T mem;
267+ memcpy (&mem, &val.m_data , sizeof (T));
268+ return Cpu_64_Float (mem);
269+ }
230270
231271 friend inline Cpu_64_Float Round (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::round (val.m_data )); }
232272
273+ friend inline Cpu_64_Float Abs (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::abs (val.m_data )); }
274+
233275 friend inline Cpu_64_Float FastApproxReciprocal (const Cpu_64_Float& val) noexcept {
234276 return Cpu_64_Float (T{1.0 } / val.m_data );
235277 }
@@ -250,8 +292,6 @@ struct Cpu_64_Float final {
250292
251293 friend inline Cpu_64_Float Sqrt (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::sqrt (val.m_data )); }
252294
253- friend inline Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::exp (val.m_data )); }
254-
255295 friend inline Cpu_64_Float Log (const Cpu_64_Float& val) noexcept { return Cpu_64_Float (std::log (val.m_data )); }
256296
257297 template <bool bDisableApprox,
@@ -264,7 +304,7 @@ struct Cpu_64_Float final {
264304 static inline Cpu_64_Float ApproxExp (const Cpu_64_Float& val,
265305 const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
266306 UNUSED (addExpSchraudolphTerm);
267- return Exp ( bNegateInput ? -val : val);
307+ return Exp< bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>( val);
268308 }
269309
270310 template <bool bDisableApprox,
@@ -354,6 +394,11 @@ struct Cpu_64_Float final {
354394static_assert (std::is_standard_layout<Cpu_64_Float>::value && std::is_trivially_copyable<Cpu_64_Float>::value,
355395 " This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
356396
397+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
398+ inline Cpu_64_Float Exp (const Cpu_64_Float& val) noexcept {
399+ return Exp64<Cpu_64_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
400+ }
401+
357402INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Cpu_64 (
358403 const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
359404 const Objective* const pObjective = static_cast <const Objective*>(pObjectiveWrapper->m_pObjective );
0 commit comments