Skip to content

Commit 8ea8d96

Browse files
committed
Fix most issues
1 parent bd35e7b commit 8ea8d96

13 files changed

+174
-58
lines changed

include/alpaka/math/Complex.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,9 @@ namespace alpaka
256256
template<typename T>
257257
ALPAKA_FN_HOST_ACC Complex<T> operator/(Complex<T> const& lhs, Complex<T> const& rhs)
258258
{
259-
return lhs
260-
* Complex<T>{
261-
rhs.real() / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag()),
262-
-rhs.imag() / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag())};
259+
return Complex<T>{
260+
(lhs.real() * rhs.real() + lhs.imag() * rhs.imag()) / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag()),
261+
(lhs.imag() * rhs.real() - lhs.real() * rhs.imag()) / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag())};
263262
}
264263

265264
//! Division of complex and a real number
@@ -273,10 +272,9 @@ namespace alpaka
273272
template<typename T>
274273
ALPAKA_FN_HOST_ACC Complex<T> operator/(T const& lhs, Complex<T> const& rhs)
275274
{
276-
return lhs
277-
* Complex<T>{
278-
rhs.real() / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag()),
279-
-rhs.imag() / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag())};
275+
return Complex<T>{
276+
lhs * rhs.real() / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag()),
277+
-lhs * rhs.imag() / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag())};
280278
}
281279

282280
//! Equality of two complex numbers

include/alpaka/math/MathStdLib.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ namespace alpaka
5454
, public Atan2StdLib
5555
, public CbrtStdLib
5656
, public CeilStdLib
57+
, public ConjStdLib
5758
, public CosStdLib
5859
, public ErfStdLib
5960
, public ExpStdLib

include/alpaka/math/MathUniformCudaHipBuiltIn.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ namespace alpaka
5656
, public Atan2UniformCudaHipBuiltIn
5757
, public CbrtUniformCudaHipBuiltIn
5858
, public CeilUniformCudaHipBuiltIn
59+
, public ConjUniformCudaHipBuiltIn
5960
, public CosUniformCudaHipBuiltIn
6061
, public ErfUniformCudaHipBuiltIn
6162
, public ExpUniformCudaHipBuiltIn

include/alpaka/math/arg/ArgUniformCudaHipBuiltIn.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace alpaka
3030

3131
namespace traits
3232
{
33-
//! The CUDA/HIP built in arg trait specialization.
33+
//! The CUDA/HIP built in arg trait specialization for float and double.
3434
template<typename TArgument>
3535
struct Arg<ArgUniformCudaHipBuiltIn, TArgument, std::enable_if_t<std::is_floating_point<TArgument>::value>>
3636
{

include/alpaka/math/atan/AtanUniformCudaHipBuiltIn.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ namespace alpaka
6060
{
6161
// This holds everywhere, including the branch cuts: atan(z) = -i/2 * ln((i - z) / (i + z))
6262
return Complex<T>{0.0, -0.5}
63-
* log(ctx, (Complex<T>{0.0, 1.0} + arg) / (Complex<T>{0.0, 1.0} + arg));
63+
* log(ctx, (Complex<T>{0.0, 1.0} - arg) / (Complex<T>{0.0, 1.0} + arg));
6464
}
6565
};
6666
} // namespace traits

include/alpaka/math/conj/ConjUniformCudaHipBuiltIn.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace alpaka
2929

3030
namespace traits
3131
{
32-
//! The CUDA/HIP built in conj trait specialization.
32+
//! The CUDA/HIP built in conj trait specialization for float and double.
3333
template<typename TArg>
3434
struct Conj<ConjUniformCudaHipBuiltIn, TArg, std::enable_if_t<std::is_floating_point<TArg>::value>>
3535
{

include/alpaka/math/exp/ExpUniformCudaHipBuiltIn.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
# include <alpaka/core/CudaHipMath.hpp>
1515
# include <alpaka/core/Unused.hpp>
1616
# include <alpaka/math/Complex.hpp>
17-
# include <alpaka/math/cos/Traits.hpp>
1817
# include <alpaka/math/exp/Traits.hpp>
19-
# include <alpaka/math/sin/Traits.hpp>
18+
# include <alpaka/math/sincos/Traits.hpp>
2019

2120
# include <type_traits>
2221

@@ -60,7 +59,9 @@ namespace alpaka
6059
__device__ auto operator()(TCtx const& ctx, Complex<T> const& arg)
6160
{
6261
// exp(z) = exp(x + iy) = exp(x) * (cos(y) + i * sin(y))
63-
return exp(ctx, arg.real()) * Complex<T>{cos(ctx, arg.imag()), sin(ctx, arg.imag())};
62+
auto re = T{}, im = T{};
63+
sincos(ctx, arg.imag(), im, re);
64+
return exp(ctx, arg.real()) * Complex<T>{re, im};
6465
}
6566
};
6667
} // namespace traits

