Skip to content

Commit 0adc88b

Browse files
authored
Fixes while documenting (kokkos#2466)
* BLAS - scal: removing check on assignable memory spaces That check is stricter than required as we will values by reference to perform copies and won't try to reassign pointers. Signed-off-by: Luc Berger-Vergiat <[email protected]> * BLAS - rot: check at runtime that X and Y have same extent Signed-off-by: Luc Berger-Vergiat <[email protected]> * BLAS - rot: improving static assertions Signed-off-by: Luc Berger-Vergiat <[email protected]> * BLAS - rotg: check for non-complex types Signed-off-by: Luc Berger-Vergiat <[email protected]> * BLAS - ger: check that matrix stores values as non-const Signed-off-by: Luc Berger-Vergiat <[email protected]> * BLAS - trmm: check for valid execution space type. Signed-off-by: Luc Berger-Vergiat <[email protected]> * BLAS: fix missing semi-colon at end of static_assert Signed-off-by: Luc Berger-Vergiat <[email protected]> * Applying clang-format Signed-off-by: Luc Berger-Vergiat <[email protected]> * More clang-format Signed-off-by: Luc Berger-Vergiat <[email protected]> * Blas - rot: fixing interface of rot The cosine coefficient is strictly real while the sine coefficient can be real or complex leading to a bug in the current API. This commit should fix that for the native and TPL implementation and the associated unit-test is also fixed accordingly. Signed-off-by: Luc Berger-Vergiat <[email protected]> * BLAS - ROT: fixing types for Host TPL calls to ROT function The types for the arguments c and s are actually different and need to be appropriately propagated through the TPL layers of the library. Signed-off-by: Luc Berger-Vergiat <[email protected]> --------- Signed-off-by: Luc Berger-Vergiat <[email protected]>
1 parent 85bbf1f commit 0adc88b

12 files changed

+234
-149
lines changed

blas/impl/KokkosBlas1_rot_impl.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323
namespace KokkosBlas {
2424
namespace Impl {
2525

26-
template <class VectorView, class ScalarView>
26+
template <class VectorView, class MagnitudeView, class ScalarView>
2727
struct rot_functor {
2828
using scalar_type = typename VectorView::non_const_value_type;
2929

3030
VectorView X, Y;
31-
ScalarView c, s;
31+
MagnitudeView c;
32+
ScalarView s;
3233

33-
rot_functor(VectorView const& X_, VectorView const& Y_, ScalarView const& c_, ScalarView const& s_)
34+
rot_functor(VectorView const& X_, VectorView const& Y_, MagnitudeView const& c_, ScalarView const& s_)
3435
: X(X_), Y(Y_), c(c_), s(s_) {}
3536

3637
KOKKOS_INLINE_FUNCTION
@@ -41,8 +42,8 @@ struct rot_functor {
4142
}
4243
};
4344

44-
template <class ExecutionSpace, class VectorView, class ScalarView>
45-
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
45+
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
46+
void Rot_Invoke(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
4647
ScalarView const& s) {
4748
Kokkos::RangePolicy<ExecutionSpace> rot_policy(space, 0, X.extent(0));
4849
rot_functor rot_func(X, Y, c, s);

blas/impl/KokkosBlas1_rot_spec.hpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
namespace KokkosBlas {
3030
namespace Impl {
3131
// Specialization struct which defines whether a specialization exists
32-
template <class ExecutionSpace, class VectorView, class ScalarView>
32+
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
3333
struct rot_eti_spec_avail {
3434
enum : bool { value = false };
3535
};
@@ -43,14 +43,15 @@ struct rot_eti_spec_avail {
4343
// We may spread out definitions (see _INST macro below) across one or
4444
// more .cpp files.
4545
//
46-
#define KOKKOSBLAS1_ROT_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
47-
template <> \
48-
struct rot_eti_spec_avail< \
49-
EXECSPACE, \
50-
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
51-
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
52-
Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
53-
enum : bool { value = true }; \
46+
#define KOKKOSBLAS1_ROT_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
47+
template <> \
48+
struct rot_eti_spec_avail< \
49+
EXECSPACE, \
50+
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
51+
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
52+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
53+
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
54+
enum : bool { value = true }; \
5455
};
5556

5657
// Include the actual specialization declarations
@@ -61,19 +62,19 @@ namespace KokkosBlas {
6162
namespace Impl {
6263

6364
// Unification layer
64-
template <class ExecutionSpace, class VectorView, class ScalarView,
65-
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, ScalarView>::value,
66-
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, ScalarView>::value>
65+
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView,
66+
bool tpl_spec_avail = rot_tpl_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value,
67+
bool eti_spec_avail = rot_eti_spec_avail<ExecutionSpace, VectorView, MagnitudeView, ScalarView>::value>
6768
struct Rot {
68-
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
69+
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
6970
ScalarView const& s);
7071
};
7172

7273
#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
7374
//! Full specialization of Rot.
74-
template <class ExecutionSpace, class VectorView, class ScalarView>
75-
struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
76-
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
75+
template <class ExecutionSpace, class VectorView, class MagnitudeView, class ScalarView>
76+
struct Rot<ExecutionSpace, VectorView, MagnitudeView, ScalarView, false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
77+
static void rot(ExecutionSpace const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
7778
ScalarView const& s) {
7879
Kokkos::Profiling::pushRegion(KOKKOSKERNELS_IMPL_COMPILE_LIBRARY ? "KokkosBlas::rot[ETI]"
7980
: "KokkosBlas::rot[noETI]");
@@ -86,7 +87,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
8687
typeid(VectorView).name(), typeid(ScalarView).name());
8788
}
8889
#endif
89-
Rot_Invoke<ExecutionSpace, VectorView, ScalarView>(space, X, Y, c, s);
90+
Rot_Invoke<ExecutionSpace, VectorView, MagnitudeView, ScalarView>(space, X, Y, c, s);
9091
Kokkos::Profiling::popRegion();
9192
}
9293
};
@@ -108,6 +109,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
108109
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
109110
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
110111
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
112+
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
111113
false, true>;
112114

113115
//
@@ -121,6 +123,7 @@ struct Rot<ExecutionSpace, VectorView, ScalarView, false, KOKKOSKERNELS_IMPL_COM
121123
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
122124
Kokkos::View<typename Kokkos::ArithTraits<SCALAR>::mag_type, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, \
123125
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
126+
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXECSPACE, MEMSPACE>, Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
124127
false, true>;
125128

126129
#include <KokkosBlas1_rot_tpl_spec_decl.hpp>

blas/src/KokkosBlas1_rot.hpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,45 +21,68 @@
2121

2222
namespace KokkosBlas {
2323

24-
template <class execution_space, class VectorView, class ScalarView>
25-
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, ScalarView const& c,
24+
template <class execution_space, class VectorView, class MagnitudeView, class ScalarView>
25+
void rot(execution_space const& space, VectorView const& X, VectorView const& Y, MagnitudeView const& c,
2626
ScalarView const& s) {
2727
static_assert(Kokkos::is_execution_space<execution_space>::value,
2828
"rot: execution_space template parameter is not a Kokkos "
2929
"execution space.");
30+
static_assert(Kokkos::is_view_v<VectorView>, "KokkosBlas::rot: VectorView is not a Kokkos::View.");
31+
static_assert(Kokkos::is_view_v<MagnitudeView>, "KokkosBlas::rot: MagnitudeView is not a Kokkos::View.");
32+
static_assert(Kokkos::is_view_v<ScalarView>, "KokkosBlas::rot: ScalarView is not a Kokkos::View.");
3033
static_assert(VectorView::rank == 1, "rot: VectorView template parameter needs to be a rank 1 view");
34+
static_assert(MagnitudeView::rank == 0, "rot: MagnitudeView template parameter needs to be a rank 0 view");
3135
static_assert(ScalarView::rank == 0, "rot: ScalarView template parameter needs to be a rank 0 view");
3236
static_assert(Kokkos::SpaceAccessibility<execution_space, typename VectorView::memory_space>::accessible,
3337
"rot: VectorView template parameter memory space needs to be accessible "
3438
"from "
3539
"execution_space template parameter");
40+
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MagnitudeView::memory_space>::accessible,
41+
"rot: MagnitudeView template parameter memory space needs to be accessible "
42+
"from "
43+
"execution_space template parameter");
3644
static_assert(Kokkos::SpaceAccessibility<execution_space, typename ScalarView::memory_space>::accessible,
37-
"rot: VectorView template parameter memory space needs to be accessible "
45+
"rot: ScalarView template parameter memory space needs to be accessible "
3846
"from "
3947
"execution_space template parameter");
4048
static_assert(std::is_same<typename VectorView::non_const_value_type, typename VectorView::value_type>::value,
4149
"rot: VectorView template parameter needs to store non-const values");
4250

51+
// Check compatibility of dimensions at run time.
52+
if (X.extent(0) != Y.extent(0)) {
53+
std::ostringstream os;
54+
os << "KokkosBlas::rot: Dimensions of X and Y do not match: "
55+
<< "X: " << X.extent(0) << ", Y: " << Y.extent(0);
56+
KokkosKernels::Impl::throw_runtime_exception(os.str());
57+
}
58+
4359
using VectorView_Internal = Kokkos::View<typename VectorView::non_const_value_type*,
4460
typename KokkosKernels::Impl::GetUnifiedLayout<VectorView>::array_layout,
4561
Kokkos::Device<execution_space, typename VectorView::memory_space>,
4662
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
4763

64+
using MagnitudeView_Internal = Kokkos::View<typename MagnitudeView::non_const_value_type,
65+
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
66+
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
67+
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
68+
4869
using ScalarView_Internal = Kokkos::View<typename ScalarView::non_const_value_type,
4970
typename KokkosKernels::Impl::GetUnifiedLayout<ScalarView>::array_layout,
5071
Kokkos::Device<execution_space, typename ScalarView::memory_space>,
5172
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
5273

5374
VectorView_Internal X_(X), Y_(Y);
54-
ScalarView_Internal c_(c), s_(s);
75+
MagnitudeView_Internal c_(c);
76+
ScalarView_Internal s_(s);
5577

5678
Kokkos::Profiling::pushRegion("KokkosBlas::rot");
57-
Impl::Rot<execution_space, VectorView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_, s_);
79+
Impl::Rot<execution_space, VectorView_Internal, MagnitudeView_Internal, ScalarView_Internal>::rot(space, X_, Y_, c_,
80+
s_);
5881
Kokkos::Profiling::popRegion();
5982
}
6083

61-
template <class VectorView, class ScalarView>
62-
void rot(VectorView const& X, VectorView const& Y, ScalarView const& c, ScalarView const& s) {
84+
template <class VectorView, class MagnitudeView, class ScalarView>
85+
void rot(VectorView const& X, VectorView const& Y, MagnitudeView const& c, ScalarView const& s) {
6386
const typename VectorView::execution_space space = typename VectorView::execution_space();
6487
rot(space, X, Y, c, s);
6588
}

blas/src/KokkosBlas1_rotg.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ void rotg(execution_space const& space, SViewType const& a, SViewType const& b,
4444
"rotg: execution_space cannot access data in SViewType");
4545
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MViewType::memory_space>::accessible,
4646
"rotg: execution_space cannot access data in MViewType");
47+
static_assert(!Kokkos::ArithTraits<typename MViewType::value_type>::is_complex,
48+
"rotg: MViewType cannot hold complex values.");
4749

4850
using SView_Internal = Kokkos::View<
4951
typename SViewType::value_type, typename KokkosKernels::Impl::GetUnifiedLayout<SViewType>::array_layout,

blas/src/KokkosBlas1_scal.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ void scal(const execution_space& space, const RMV& R, const AV& a, const XMV& X)
5858
"X is not a Kokkos::View.");
5959
static_assert(Kokkos::SpaceAccessibility<execution_space, typename XMV::memory_space>::accessible,
6060
"KokkosBlas::scal: XMV must be accessible from execution_space");
61-
static_assert(Kokkos::SpaceAccessibility<typename RMV::memory_space, typename XMV::memory_space>::assignable,
62-
"KokkosBlas::scal: XMV must be assignable to RMV");
6361
static_assert(std::is_same<typename RMV::value_type, typename RMV::non_const_value_type>::value,
6462
"KokkosBlas::scal: R is const. "
6563
"It must be nonconst, because it is an output argument "

blas/src/KokkosBlas2_ger.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,22 @@ template <class ExecutionSpace, class XViewType, class YViewType, class AViewTyp
4343
void ger(const ExecutionSpace& space, const char trans[], const typename AViewType::const_value_type& alpha,
4444
const XViewType& x, const YViewType& y, const AViewType& A) {
4545
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename AViewType::memory_space>::accessible,
46-
"AViewType memory space must be accessible from ExecutionSpace");
46+
"ger: AViewType memory space must be accessible from ExecutionSpace");
4747
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename XViewType::memory_space>::accessible,
48-
"XViewType memory space must be accessible from ExecutionSpace");
48+
"ger: XViewType memory space must be accessible from ExecutionSpace");
4949
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename YViewType::memory_space>::accessible,
50-
"YViewType memory space must be accessible from ExecutionSpace");
50+
"ger: YViewType memory space must be accessible from ExecutionSpace");
5151

