Skip to content

Commit bd35e7b

Browse files
committed
mostly working version, some TODOs left and a couple tests fail
1 parent 242a858 commit bd35e7b

28 files changed

+267
-145
lines changed

include/alpaka/math/Complex.hpp

Lines changed: 170 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ namespace alpaka
166166
T m_real, m_imag;
167167
};
168168

169+
//! Host-device arithmetic operations matching std::complex<T>.
170+
//!
171+
//! They take and return alpaka::Complex.
172+
//!
173+
//! @{
174+
//!
175+
169176
//! Unary plus (added for compatibility with std::complex)
170177
template<typename T>
171178
ALPAKA_FN_HOST_ACC Complex<T> operator+(Complex<T> const& val)
@@ -322,78 +329,214 @@ namespace alpaka
322329
|| !math::floatEqualExactNoWarning(static_cast<T>(0), rhs.imag());
323330
}
324331

325-
//! Host-side abs falling back to std:: implementation
332+
//! @}
333+
334+
//! Host-only output of a complex number
335+
template<typename T, typename TChar, typename TTraits>
336+
std::basic_ostream<TChar, TTraits>& operator<<(std::basic_ostream<TChar, TTraits>& os, Complex<T> const& x)
337+
{
338+
os << std::complex<T>{x};
339+
return os;
340+
}
341+
342+
//! Host-only input of a complex number
343+
template<typename T, typename TChar, typename TTraits>
344+
std::basic_istream<TChar, TTraits>& operator>>(std::basic_istream<TChar, TTraits>& is, Complex<T> const& x)
345+
{
346+
std::complex<T> z;
347+
is >> z;
348+
x = z;
349+
return is;
350+
}
351+
352+
//! Host-only math functions matching std::complex<T>.
353+
//!
354+
//! They take and return alpaka::Complex (or a real number when appropriate).
355+
//! Internally cast, fall back to std::complex implementation and cast back.
356+
//! These functions can be used directly on the host side.
357+
//! They are also picked up by ADL in math traits for CPU backends.
358+
//!
359+
//! On the device side, alpaka math traits must be used instead.
360+
//! Note that the set of the traits is currently a bit smaller.
361+
//!
362+
//! @{
363+
//!
364+
365+
//! Absolute value
326366
template<typename T>
327-
ALPAKA_FN_HOST auto abs(Complex<T> const& x)
367+
constexpr ALPAKA_FN_HOST T abs(Complex<T> const& x)
328368
{
329369
return std::abs(std::complex<T>(x));
330370
}
331371

332-
//! Host-side acos falling back to std:: implementation
372+
//! Arc cosine
333373
template<typename T>
334-
ALPAKA_FN_HOST auto acos(Complex<T> const& x)
374+
constexpr ALPAKA_FN_HOST Complex<T> acos(Complex<T> const& x)
335375
{
336376
return std::acos(std::complex<T>(x));
337377
}
338378

379+
//! Arc hyperbolic cosine
380+
template<typename T>
381+
constexpr ALPAKA_FN_HOST Complex<T> acosh(Complex<T> const& x)
382+
{
383+
return std::acosh(std::complex<T>(x));
384+
}
385+
339386
//! Host-side arg falling back to std:: implementation
340387
template<typename T>
341-
ALPAKA_FN_HOST auto arg(Complex<T> const& x)
388+
constexpr ALPAKA_FN_HOST T arg(Complex<T> const& x)
342389
{
343390
return std::arg(std::complex<T>(x));
344391
}
345392

346-
//! Host-side asin falling back to std:: implementation
393+
//! Arc sine
347394
template<typename T>
348-
ALPAKA_FN_HOST auto asin(Complex<T> const& x)
395+
constexpr ALPAKA_FN_HOST Complex<T> asin(Complex<T> const& x)
349396
{
350397
return std::asin(std::complex<T>(x));
351398
}
352399

