Skip to content

Commit 6017280

Browse files
alpaka::math custom implementations for CPU backend (#2525)
Co-authored-by: mehmet yusufoglu <[email protected]>
1 parent 748f830 commit 6017280

File tree

4 files changed

+212
-4
lines changed

4 files changed

+212
-4
lines changed

include/alpaka/core/BitCast.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright 2025 Andrea Bocci
2+
* SPDX-License-Identifier: MPL-2.0
3+
*/
4+
5+
#pragma once
6+
7+
#include <cstring>
8+
#include <type_traits>
9+
10+
namespace alpaka::core
11+
{
12+
//! From https://en.cppreference.com/w/cpp/numeric/bit_cast.html
13+
template<class To, class From>
14+
std::enable_if_t<
15+
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> && std::is_trivially_copyable_v<To>,
16+
To>
17+
bit_cast(From const& src) noexcept
18+
{
19+
std::aligned_storage_t<sizeof(To), alignof(To)> dst;
20+
std::memcpy(&dst, &src, sizeof(To));
21+
return *reinterpret_cast<To*>(&dst);
22+
}
23+
24+
} // namespace alpaka::core

include/alpaka/math/MathStdLib.hpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55

66
#pragma once
77

8+
#include "alpaka/core/BitCast.hpp"
89
#include "alpaka/core/Decay.hpp"
10+
#include "alpaka/core/Unreachable.hpp"
911
#include "alpaka/math/Traits.hpp"
1012

13+
#include <cstdint>
14+
#include <type_traits>
15+
1116
namespace alpaka::math
1217
{
1318
//! The standard library abs, implementation covered by the general template.
@@ -294,6 +299,99 @@ namespace alpaka::math
294299
ALPAKA_UNREACHABLE(std::common_type_t<Tx, Ty>{});
295300
}
296301
};
302+
303+
//! Custom IEEE 754 bitwise implementation of isfinite.
304+
//! std counterpart does not work correctly for some compiler flags at CPU backend
305+
template<typename TArg>
306+
struct Isfinite<IsfiniteStdLib, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
307+
{
308+
auto operator()(IsfiniteStdLib const& /* ctx */, TArg const& arg) -> bool
309+
{
310+
if constexpr(std::is_same_v<TArg, float>)
311+
{
312+
constexpr std::uint32_t expMask = 0x7F80'0000;
313+
std::uint32_t bits = alpaka::core::bit_cast<std::uint32_t>(arg);
314+
bool result = (bits & expMask) != expMask;
315+
return result;
316+
}
317+
else if constexpr(std::is_same_v<TArg, double>)
318+
{
319+
constexpr std::uint64_t expMask = 0x7FF0'0000'0000'0000ULL;
320+
std::uint64_t bits = alpaka::core::bit_cast<std::uint64_t>(arg);
321+
bool result = (bits & expMask) != expMask;
322+
return result;
323+
}
324+
else
325+
{
326+
static_assert(!sizeof(TArg), "Unsupported floating-point type");
327+
}
328+
ALPAKA_UNREACHABLE(false);
329+
}
330+
};
331+
332+
//! Custom IEEE 754 bitwise implementation of isinf
333+
//! std counterpart does not work correctly for some compiler flags at CPU backend
334+
template<typename TArg>
335+
struct Isinf<IsinfStdLib, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
336+
{
337+
auto operator()(IsinfStdLib const& /* ctx */, TArg const& arg) -> bool
338+
{
339+
if constexpr(std::is_same_v<TArg, float>)
340+
{
341+
constexpr std::uint32_t expMask = 0x7F80'0000;
342+
constexpr std::uint32_t fracMask = 0x007F'FFFF;
343+
std::uint32_t bits = alpaka::core::bit_cast<std::uint32_t>(arg);
344+
bool result = ((bits & expMask) == expMask) && !(bits & fracMask);
345+
return result;
346+
}
347+
else if constexpr(std::is_same_v<TArg, double>)
348+
{
349+
constexpr std::uint64_t expMask = 0x7FF0'0000'0000'0000ULL;
350+
constexpr std::uint64_t fracMask = 0x000F'FFFF'FFFF'FFFFULL;
351+
std::uint64_t bits = alpaka::core::bit_cast<std::uint64_t>(arg);
352+
bool result = ((bits & expMask) == expMask) && !(bits & fracMask);
353+
return result;
354+
}
355+
else
356+
{
357+
static_assert(!sizeof(TArg), "Unsupported floating-point type");
358+
}
359+
ALPAKA_UNREACHABLE(false);
360+
}
361+
};
362+
363+
//! Custom IEEE 754 bitwise implementation of isnan
364+
//! std counterpart does not work correctly for some compiler flags at CPU backend
365+
template<typename TArg>
366+
struct Isnan<IsnanStdLib, TArg, std::enable_if_t<std::is_floating_point_v<TArg>>>
367+
{
368+
auto operator()(IsnanStdLib const& /* ctx */, TArg const& arg) -> bool
369+
{
370+
if constexpr(std::is_same_v<TArg, float>)
371+
{
372+
constexpr std::uint32_t expMask = 0x7F80'0000;
373+
constexpr std::uint32_t fracMask = 0x007F'FFFF;
374+
std::uint32_t bits = alpaka::core::bit_cast<std::uint32_t>(arg);
375+
bool result = ((bits & expMask) == expMask) && (bits & fracMask);
376+
return result;
377+
}
378+
else if constexpr(std::is_same_v<TArg, double>)
379+
{
380+
constexpr std::uint64_t expMask = 0x7FF0'0000'0000'0000ULL;
381+
constexpr std::uint64_t fracMask = 0x000F'FFFF'FFFF'FFFFULL;
382+
std::uint64_t bits = alpaka::core::bit_cast<std::uint64_t>(arg);
383+
bool result = ((bits & expMask) == expMask) && (bits & fracMask);
384+
return result;
385+
}
386+
else
387+
{
388+
static_assert(!sizeof(TArg), "Unsupported floating-point type");
389+
}
390+
ALPAKA_UNREACHABLE(false);
391+
}
392+
};
393+
394+
297395
} // namespace trait
298396