52-
static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
53-
static_assert(Kokkos::is_view<XViewType>::value, "XViewType must be a Kokkos::View.");
54-
static_assert(Kokkos::is_view<YViewType>::value, "YViewType must be a Kokkos::View.");
52+
static_assert(Kokkos::is_view<AViewType>::value, "ger: AViewType must be a Kokkos::View.");
53+
static_assert(Kokkos::is_view<XViewType>::value, "ger: XViewType must be a Kokkos::View.");
54+
static_assert(Kokkos::is_view<YViewType>::value, "ger: YViewType must be a Kokkos::View.");
5555

56-
static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
57-
static_assert(static_cast<int>(XViewType::rank) == 1, "XViewType must have rank 1.");
58-
static_assert(static_cast<int>(YViewType::rank) == 1, "YViewType must have rank 1.");
56+
static_assert(static_cast<int>(AViewType::rank) == 2, "ger: AViewType must have rank 2.");
57+
static_assert(static_cast<int>(XViewType::rank) == 1, "ger: XViewType must have rank 1.");
58+
static_assert(static_cast<int>(YViewType::rank) == 1, "ger: YViewType must have rank 1.");
59+
60+
static_assert(std::is_same_v<typename AViewType::value_type, typename AViewType::non_const_value_type>,
61+
"ger: AViewType must store non const values.");
5962