353-
//! Host-side atan falling back to std:: implementation
400+
//! Arc hyperbolic sine
354401
template<typename T>
355-
ALPAKA_FN_HOST auto atan(Complex<T> const& x)
402+
constexpr ALPAKA_FN_HOST Complex<T> asinh(Complex<T> const& x)
403+
{
404+
return std::asinh(std::complex<T>(x));
405+
}
406+
407+
//! Arc tangent
408+
template<typename T>
409+
constexpr ALPAKA_FN_HOST Complex<T> atan(Complex<T> const& x)
356410
{
357411
return std::atan(std::complex<T>(x));
358412
}
359413

360-
//! Host-side pow falling back to std:: implementation
414+
//! Arc hyperbolic tangent
415+
template<typename T>
416+
constexpr ALPAKA_FN_HOST Complex<T> atanh(Complex<T> const& x)
417+
{
418+
return std::atanh(std::complex<T>(x));
419+
}
420+
421+
//! Complex conjugate
422+
template<typename T>
423+
constexpr ALPAKA_FN_HOST Complex<T> conj(Complex<T> const& x)
424+
{
425+
return std::conj(std::complex<T>(x));
426+
}
427+
428+
//! Cosine
429+
template<typename T>
430+
constexpr ALPAKA_FN_HOST Complex<T> cos(Complex<T> const& x)
431+
{
432+
return std::cos(std::complex<T>(x));
433+
}
434+
435+
//! Hyperbolic cosine
436+
template<typename T>
437+
constexpr ALPAKA_FN_HOST Complex<T> cosh(Complex<T> const& x)
438+
{
439+
return std::cosh(std::complex<T>(x));
440+
}
441+
442+
//! Exponential
443+
template<typename T>
444+
constexpr ALPAKA_FN_HOST Complex<T> exp(Complex<T> const& x)
445+
{
446+
return std::exp(std::complex<T>(x));
447+
}
448+
449+
//! Natural logarithm
450+
template<typename T>
451+
constexpr ALPAKA_FN_HOST Complex<T> log(Complex<T> const& x)
452+
{
453+
return std::log(std::complex<T>(x));
454+
}
455+
456+
//! Base 10 logarithm
457+
template<typename T>
458+
constexpr ALPAKA_FN_HOST Complex<T> log10(Complex<T> const& x)
459+
{
460+
return std::log10(std::complex<T>(x));
461+
}
462+
463+
//! Squared magnitude
464+
template<typename T>
465+
constexpr ALPAKA_FN_HOST T norm(Complex<T> const& x)
466+
{
467+
return std::norm(std::complex<T>(x));
468+
}
469+
470+
//! Get a complex number with given magnitude and phase angle
471+
template<typename T>
472+
constexpr ALPAKA_FN_HOST Complex<T> polar(T const& r, T const& theta = T())
473+
{
474+
return std::polar(r, theta);
475+
}
476+
477+
//! Complex power of a complex number
361478
template<typename T, typename U>
362-
ALPAKA_FN_HOST auto pow(Complex<T> const& x, Complex<U> const& y)
479+
constexpr ALPAKA_FN_HOST Complex<T> pow(Complex<T> const& x, Complex<U> const& y)
363480
{
364481
return std::pow(std::complex<T>(x), std::complex<U>(y));
365482
}
366483

367-
//! Host-side pow falling back to std:: implementation
484+
//! Real power of a complex number
368485
template<typename T, typename U>
369-
ALPAKA_FN_HOST auto pow(Complex<T> const& x, U const& y)
486+
constexpr ALPAKA_FN_HOST Complex<T> pow(Complex<T> const& x, U const& y)
370487
{
371488
return std::pow(std::complex<T>(x), y);
372489
}
373490