299397
} // namespace alpaka::math

test/unit/math/src/DataGen.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,37 @@ namespace mathtest
185185
args(2).arg[k] = std::numeric_limits<TData>::signaling_NaN();
186186
args(3).arg[k] = std::numeric_limits<TData>::infinity();
187187
args(4).arg[k] = -std::numeric_limits<TData>::infinity();
188-
constexpr size_t nFixed = 5;
188+
// One negative one positive value
189+
if constexpr(std::is_same_v<TData, float>)
190+
{
191+
args(5).arg[k] = 1.1f; // Use float literal
192+
args(6).arg[k] = -1.1f;
193+
}
194+
else if constexpr(std::is_same_v<TData, double>)
195+
{
196+
args(5).arg[k] = 1.1; // Use double literal
197+
args(6).arg[k] = -1.1;
198+
}
199+
else if constexpr(std::is_same_v<TData, alpaka::Complex<float>>)
200+
{
201+
args(5).arg[k] = alpaka::Complex<float>{1.1f, 2.1f}; // Complex float
202+
args(6).arg[k] = alpaka::Complex<float>{-1.1f, -2.1f};
203+
}
204+
else if constexpr(std::is_same_v<TData, alpaka::Complex<double>>)
205+
{
206+
args(5).arg[k] = alpaka::Complex<double>{1.1, 2.1}; // Complex double
207+
args(6).arg[k] = alpaka::Complex<double>{-1.1, -2.1};
208+
}
209+
210+
constexpr size_t nFixed = 7;
189211
size_t i = nFixed;
190212
// no need to test for denormal for now: not supported by CUDA
191213
// for(; i < nFixed + (TArgs::capacity - nFixed) / 2; ++i)
192214
// {
193215
// const TData v = rngWrapper.getNumber(dist, eng) *
194216
// std::numeric_limits<TData>::denorm_min(); args(i).arg[k] = (i % 2 == 0) ? v : -v;
195217
// }
218+
// Next values
196219
for(; i < TArgs::capacity; ++i)
197220
{
198221
TData const v = rngWrapper.getNumber(dist, eng);

test/unit/math/src/TestTemplate.hpp

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,53 @@ namespace mathtest
5353
using type = T;
5454
};
5555

