Skip to content

Commit d421c5c

Browse files
committed
some experiments with scalar view unification
1 parent 4d9dd63 commit d421c5c

4 files changed

+271
-34
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2022) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
#ifndef KOKKOSBLAS1_SCAL_UNIFIED_SCALAR_VIEW_IMPL
17+
#define KOKKOSBLAS1_SCAL_UNIFIED_SCALAR_VIEW_IMPL
18+
19+
#include <KokkosKernels_config.h>
20+
#include <Kokkos_Core.hpp>
21+
22+
/*! \brief
23+
24+
25+
Implements the following table:
26+
27+
28+
Row | RMV | AV | XMV | alpha_type
29+
1 | Rank-1 | S | Rank-1 | S
30+
2 | Rank-2 | S | Rank-2 | S
31+
3 | Rank-1 | View<S, host> | Rank-1 | S
32+
4 | Rank-2 | View<S, host> | Rank-2 | S
33+
5 | Rank-1 | View<S, dev> | Rank-1 | View<S, dev>
34+
6 | Rank-2 | View<S, dev> | Rank-2 | View<S, dev>
35+
7 | Rank-1 | View<S[1], host> | Rank-1 | S
36+
8 | Rank-2 | View<S[1], host> | Rank-2 | S
37+
9 | Rank-1 | View<S*, host> | Rank-1 | S
38+
10 | Rank-2 | View<S*, host> | Rank-2 | View<S*, host>
39+
11 | Rank-1 | View<S[1], dev> | Rank-1 | View<S, dev>
40+
12 | Rank-1 | View<S*, dev> | Rank-1 | View<S, dev>
41+
13 | Rank-2 | View<S[1], dev> | Rank-2 | View<S, dev>
42+
14 | Rank-2 | View<S*, dev> | Rank-2 | View<S*, dev>
43+
44+
See comments on the implementation below for each rows
45+
*/
46+
47+
namespace KokkosKernels::Impl {
48+
49+
template <typename T, typename ExecSpace, typename Enable = void>
50+
struct is_host : std::false_type {};
51+
template <typename T, typename ExecSpace>
52+
struct is_host<
53+
T, ExecSpace,
54+
std::enable_if_t<Kokkos::is_view_v<T> &&
55+
!Kokkos::SpaceAccessibility<
56+
ExecSpace, typename T::memory_space>::accessible>>
57+
: std::true_type {};
58+
template <typename T, typename ExecSpace>
59+
constexpr inline bool is_host_v = is_host<T, ExecSpace>::value;
60+
61+
template <typename T, typename ExecSpace, typename Enable = void>
62+
struct is_rank_0_host : std::false_type {};
63+
template <typename T, typename ExecSpace>
64+
struct is_rank_0_host<T, ExecSpace,
65+
std::enable_if_t<is_host_v<T, ExecSpace> && T::rank == 0>>
66+
: std::true_type {};
67+
template <typename T, typename ExecSpace>
68+
constexpr inline bool is_rank_0_host_v = is_rank_0_host<T, ExecSpace>::value;
69+
70+
template <typename T, typename ExecSpace, typename Enable = void>
71+
struct is_rank_1_host : std::false_type {};
72+
template <typename T, typename ExecSpace>
73+
struct is_rank_1_host<T, ExecSpace,
74+
std::enable_if_t<is_host_v<T, ExecSpace> && T::rank == 1>>
75+
: std::true_type {};
76+
template <typename T, typename ExecSpace>
77+
constexpr inline bool is_rank_1_host_v = is_rank_1_host<T, ExecSpace>::value;
78+
79+
template <typename T, typename ExecSpace, typename Enable = void>
80+
struct is_rank_1_host_static : std::false_type {};
81+
template <typename T, typename ExecSpace>
82+
struct is_rank_1_host_static<T, ExecSpace,
83+
std::enable_if_t<is_rank_1_host_v<T, ExecSpace> &&
84+
T::static_extent(0) == 1>>
85+
: std::true_type {};
86+
template <typename T, typename ExecSpace>
87+
constexpr inline bool is_rank_1_host_static_v =
88+
is_rank_1_host_static<T, ExecSpace>::value;
89+
90+
template <typename T, typename ExecSpace, typename Enable = void>
91+
struct is_dev : std::false_type {};
92+
template <typename T, typename ExecSpace>
93+
struct is_dev<
94+
T, ExecSpace,
95+
std::enable_if_t<Kokkos::is_view_v<T> &&
96+
Kokkos::SpaceAccessibility<
97+
ExecSpace, typename T::memory_space>::accessible>>
98+
: std::true_type {};
99+
template <typename T, typename ExecSpace>
100+
constexpr inline bool is_dev_v = is_dev<T, ExecSpace>::value;
101+
102+
template <typename T, typename ExecSpace, typename Enable = void>
103+
struct is_rank_0_dev : std::false_type {};
104+
template <typename T, typename ExecSpace>
105+
struct is_rank_0_dev<T, ExecSpace,
106+
std::enable_if_t<is_dev_v<T, ExecSpace> && T::rank == 0>>
107+
: std::true_type {};
108+
template <typename T, typename ExecSpace>
109+
constexpr inline bool is_rank_0_dev_v = is_rank_0_dev<T, ExecSpace>::value;
110+
111+
template <typename T, typename ExecSpace, typename Enable = void>
112+
struct is_rank_1_dev : std::false_type {};
113+
template <typename T, typename ExecSpace>
114+
struct is_rank_1_dev<T, ExecSpace,
115+
std::enable_if_t<is_dev_v<T, ExecSpace> && T::rank == 1>>
116+
: std::true_type {};
117+
template <typename T, typename ExecSpace>
118+
constexpr inline bool is_rank_1_dev_v = is_rank_1_dev<T, ExecSpace>::value;
119+
120+
template <typename T, typename ExecSpace, typename Enable = void>
121+
struct is_rank_1_dev_static : std::false_type {};
122+
template <typename T, typename ExecSpace>
123+
struct is_rank_1_dev_static<
124+
T, ExecSpace,
125+
std::enable_if_t<is_rank_1_dev_v<T, ExecSpace> && T::static_extent(0) == 1>>
126+
: std::true_type {};
127+
template <typename T, typename ExecSpace>
128+
constexpr inline bool is_rank_1_dev_static_v =
129+
is_rank_1_dev_static<T, ExecSpace>::value;
130+
131+
template <typename RMV, typename AV, typename XMV, typename ExecSpace,
132+
typename Enable = void>
133+
struct scal_unified_scalar_view;
134+
135+
// Rows 1,2: AV is a scalar
136+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
137+
struct scal_unified_scalar_view<RMV, AV, XMV, ExecSpace,
138+
std::enable_if_t<!Kokkos::is_view_v<AV>>> {
139+
using alpha_type = AV;
140+
141+
static alpha_type from(const AV &av) { return av; }
142+
};
143+
144+
// Rows 3,4: AV is a rank 0 host view
145+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
146+
struct scal_unified_scalar_view<
147+
RMV, AV, XMV, ExecSpace,
148+
std::enable_if_t<is_rank_0_host_v<AV, ExecSpace>>> {
149+
using alpha_type = typename AV::data_type;
150+
151+
static alpha_type from(const AV &av) { return av(); }
152+
};
153+
154+
// Rows 5,6: AV is a rank 0 device view
155+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
156+
struct scal_unified_scalar_view<
157+
RMV, AV, XMV, ExecSpace, std::enable_if_t<is_rank_0_dev_v<AV, ExecSpace>>> {
158+
using alpha_type = Kokkos::View<const typename AV::data_type, typename AV::memory_space, Kokkos::MemoryUnmanaged>;
159+
160+
static alpha_type from(const AV &av) { return av; }
161+
};
162+
163+
// Rows 7,8: AV is a rank 1 host view with known extent
164+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
165+
struct scal_unified_scalar_view<
166+
RMV, AV, XMV, ExecSpace,
167+
std::enable_if_t<is_rank_1_host_static_v<AV, ExecSpace>>> {
168+
169+
// FIXME: const?
170+
using alpha_type = typename AV::value_type;
171+
172+
static alpha_type from(const AV &av) { return av(0); }
173+
};
174+
175+
// Row 9: AV is a rank 1 host view of unknown size, but we assume it's
176+
// a single scalar since XMV and YMV are rank 1
177+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
178+
struct scal_unified_scalar_view<
179+
RMV, AV, XMV, ExecSpace,
180+
std::enable_if_t<is_rank_1_host_v<AV, ExecSpace> && XMV::rank == 1 &&
181+
RMV::rank == 1>> {
182+
183+
// FIXME: const?
184+
using alpha_type = typename AV::value_type;
185+
186+
static alpha_type from(const AV &av) { return av(0); }
187+
};
188+
189+
// Row 10: AV is a rank 1 host view of unknown size, and we assume
190+
// each element is to scale a vector in RMV and XMV
191+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
192+
struct scal_unified_scalar_view<
193+
RMV, AV, XMV, ExecSpace,
194+
std::enable_if_t<is_rank_1_host_v<AV, ExecSpace> && XMV::rank == 2 &&
195+
RMV::rank == 2>> {
196+
197+
// FIXME: const?
198+
using alpha_type = Kokkos::View<typename AV::data_type, typename AV::memory_space, Kokkos::MemoryUnmanaged>;
199+
200+
static alpha_type from(const AV &av) { return av; }
201+
};
202+
203+
// Row 11, 12: AV is a rank 1 dev view, but we assume its
204+
// a single scalar since XMV and YMV are rank 1
205+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
206+
struct scal_unified_scalar_view<
207+
RMV, AV, XMV, ExecSpace,
208+
std::enable_if_t<is_rank_1_dev_v<AV, ExecSpace> && XMV::rank == 1 &&
209+
RMV::rank == 1>> {
210+
211+
using alpha_type =
212+
Kokkos::View<const typename AV::value_type, typename AV::memory_space,
213+
Kokkos::MemoryUnmanaged>;
214+
215+
static alpha_type from(const AV &av) { return Kokkos::subview(av, 0); }
216+
};
217+
218+
// Row 13: AV is a rank 1 dev view of static size,
219+
// so its a single scalar
220+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
221+
struct scal_unified_scalar_view<
222+
RMV, AV, XMV, ExecSpace,
223+
std::enable_if_t<is_rank_1_dev_static_v<AV, ExecSpace>>> {
224+
225+
// FIXME: const?
226+
using alpha_type =
227+
Kokkos::View<const typename AV::value_type, typename AV::memory_space,
228+
Kokkos::MemoryUnmanaged>;
229+
230+
static alpha_type from(const AV &av) { return Kokkos::subview(av, 0); }
231+
};
232+
233+
// Row 14: AV is a rank 1 dev view of unknown size,
234+
// and XMV and YMV are rank 2, so assume each entry is
235+
// used to scale each vector
236+
template <typename RMV, typename AV, typename XMV, typename ExecSpace>
237+
struct scal_unified_scalar_view<
238+
RMV, AV, XMV, ExecSpace,
239+
std::enable_if_t<is_rank_1_dev_v<AV, ExecSpace> && XMV::rank == 2 &&
240+
RMV::rank == 2>> {
241+
// FIXME: const?
242+
using alpha_type = Kokkos::View<typename AV::data_type, typename AV::memory_space, Kokkos::MemoryUnmanaged>;
243+
244+
static alpha_type from(const AV &av) { return av; }
245+
};
246+
247+
} // namespace KokkosKernels::Impl
248+
249+
#endif // KOKKOSBLAS1_SCAL_UNIFIED_SCALAR_VIEW_IMPL