374-
//! Host-side pow falling back to std:: implementation
491+
//! Complex power of a real number
375492
template<typename T, typename U>
376-
ALPAKA_FN_HOST auto pow(T const& x, Complex<U> const& y)
493+
constexpr ALPAKA_FN_HOST Complex<T> pow(T const& x, Complex<U> const& y)
377494
{
378495
return std::pow(x, std::complex<U>(y));
379496
}
380497

381-
//! Host-only output of a complex number
382-
template<typename T, typename TChar, typename TTraits>
383-
std::basic_ostream<TChar, TTraits>& operator<<(std::basic_ostream<TChar, TTraits>& os, Complex<T> const& x)
498+
//! Projection onto the Riemann sphere
499+
template<typename T>
500+
constexpr ALPAKA_FN_HOST Complex<T> proj(Complex<T> const& x)
384501
{
385-
os << std::complex<T>{x};
386-
return os;
502+
return std::proj(std::complex<T>(x));
387503
}
388504

389-
//! Host-only input of a complex number
390-
template<typename T, typename TChar, typename TTraits>
391-
std::basic_istream<TChar, TTraits>& operator>>(std::basic_istream<TChar, TTraits>& is, Complex<T> const& x)
505+
//! Sine
506+
template<typename T>
507+
constexpr ALPAKA_FN_HOST Complex<T> sin(Complex<T> const& x)
392508
{
393-
std::complex<T> z;
394-
is >> z;
395-
x = z;
396-
return is;
509+
return std::sin(std::complex<T>(x));
510+
}
511+
512+
//! Hyperbolic sine
513+
template<typename T>
514+
constexpr ALPAKA_FN_HOST Complex<T> sinh(Complex<T> const& x)
515+
{
516+
return std::sinh(std::complex<T>(x));
397517
}
398518

519+
//! Square root
520+
template<typename T>
521+
constexpr ALPAKA_FN_HOST Complex<T> sqrt(Complex<T> const& x)
522+
{
523+
return std::sqrt(std::complex<T>(x));
524+
}
525+
526+
//! Tangent
527+
template<typename T>
528+
constexpr ALPAKA_FN_HOST Complex<T> tan(Complex<T> const& x)
529+
{
530+
return std::tan(std::complex<T>(x));
531+
}
532+
533+
//! Hyperbolic tangent
534+
template<typename T>
535+
constexpr ALPAKA_FN_HOST Complex<T> tanh(Complex<T> const& x)
536+
{
537+
return std::tanh(std::complex<T>(x));
538+
}
539+
540+
//! @}
541+
399542
} // namespace alpaka

include/alpaka/math/MathStdLib.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ namespace alpaka
4848
class MathStdLib
4949
: public AbsStdLib
5050
, public AcosStdLib
51+
, public ArgStdLib
5152
, public AsinStdLib
5253
, public AtanStdLib
5354
, public Atan2StdLib

include/alpaka/math/StdLibArg.hpp

Lines changed: 0 additions & 56 deletions
This file was deleted.

include/alpaka/math/abs/AbsUniformCudaHipBuiltIn.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ namespace alpaka
6464
template<typename T>
6565
struct Abs<AbsUniformCudaHipBuiltIn, Complex<T>>
6666
{
67-
__device__ auto operator()(AbsUniformCudaHipBuiltIn const& abs_ctx, Complex<T> const& arg)
67+
//! Take context as original (accelerator) type, since we call other math functions
68+
template<typename TCtx>
69+
__device__ auto operator()(TCtx const& ctx, Complex<T> const& arg)
6870
{
69-
// Call alpaka's sqrt as it handles all types T correctly
70-
return sqrt(abs_ctx, arg.real() * arg.real() + arg.imag() * arg.imag());
71+
return sqrt(ctx, arg.real() * arg.real() + arg.imag() * arg.imag());
7172
}
7273
};
7374
} // namespace traits