6063
// Check compatibility of dimensions at run time.
6164
if ((A.extent(0) != x.extent(0)) || (A.extent(1) != y.extent(0))) {

blas/src/KokkosBlas3_trmm.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,14 @@ namespace KokkosBlas {
6666
template <class execution_space, class AViewType, class BViewType>
6767
void trmm(const execution_space& space, const char side[], const char uplo[], const char trans[], const char diag[],
6868
typename BViewType::const_value_type& alpha, const AViewType& A, const BViewType& B) {
69-
static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
70-
static_assert(Kokkos::is_view<BViewType>::value, "BViewType must be a Kokkos::View.");
71-
static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
72-
static_assert(static_cast<int>(BViewType::rank) == 2, "BViewType must have rank 2.");
69+
static_assert(Kokkos::is_execution_space_v<execution_space>,
70+
"trmm: execution_space must be a Kokkos::execution_space.");
71+
static_assert(Kokkos::is_view_v<AViewType>,
72+
"trmm: AViewType must be a "
73+
"Kokkos::View.");
74+
static_assert(Kokkos::is_view_v<BViewType>, "trmm: BViewType must be a Kokkos::View.");
75+
static_assert(static_cast<int>(AViewType::rank) == 2, "trmm: AViewType must have rank 2.");
76+
static_assert(static_cast<int>(BViewType::rank) == 2, "trmm: BViewType must have rank 2.");
7377

7478
// Check validity of indicator argument
7579
bool valid_side = (side[0] == 'L') || (side[0] == 'l') || (side[0] == 'R') || (side[0] == 'r');

0 commit comments

Comments
 (0)