include/alpaka/math/sqrt/SqrtUniformCudaHipBuiltIn.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ namespace alpaka
6363
// principal value of sqrt(z) = sqrt(|z|) * e^(i * arg(z) / 2)
6464
auto const halfArg = T(0.5) * arg(ctx, argument);
6565
auto re = T{}, im = T{};
66-
sincos(ctx, halfArg, re, im);
66+
sincos(ctx, halfArg, im, re);
6767
return sqrt(ctx, abs(ctx, argument)) * Complex<T>(re, im);
6868
}
6969
};

include/alpaka/math/tan/TanUniformCudaHipBuiltIn.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ namespace alpaka
5858
template<typename TCtx>
5959
__device__ auto operator()(TCtx const& ctx, Complex<T> const& arg)
6060
{
61-
// tan(z) = i * (e^-iz - e^iz) / (e^-iz + e^iz)
62-
auto const exp1 = exp(ctx, Complex<T>{0.0, -1.0} * arg);
63-
auto const exp2 = exp(ctx, Complex<T>{0.0, 1.0} * arg);
64-
return Complex<T>{0.0, 1.0} * (exp1 - exp2) / (exp1 + exp2);
61+
// tan(z) = i * (e^-iz - e^iz) / (e^-iz + e^iz) = i * (1 - e^2iz) / (1 + e^2iz)
62+
// Warning: this straightforward implementation can easily result in NaN as 0/0 or inf/inf.
63+
auto const expValue = exp(ctx, Complex<T>{0.0, 2.0} * arg);
64+
return Complex<T>{0.0, 1.0} * (T{1.0} - expValue) / (T{1.0} + expValue);
6565
}
6666
};
6767
} // namespace traits

test/unit/math/src/DataGen.hpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "Defines.hpp"
1313

1414
#include <cassert>
15+
#include <cmath>
1516
#include <limits>
1617
#include <random>
1718

@@ -26,12 +27,12 @@ namespace alpaka
2627
template<typename TData>
2728
struct RngHelper
2829
{
29-
static constexpr auto getMax()
30+
static auto getMax()
3031
{
3132
return std::numeric_limits<TData>::max();
3233
}
3334

34-
static constexpr auto getLowest()
35+
static auto getLowest()
3536
{
3637
return std::numeric_limits<TData>::lowest();
3738
}
@@ -48,16 +49,16 @@ namespace alpaka
4849
template<typename TData>
4950
struct RngHelper<Complex<TData>>
5051
{
51-
static constexpr auto getMax()
52+
static auto getMax()
5253
{
53-
auto const max = std::numeric_limits<TData>::max();
54-
return Complex<TData>{max, max};
54+
/// auto const max = TData{0.01} * std::sqrt(std::numeric_limits<TData>::max());
55+
/// return Complex<TData>{max, TData{-0.7} * max};
56+
return Complex<TData>{2.0, 2.0};
5557
}
5658

57-
static constexpr auto getLowest()
59+
static auto getLowest()
5860
{
59-
auto const lowest = std::numeric_limits<TData>::lowest();
60-
return Complex<TData>{lowest, lowest};
61+
return -getMax();
6162
}
6263

6364
using Distribution = std::uniform_real_distribution<TData>;
@@ -98,13 +99,13 @@ namespace alpaka
9899
static_assert(TArgs::capacity > 6, "Set of args must provide > 6 entries.");
99100
using RngHelper = RngHelper<TData>;
100101
auto rngHelper = RngHelper{};
101-
constexpr auto max = rngHelper.getMax();
102-
constexpr auto low = rngHelper.getLowest();
102+
auto const max = rngHelper.getMax();
103+
auto const low = rngHelper.getLowest();
103104
std::default_random_engine eng{static_cast<std::default_random_engine::result_type>(seed)};
104105

105106
// These pseudo-random numbers are implementation/platform specific!
106107
using Distribution = typename RngHelper::Distribution;
107-
Distribution dist(0, 1000);
108+
Distribution dist(0, /*1000*/ 10.0);
108109
Distribution distOne(-1, 1);
109110
for(size_t k = 0; k < TFunctor::arity_nr; ++k)
110111
{

0 commit comments

Comments
 (0)