56+
//!
57+
//! \brief setExpectedResultForSpecificInput
58+
//! This function is for testing alpaka::math functions isinf, isnan, isfinite. Since for some compile
59+
//! options for CPU backends; std::isnan, std::isinf and std::isfinite does not work properpy; test results can
60+
//! only be tested by setting the expected results for the known input. For 3 testing operators OpIsnan, OpIsinf,
61+
//! OpIsfinite; at the beginning of test input array, specific values are used and their expected results are set
62+
//! in that function.
63+
//! input[0]: [ 0 ]
64+
//! input[1]: [ nan ]
65+
//! input[2]: [ nan ]
66+
//! input[3]: [ inf ]
67+
//! input[4]: [ -inf ]
68+
//! \param stdExpectedResult Expected value for the operator, the type of resulting operation could either be type
69+
//! of operand (although for uniary op. like isInf it is bool) Since all operation outputs are represented by
70+
//! operand type in the code, this function uses 0 and 1 for the results. \param idx is the index in the input
71+
//! buffer
72+
//!
73+
74+
template<typename TFunctor, typename TData>
75+
void setExpectedResultForSpecificInput(TData& stdExpectedResult, size_t idx)
76+
{
77+
// marked as [[maybe_unused]] because nvcc 11.2 ignores the "else" branch when a previous "if" is true.
78+
[[maybe_unused]] constexpr bool isIsnan = std::is_same_v<TFunctor, OpIsnan>;
79+
[[maybe_unused]] constexpr bool isIsinf = std::is_same_v<TFunctor, OpIsinf>;
80+
[[maybe_unused]] constexpr bool isIsfinite = std::is_same_v<TFunctor, OpIsfinite>;
81+
82+
if constexpr(isIsnan)
83+
{
84+
// for the input[1] and input[2] input is Nan and isNan should be tested by result 1.
85+
stdExpectedResult = (idx == 1 || idx == 2) ? static_cast<TData>(1) : static_cast<TData>(0);
86+
}
87+
else if constexpr(isIsinf)
88+
{
89+
// for the input[3] and input[4] input is Inf and -Inf should be tested by result 1.
90+
stdExpectedResult = (idx == 3 || idx == 4) ? static_cast<TData>(1) : static_cast<TData>(0);
91+
}
92+
else if constexpr(isIsfinite)
93+
{
94+
// input[0] is 0 hence it is finite, other data starting after nan and infs are finite.
95+
stdExpectedResult = (idx == 0 || idx > 4) ? static_cast<TData>(1) : static_cast<TData>(0);
96+
}
97+
else
98+
{
99+
stdExpectedResult = static_cast<TData>(0);
100+
}
101+
}
102+
56103
//! Base test template for math unit tests
57104
//! @tparam TAcc Accelerator.
58105
//! @tparam TFunctor Functor defined in Functor.hpp.
@@ -140,9 +187,25 @@ namespace mathtest
140187
#endif
141188
for(size_t i = 0; i < Args::capacity; ++i)
142189
{
143-
TData std_result = functor(args(i));
144-
INFO("Idx i: " << i << " computed : " << results(i) << " vs expected: " << std_result);
145-
REQUIRE(isApproxEqual(results(i), std_result));
190+
TData stdExpectedResult{};
191+
192+
constexpr bool isSpecialCase = std::is_same_v<TFunctor, OpIsnan> || std::is_same_v<TFunctor, OpIsinf>
193+
|| std::is_same_v<TFunctor, OpIsfinite>;
194+
195+
// Only for specific operators, the results for the test inputs can only be verified by setting the
196+
// expected specific result manually.
197+
if constexpr((std::is_same_v<TData, float> || std::is_same_v<TData, double>) &&isSpecialCase)
198+
{
199+
setExpectedResultForSpecificInput<TFunctor>(stdExpectedResult, i);
200+
}
201+
else
202+
{
203+
// Calculated expected result using std functions
204+
stdExpectedResult = functor(args(i));
205+
}
206+
INFO("Idx i: " << i << " computed : " << results(i) << " vs expected: " << stdExpectedResult);
207+
// Validate
208+
REQUIRE(isApproxEqual(results(i), stdExpectedResult));
146209
}
147210
}
148211

0 commit comments

Comments
 (0)