Skip to content

Commit 56ae59c

Browse files
committed
.
1 parent d421c5c commit 56ae59c

6 files changed

+290
-165
lines changed

blas/impl/KokkosBlas1_scal_impl.hpp

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include <Kokkos_Core.hpp>
2121
#include <Kokkos_InnerProductSpaceTraits.hpp>
2222
#include <KokkosBlas1_scal_spec.hpp>
23+
#include <KokkosKernels_AlwaysFalse.hpp>
24+
#include <KokkosBlas1_scal_unified_scalar_view_impl.hpp>
25+
#include <KokkosKernels_ScalarHint.hpp>
2326

2427
#ifndef KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL
2528
#define KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL 2
@@ -28,7 +31,10 @@
2831
namespace KokkosBlas {
2932
namespace Impl {
3033

31-
// Single-vector version of MV_Scal_Functor. By default, a is still a
34+
35+
// Single-vector version of MV_Scal_Functor.
36+
// a has been unified into either a scalar or a 0D view
37+
3238
// 1-D View. Below is a partial specialization that lets a be a
3339
// scalar. This functor computes any of the following:
3440
//
@@ -42,7 +48,7 @@ namespace Impl {
4248
// Any literal coefficient of zero has BLAS semantics of ignoring the
4349
// corresponding (multi)vector entry. This does not apply to
4450
// coefficients in the a vector, if used.
45-
template <class RV, class AV, class XV, int scalar_x, class SizeType>
51+
template <class RV, class AV, class XV, KokkosKernels::Impl::ScalarHint ALPHA_HINT, class SizeType>
4652
struct V_Scal_Functor {
4753
typedef SizeType size_type;
4854
typedef Kokkos::ArithTraits<typename RV::non_const_value_type> ATS;
@@ -51,46 +57,55 @@ struct V_Scal_Functor {
5157
XV m_x;
5258
AV m_a;
5359

54-
V_Scal_Functor(const RV& r, const XV& x, const AV& a,
55-
const SizeType startingColumn)
60+
V_Scal_Functor(const RV& r, const XV& x, const AV& a, const SizeType startingColumn)
5661
: m_r(r), m_x(x), m_a(a) {
5762
static_assert(Kokkos::is_view<RV>::value,
5863
"V_Scal_Functor: RV is not a Kokkos::View.");
59-
static_assert(Kokkos::is_view<AV>::value,
60-
"V_Scal_Functor: AV is not a Kokkos::View.");
64+
65+
// TODO: static assert truths about AV
66+
6167
static_assert(Kokkos::is_view<XV>::value,
6268
"V_Scal_Functor: XV is not a Kokkos::View.");
6369
static_assert(RV::rank == 1, "V_Scal_Functor: RV is not rank 1.");
64-
static_assert(AV::rank == 1, "V_Scal_Functor: AV is not rank 1.");
6570
static_assert(XV::rank == 1, "V_Scal_Functor: XV is not rank 1.");
6671

67-
if (startingColumn != 0) {
72+
73+
if constexpr (Kokkos::is_view_v<AV>) {
74+
if (startingColumn != 0) {
6875
m_a = Kokkos::subview(
6976
a,
7077
std::make_pair(startingColumn, static_cast<SizeType>(a.extent(0))));
78+
}
7179
}
7280
}
7381

7482
KOKKOS_INLINE_FUNCTION
7583
void operator()(const size_type& i) const {
76-
// scalar_x is a compile-time constant (since it is a template
84+
85+
using ScalarHint = KokkosKernels::Impl::ScalarHint;
86+
87+
// scalar_a is a compile-time constant (since it is a template
7788
// parameter), so the compiler should evaluate these branches at
7889
// compile time.
79-
if (scalar_x == 0) {
90+
if constexpr (ALPHA_HINT == ScalarHint::zero) {
8091
m_r(i) = ATS::zero();
8192
}
82-
if (scalar_x == -1) {
93+
else if constexpr (ALPHA_HINT == ScalarHint::neg_one) {
8394
m_r(i) = -m_x(i);
8495
}
85-
if (scalar_x == 1) {
96+
else if constexpr (ALPHA_HINT == ScalarHint::pos_one) {
8697
m_r(i) = m_x(i);
8798
}
88-
if (scalar_x == 2) {
89-
m_r(i) = m_a(0) * m_x(i);
99+
else if constexpr (ALPHA_HINT == ScalarHint::none) {
100+
m_r(i) = KokkosBlas::Impl::as_scalar(m_a) * m_x(i);
101+
}
102+
else {
103+
static_assert(KokkosKernels::Impl::always_false_v<AV>, "Unexpected value for ALPHA_HINT");
90104
}
91105
}
92106
};
93107

108+
#if 0
94109
// Partial specialization of V_Scal_Functor that lets a be a scalar
95110
// (rather than a 1-D View, as in the most general version above).
96111
// This functor computes any of the following:
@@ -128,13 +143,16 @@ struct V_Scal_Functor<RV, typename XV::non_const_value_type, XV, scalar_x,
128143
}
129144
}
130145
};
146+
#endif
131147

132148
// Variant of MV_Scal_Generic for single vectors (1-D Views) r and x.
133149
// As above, av is either a 1-D View (and only its first entry will be
134150
// read), or a scalar.
135151
template <class execution_space, class RV, class AV, class XV, class SizeType>
136152
void V_Scal_Generic(const execution_space& space, const RV& r, const AV& av,
137-
const XV& x, const SizeType startingColumn, int a = 2) {
153+
const XV& x,
154+
const SizeType startingColumn,
155+
const KokkosKernels::Impl::ScalarHint &alphaHint) {
138156
static_assert(Kokkos::is_view<RV>::value,
139157
"V_Scal_Generic: RV is not a Kokkos::View.");
140158
static_assert(Kokkos::is_view<XV>::value,
@@ -145,24 +163,23 @@ void V_Scal_Generic(const execution_space& space, const RV& r, const AV& av,
145163
const SizeType numRows = x.extent(0);
146164
Kokkos::RangePolicy<execution_space, SizeType> policy(space, 0, numRows);
147165

148-
if (a == 0) {
149-
V_Scal_Functor<RV, AV, XV, 0, SizeType> op(r, x, av, startingColumn);
166+
if (alphaHint == KokkosKernels::Impl::ScalarHint::zero) {
167+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::zero, SizeType> op(r, x, av, startingColumn);
150168
Kokkos::parallel_for("KokkosBlas::Scal::S0", policy, op);
151169
return;
152170
}
153-
if (a == -1) {
154-
V_Scal_Functor<RV, AV, XV, -1, SizeType> op(r, x, av, startingColumn);
171+
else if (alphaHint == KokkosKernels::Impl::ScalarHint::neg_one) {
172+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::neg_one, SizeType> op(r, x, av, startingColumn);
155173
Kokkos::parallel_for("KokkosBlas::Scal::S1", policy, op);
156174
return;
157175
}
158-
if (a == 1) {
159-
V_Scal_Functor<RV, AV, XV, 1, SizeType> op(r, x, av, startingColumn);
176+
else if (alphaHint == KokkosKernels::Impl::ScalarHint::pos_one) {
177+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::pos_one, SizeType> op(r, x, av, startingColumn);
160178
Kokkos::parallel_for("KokkosBlas::Scal::S2", policy, op);
161179
return;
162180
}
163181

164-
// a arbitrary (not -1, 0, or 1)
165-
V_Scal_Functor<RV, AV, XV, 2, SizeType> op(r, x, av, startingColumn);
182+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::none, SizeType> op(r, x, av, startingColumn);
166183
Kokkos::parallel_for("KokkosBlas::Scal::S3", policy, op);
167184
}
168185

blas/impl/KokkosBlas1_scal_mv_impl.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <Kokkos_InnerProductSpaceTraits.hpp>
2222
#include <KokkosBlas1_scal_spec.hpp>
2323
#include <KokkosBlas1_scal_impl.hpp>
24+
#include <KokkosKernels_ScalarHint.hpp>
2425

2526
#ifndef KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL
2627
#define KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL 2
@@ -323,24 +324,25 @@ template <class execution_space, class RMV, class aVector, class XMV,
323324
int UNROLL, class SizeType>
324325
void MV_Scal_Unrolled(const execution_space& space, const RMV& r,
325326
const aVector& av, const XMV& x,
326-
const SizeType startingColumn, int a = 2) {
327-
if (a == 0) {
327+
const SizeType startingColumn,
328+
const KokkosKernels::Impl::ScalarHint &a = KokkosKernels::Impl::ScalarHint::none) {
329+
if (a == KokkosKernels::Impl::ScalarHint::zero) {
328330
MV_Scal_Unroll_Functor<RMV, aVector, XMV, 0, UNROLL, SizeType> op(
329331
r, x, av, startingColumn);
330332
const SizeType numRows = x.extent(0);
331333
Kokkos::RangePolicy<execution_space, SizeType> policy(space, 0, numRows);
332334
Kokkos::parallel_for("KokkosBlas::Scal::MV::S0", policy, op);
333335
return;
334336
}
335-
if (a == -1) {
337+
if (a == KokkosKernels::Impl::ScalarHint::neg_one) {
336338
MV_Scal_Unroll_Functor<RMV, aVector, XMV, -1, UNROLL, SizeType> op(
337339
r, x, av, startingColumn);
338340
const SizeType numRows = x.extent(0);
339341
Kokkos::RangePolicy<execution_space, SizeType> policy(space, 0, numRows);
340342
Kokkos::parallel_for("KokkosBlas::Scal::MV::S1", policy, op);
341343
return;
342344
}
343-
if (a == 1) {
345+
if (a == KokkosKernels::Impl::ScalarHint::pos_one) {
344346
MV_Scal_Unroll_Functor<RMV, aVector, XMV, 1, UNROLL, SizeType> op(
345347
r, x, av, startingColumn);
346348
const SizeType numRows = x.extent(0);
@@ -349,7 +351,6 @@ void MV_Scal_Unrolled(const execution_space& space, const RMV& r,
349351
return;
350352
}
351353

352-
// a arbitrary (not -1, 0, or 1)
353354
MV_Scal_Unroll_Functor<RMV, aVector, XMV, 2, UNROLL, SizeType> op(
354355
r, x, av, startingColumn);
355356
const SizeType numRows = x.extent(0);
@@ -420,7 +421,8 @@ void MV_Scal_Generic(const execution_space& space, const RVector& r,
420421
// coefficient(s) in av, if used.
421422
template <class execution_space, class RMV, class AV, class XMV, class SizeType>
422423
void MV_Scal_Invoke_Left(const execution_space& space, const RMV& r,
423-
const AV& av, const XMV& x, int a = 2) {
424+
const AV& av, const XMV& x,
425+
const KokkosKernels::Impl::ScalarHint &a = KokkosKernels::Impl::ScalarHint::none) {
424426
const SizeType numCols = x.extent(1);
425427

426428
#if KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL <= 2

blas/impl/KokkosBlas1_scal_spec.hpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -152,23 +152,23 @@ struct Scal<execution_space, RV, typename XV::non_const_value_type, XV, 1,
152152
#endif
153153

154154
const size_type numRows = X.extent(0);
155-
int a = 2;
155+
KokkosKernels::Impl::ScalarHint alphaHint = KokkosKernels::Impl::ScalarHint::none;
156156
if (alpha == ATA::zero()) {
157-
a = 0;
157+
alphaHint = KokkosKernels::Impl::ScalarHint::zero;
158158
} else if (alpha == -ATA::one()) {
159-
a = -1;
159+
alphaHint = KokkosKernels::Impl::ScalarHint::neg_one;
160160
} else if (alpha == ATA::one()) {
161-
a = 1;
161+
alphaHint = KokkosKernels::Impl::ScalarHint::pos_one;
162162
}
163163

164164
if (numRows < static_cast<size_type>(INT_MAX)) {
165165
typedef int index_type;
166166
V_Scal_Generic<execution_space, RV, AV, XV, index_type>(space, R, alpha,
167-
X, a);
167+
X, 0, alphaHint);
168168
} else {
169169
typedef typename XV::size_type index_type;
170170
V_Scal_Generic<execution_space, RV, AV, XV, index_type>(space, R, alpha,
171-
X, a);
171+
X, 0, alphaHint);
172172
}
173173
Kokkos::Profiling::popRegion();
174174
}
@@ -183,6 +183,7 @@ struct Scal<execution_space, RV, typename XV::non_const_value_type, XV, 1,
183183
template <class execution_space, class RMV, class AV, class XMV>
184184
struct Scal<execution_space, RMV, AV, XMV, 2, false,
185185
KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
186+
using ScalarHint = KokkosKernels::Impl::ScalarHint;
186187
typedef typename XMV::size_type size_type;
187188
typedef Kokkos::ArithTraits<typename XMV::non_const_value_type> ATA;
188189

@@ -221,16 +222,16 @@ struct Scal<execution_space, RMV, AV, XMV, 2, false,
221222

222223
const size_type numRows = X.extent(0);
223224
const size_type numCols = X.extent(1);
224-
const int a = (av.extent(0) == 0) ? 0 : 2;
225+
const ScalarHint alphaHint = (av.extent(0) == 0) ? ScalarHint::zero : ScalarHint::none;
225226
if (numRows < static_cast<size_type>(INT_MAX) &&
226227
numRows * numCols < static_cast<size_type>(INT_MAX)) {
227228
typedef int index_type;
228229
MV_Scal_Invoke_Left<execution_space, RMV, AV, XMV, index_type>(space, R,
229-
av, X, a);
230+
av, X, alphaHint);
230231
} else {
231232
typedef typename XMV::size_type index_type;
232233
MV_Scal_Invoke_Left<execution_space, RMV, AV, XMV, index_type>(space, R,
233-
av, X, a);
234+
av, X, alphaHint);
234235
}
235236
Kokkos::Profiling::popRegion();
236237
}
@@ -245,6 +246,7 @@ struct Scal<execution_space, RMV, AV, XMV, 2, false,
245246
template <class execution_space, class RMV, class XMV>
246247
struct Scal<execution_space, RMV, typename XMV::non_const_value_type, XMV, 2,
247248
false, KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
249+
using ScalarHint = KokkosKernels::Impl::ScalarHint;
248250
typedef typename XMV::non_const_value_type AV;
249251
typedef typename XMV::size_type size_type;
250252
typedef Kokkos::ArithTraits<typename XMV::non_const_value_type> ATA;
@@ -279,26 +281,26 @@ struct Scal<execution_space, RMV, typename XMV::non_const_value_type, XMV, 2,
279281

280282
const size_type numRows = X.extent(0);
281283
const size_type numCols = X.extent(1);
282-
int a = 2;
284+
ScalarHint alphaHint = ScalarHint::none;
283285
if (alpha == ATA::zero()) {
284-
a = 0;
286+
alphaHint = ScalarHint::zero;
285287
} else if (alpha == -ATA::one()) {
286-
a = -1;
288+
alphaHint = ScalarHint::neg_one;
287289
} else if (alpha == ATA::one()) {
288-
a = 1;
290+
alphaHint = ScalarHint::pos_one;
289291
}
290292

291293
if (numRows < static_cast<size_type>(INT_MAX) &&
292294
numRows * numCols < static_cast<size_type>(INT_MAX)) {
293295
typedef int index_type;
294296
MV_Scal_Invoke_Left<execution_space, RMV,
295297
typename XMV::non_const_value_type, XMV, index_type>(
296-
space, R, alpha, X, a);
298+
space, R, alpha, X, alphaHint);
297299
} else {
298300
typedef typename XMV::size_type index_type;
299301
MV_Scal_Invoke_Left<execution_space, RMV,
300302
typename XMV::non_const_value_type, XMV, index_type>(
301-
space, R, alpha, X, a);
303+
space, R, alpha, X, alphaHint);
302304
}
303305
Kokkos::Profiling::popRegion();
304306
}

0 commit comments

Comments
 (0)