20
20
#include < Kokkos_Core.hpp>
21
21
#include < Kokkos_InnerProductSpaceTraits.hpp>
22
22
#include < KokkosBlas1_scal_spec.hpp>
23
+ #include < KokkosKernels_AlwaysFalse.hpp>
24
+ #include < KokkosBlas1_scal_unified_scalar_view_impl.hpp>
25
+ #include < KokkosKernels_ScalarHint.hpp>
23
26
24
27
#ifndef KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL
25
28
#define KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL 2
28
31
namespace KokkosBlas {
29
32
namespace Impl {
30
33
31
- // Single-vector version of MV_Scal_Functor. By default, a is still a
34
+
35
+ // Single-vector version of MV_Scal_Functor.
36
+ // a has been unified into either a scalar or a 0D view
37
+
32
38
// 1-D View. Below is a partial specialization that lets a be a
33
39
// scalar. This functor computes any of the following:
34
40
//
@@ -42,7 +48,7 @@ namespace Impl {
42
48
// Any literal coefficient of zero has BLAS semantics of ignoring the
43
49
// corresponding (multi)vector entry. This does not apply to
44
50
// coefficients in the a vector, if used.
45
- template <class RV , class AV , class XV , int scalar_x , class SizeType >
51
+ template <class RV , class AV , class XV , KokkosKernels::Impl::ScalarHint ALPHA_HINT , class SizeType >
46
52
struct V_Scal_Functor {
47
53
typedef SizeType size_type;
48
54
typedef Kokkos::ArithTraits<typename RV::non_const_value_type> ATS;
@@ -51,46 +57,55 @@ struct V_Scal_Functor {
51
57
XV m_x;
52
58
AV m_a;
53
59
54
- V_Scal_Functor (const RV& r, const XV& x, const AV& a,
55
- const SizeType startingColumn)
60
+ V_Scal_Functor (const RV& r, const XV& x, const AV& a, const SizeType startingColumn)
56
61
: m_r(r), m_x(x), m_a(a) {
57
62
static_assert (Kokkos::is_view<RV>::value,
58
63
" V_Scal_Functor: RV is not a Kokkos::View." );
59
- static_assert (Kokkos::is_view<AV>::value,
60
- " V_Scal_Functor: AV is not a Kokkos::View." );
64
+
65
+ // TODO: static assert truths about AV
66
+
61
67
static_assert (Kokkos::is_view<XV>::value,
62
68
" V_Scal_Functor: XV is not a Kokkos::View." );
63
69
static_assert (RV::rank == 1 , " V_Scal_Functor: RV is not rank 1." );
64
- static_assert (AV::rank == 1 , " V_Scal_Functor: AV is not rank 1." );
65
70
static_assert (XV::rank == 1 , " V_Scal_Functor: XV is not rank 1." );
66
71
67
- if (startingColumn != 0 ) {
72
+
73
+ if constexpr (Kokkos::is_view_v<AV>) {
74
+ if (startingColumn != 0 ) {
68
75
m_a = Kokkos::subview (
69
76
a,
70
77
std::make_pair (startingColumn, static_cast <SizeType>(a.extent (0 ))));
78
+ }
71
79
}
72
80
}
73
81
74
82
KOKKOS_INLINE_FUNCTION
75
83
void operator ()(const size_type& i) const {
76
- // scalar_x is a compile-time constant (since it is a template
84
+
85
+ using ScalarHint = KokkosKernels::Impl::ScalarHint;
86
+
87
+ // scalar_a is a compile-time constant (since it is a template
77
88
// parameter), so the compiler should evaluate these branches at
78
89
// compile time.
79
- if (scalar_x == 0 ) {
90
+ if constexpr (ALPHA_HINT == ScalarHint::zero ) {
80
91
m_r (i) = ATS::zero ();
81
92
}
82
- if (scalar_x == - 1 ) {
93
+ else if constexpr (ALPHA_HINT == ScalarHint::neg_one ) {
83
94
m_r (i) = -m_x (i);
84
95
}
85
- if (scalar_x == 1 ) {
96
+ else if constexpr (ALPHA_HINT == ScalarHint::pos_one ) {
86
97
m_r (i) = m_x (i);
87
98
}
88
- if (scalar_x == 2 ) {
89
- m_r (i) = m_a (0 ) * m_x (i);
99
+ else if constexpr (ALPHA_HINT == ScalarHint::none) {
100
+ m_r (i) = KokkosBlas::Impl::as_scalar (m_a) * m_x (i);
101
+ }
102
+ else {
103
+ static_assert (KokkosKernels::Impl::always_false_v<AV>, " Unexpected value for ALPHA_HINT" );
90
104
}
91
105
}
92
106
};
93
107
108
+ #if 0
94
109
// Partial specialization of V_Scal_Functor that lets a be a scalar
95
110
// (rather than a 1-D View, as in the most general version above).
96
111
// This functor computes any of the following:
@@ -128,13 +143,16 @@ struct V_Scal_Functor<RV, typename XV::non_const_value_type, XV, scalar_x,
128
143
}
129
144
}
130
145
};
146
+ #endif
131
147
132
148
// Variant of MV_Scal_Generic for single vectors (1-D Views) r and x.
133
149
// As above, av is either a 1-D View (and only its first entry will be
134
150
// read), or a scalar.
135
151
template <class execution_space , class RV , class AV , class XV , class SizeType >
136
152
void V_Scal_Generic (const execution_space& space, const RV& r, const AV& av,
137
- const XV& x, const SizeType startingColumn, int a = 2 ) {
153
+ const XV& x,
154
+ const SizeType startingColumn,
155
+ const KokkosKernels::Impl::ScalarHint &alphaHint) {
138
156
static_assert (Kokkos::is_view<RV>::value,
139
157
" V_Scal_Generic: RV is not a Kokkos::View." );
140
158
static_assert (Kokkos::is_view<XV>::value,
@@ -145,24 +163,23 @@ void V_Scal_Generic(const execution_space& space, const RV& r, const AV& av,
145
163
const SizeType numRows = x.extent (0 );
146
164
Kokkos::RangePolicy<execution_space, SizeType> policy (space, 0 , numRows);
147
165
148
- if (a == 0 ) {
149
- V_Scal_Functor<RV, AV, XV, 0 , SizeType> op (r, x, av, startingColumn);
166
+ if (alphaHint == KokkosKernels::Impl::ScalarHint::zero ) {
167
+ V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::zero , SizeType> op (r, x, av, startingColumn);
150
168
Kokkos::parallel_for (" KokkosBlas::Scal::S0" , policy, op);
151
169
return ;
152
170
}
153
- if (a == - 1 ) {
154
- V_Scal_Functor<RV, AV, XV, - 1 , SizeType> op (r, x, av, startingColumn);
171
+ else if (alphaHint == KokkosKernels::Impl::ScalarHint::neg_one ) {
172
+ V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::neg_one , SizeType> op (r, x, av, startingColumn);
155
173
Kokkos::parallel_for (" KokkosBlas::Scal::S1" , policy, op);
156
174
return ;
157
175
}
158
- if (a == 1 ) {
159
- V_Scal_Functor<RV, AV, XV, 1 , SizeType> op (r, x, av, startingColumn);
176
+ else if (alphaHint == KokkosKernels::Impl::ScalarHint::pos_one ) {
177
+ V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::pos_one , SizeType> op (r, x, av, startingColumn);
160
178
Kokkos::parallel_for (" KokkosBlas::Scal::S2" , policy, op);
161
179
return ;
162
180
}
163
181
164
- // a arbitrary (not -1, 0, or 1)
165
- V_Scal_Functor<RV, AV, XV, 2 , SizeType> op (r, x, av, startingColumn);
182
+ V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::none, SizeType> op (r, x, av, startingColumn);
166
183
Kokkos::parallel_for (" KokkosBlas::Scal::S3" , policy, op);
167
184
}
168
185
0 commit comments