blas/src/KokkosBlas1_scal.hpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include <KokkosKernels_helpers.hpp>
2424
#include <KokkosKernels_Error.hpp>
2525

26+
#include <KokkosBlas1_scal_unified_scalar_view_impl.hpp>
27+
2628
///
2729
/// General/Host Scale
2830
///
@@ -37,7 +39,7 @@ namespace KokkosBlas {
3739
/// \tparam RMV 1-D or 2-D Kokkos::View specialization.
3840
/// \tparam XMV 1-D or 2-D Kokkos::View specialization. It must have
3941
/// the same rank as RMV.
40-
/// \tparam AV 1-D or 2-D Kokkos::View specialization.
42+
/// \tparam AV a scalar, 0-D, or 1-D Kokkos::View specialization.
4143
///
4244
/// \param space [in] the execution space instance on which the kernel will run.
4345
/// \param R [in/out] view of type RMV in which the results will be stored.
@@ -103,12 +105,20 @@ void scal(const execution_space& space, const RMV& R, const AV& a,
103105
using XMV_Internal = Kokkos::View<typename XMV::const_data_type,
104106
UnifiedXLayout, typename XMV::device_type,
105107
Kokkos::MemoryTraits<Kokkos::Unmanaged> >;
108+
109+
#if 1
110+
using AlphaUnifier = KokkosKernels::Impl::scal_unified_scalar_view<RMV_Internal, AV, XMV_Internal, execution_space>;
111+
using AV_Internal =
112+
typename AlphaUnifier::alpha_type;
113+
AV_Internal a_internal = AlphaUnifier::from(a);
114+
#else
106115
using AV_Internal =
107116
typename KokkosKernels::Impl::GetUnifiedScalarViewType<AV, XMV_Internal,
108-
true>::type;
117+
true>::type;
118+
AV_Internal a_internal = a;
119+
#endif
109120

110121
RMV_Internal R_internal = R;
111-
AV_Internal a_internal = a;
112122
XMV_Internal X_internal = X;
113123

114124
Impl::Scal<execution_space, RMV_Internal, AV_Internal, XMV_Internal>::scal(

blas/tpls/KokkosBlas1_scal_tpl_spec_decl.hpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,12 @@ KOKKOSBLAS1_CSCAL_TPL_SPEC_DECL_CUBLAS(Kokkos::LayoutLeft, Kokkos::CudaUVMSpace,
256256
namespace KokkosBlas {
257257
namespace Impl {
258258

259-
260259
/* rocBLAS documentation:
261260
"a rocBLAS handle always has one stream.""
262-
"If the handle is switching from one non-default stream to another, the old
263-
stream needs to be synchronized...next...rocblas_set_stream"
264-
Basically this means if we're switching streams, we have to fence the old one
265-
first.
266-
We also set the handle's pointer mode appropriately before invoking BLAS.
261+
"If the handle is switching from one non-default stream to another, the
262+
old stream needs to be synchronized...next...rocblas_set_stream" Basically
263+
this means if we're switching streams, we have to fence the old one first. We
264+
also set the handle's pointer mode appropriately before invoking BLAS.
267265
268266
// push_pointer_mode
269267
*/
@@ -299,18 +297,16 @@ namespace Impl {
299297
const size_type numElems = X.extent(0); \
300298
if ((numElems < static_cast<size_type>(INT_MAX)) && \
301299
(R.data() == X.data())) { \
302-
std::cerr << __FILE__ << ":" << __LINE__ << " rocBLAS scal(1)!\n"; \
303300
scal_print_specialization<RV, AS, XV>(); \
304301
const int N = static_cast<int>(numElems); \
305302
constexpr int one = 1; \
306303
KokkosBlas::Impl::RocBlasSingleton& s = \
307304
KokkosBlas::Impl::RocBlasSingleton::singleton(); \
308-
hipStream_t cur; \
309-
KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \
310-
rocblas_get_stream(s.handle, &cur)); \
311-
if (cur != space.hip_stream()) { \
312-
execution_space(cur).fence(); \
313-
} \
305+
hipStream_t cur; \
306+
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_get_stream(s.handle, &cur)); \
307+
if (cur != space.hip_stream()) { \
308+
execution_space(cur).fence(); \
309+
} \
314310
KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \
315311
rocblas_set_stream(s.handle, space.hip_stream())); \
316312
rocblas_pointer_mode pointer_mode; \

common/src/KokkosKernels_UnifiedScalarView.hpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,6 @@
2828
namespace KokkosKernels {
2929
namespace Impl {
3030

31-
32-
template <typename ScalarLike, typename = void>
33-
struct is_scalar : std::false_type {};
34-
35-
// kokkos complex
36-
template <typename T>
37-
struct is_scalar<T, std::enable_if_t<is_kokkos_complex_v<T>>> : std::true_type {};
38-
39-
// other scalars
40-
template <typename ScalarLike>
41-
struct is_scalar<ScalarLike, std::enable_if_t<std::is_integral_v<ScalarLike> || std::is_floating_point_v<ScalarLike>>> : std::true_type {};
42-
43-
template <typename ScalarLike>
44-
inline constexpr bool is_scalar_v = is_scalar<ScalarLike>::value;
45-
46-
47-
48-
4931
template <typename ScalarLike, typename = void>
5032
struct is_scalar_view : std::false_type {};
5133

0 commit comments

Comments
 (0)