include/alpaka/math/acos/AcosUniformCudaHipBuiltIn.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ namespace alpaka
5555
template<typename T>
5656
struct Acos<AcosUniformCudaHipBuiltIn, Complex<T>>
5757
{
58-
__device__ auto operator()(AcosUniformCudaHipBuiltIn const& acos_ctx, Complex<T> const& arg)
58+
//! Take context as original (accelerator) type, since we call other math functions
59+
template<typename TCtx>
60+
__device__ auto operator()(TCtx const& ctx, Complex<T> const& arg)
5961
{
6062
// This holds everywhere, including the branch cuts: acos(z) = -i * ln(z + i * sqrt(1 - z^2))
6163
return Complex<T>{0.0, -1.0}
62-
* log(acos_ctx, arg + Complex<T>{0.0, 1.0} * sqrt(acos_ctx, T(1.0) - arg * arg));
64+
* log(ctx, arg + Complex<T>{0.0, 1.0} * sqrt(ctx, T(1.0) - arg * arg));
6365
}
6466
};
6567
} // namespace traits

include/alpaka/math/arg/ArgUniformCudaHipBuiltIn.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ namespace alpaka
4444
template<typename T>
4545
struct Arg<ArgUniformCudaHipBuiltIn, Complex<T>>
4646
{
47-
__device__ auto operator()(ArgUniformCudaHipBuiltIn const& arg_ctx, Complex<T> const& argument)
47+
//! Take context as original (accelerator) type, since we call other math functions
48+
template<typename TCtx>
49+
__device__ auto operator()(TCtx const& ctx, Complex<T> const& argument)
4850
{
49-
return atan2(arg_ctx, argument.imag(), argument.real());
51+
return atan2(ctx, argument.imag(), argument.real());
5052
}
5153
};
5254
} // namespace traits

include/alpaka/math/asin/AsinUniformCudaHipBuiltIn.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ namespace alpaka
5555
template<typename T>
5656
struct Asin<AsinUniformCudaHipBuiltIn, Complex<T>>
5757
{
58-
__device__ auto operator()(AsinUniformCudaHipBuiltIn const& asin_ctx, Complex<T> const& arg)
58+
//! Take context as original (accelerator) type, since we call other math functions
59+
template<typename TCtx>
60+
__device__ auto operator()(TCtx const& ctx, Complex<T> const& arg)
5961
{
6062
// This holds everywhere, including the branch cuts: asin(z) = i * ln(sqrt(1 - z^2) - i * z)
61-
return Complex<T>{0.0, 1.0}
62-
* log(asin_ctx, sqrt(asin_ctx, T(1.0) - arg * arg) - Complex<T>{0.0, 1.0} * arg);
63+
return Complex<T>{0.0, 1.0} * log(ctx, sqrt(ctx, T(1.0) - arg * arg) - Complex<T>{0.0, 1.0} * arg);
6364
}
6465
};
6566
} // namespace traits

include/alpaka/math/atan/AtanUniformCudaHipBuiltIn.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,13 @@ namespace alpaka
5454
template<typename T>
5555
struct Atan<AtanUniformCudaHipBuiltIn, Complex<T>>
5656
{
57-
__device__ auto operator()(AtanUniformCudaHipBuiltIn const& atan_ctx, Complex<T> const& arg)
57+
//! Take context as original (accelerator) type, since we call other math functions
58+
template<typename TCtx>
59+
__device__ auto operator()(TCtx const& ctx, Complex<T> const& arg)
5860
{
5961
// This holds everywhere, including the branch cuts: atan(z) = -i/2 * ln((i - z) / (i + z))
6062
return Complex<T>{0.0, -0.5}
61-
* log(atan_ctx, (Complex<T>{0.0, 1.0} + arg) / (Complex<T>{0.0, 1.0} + arg));
63+
* log(ctx, (Complex<T>{0.0, 1.0} + arg) / (Complex<T>{0.0, 1.0} + arg));
6264
}
6365
};
6466
} // namespace traits

0 commit comments

Comments
 (0)