1
+ // @HEADER
2
+ // ************************************************************************
3
+ //
4
+ // Kokkos v. 4.0
5
+ // Copyright (2022) National Technology & Engineering
6
+ // Solutions of Sandia, LLC (NTESS).
7
+ //
8
+ // Under the terms of Contract DE-NA0003525 with NTESS,
9
+ // the U.S. Government retains certain rights in this software.
10
+ //
11
+ // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12
+ // See https://kokkos.org/LICENSE for license information.
13
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14
+ //
15
+ // @HEADER
16
+ #ifndef KOKKOSBLAS1_SCAL_UNIFIED_SCALAR_VIEW_IMPL
17
+ #define KOKKOSBLAS1_SCAL_UNIFIED_SCALAR_VIEW_IMPL
18
+
19
+ #include < KokkosKernels_config.h>
20
+ #include < Kokkos_Core.hpp>
21
+
22
+ /* ! \brief
23
+
24
+
25
+ Implements the following table:
26
+
27
+
28
+ Row | RMV | AV | XMV | alpha_type
29
+ 1 | Rank-1 | S | Rank-1 | S
30
+ 2 | Rank-2 | S | Rank-2 | S
31
+ 3 | Rank-1 | View<S, host> | Rank-1 | S
32
+ 4 | Rank-2 | View<S, host> | Rank-2 | S
33
+ 5 | Rank-1 | View<S, dev> | Rank-1 | View<S, dev>
34
+ 6 | Rank-2 | View<S, dev> | Rank-2 | View<S, dev>
35
+ 7 | Rank-1 | View<S[1], host> | Rank-1 | S
36
+ 8 | Rank-2 | View<S[1], host> | Rank-2 | S
37
+ 9 | Rank-1 | View<S*, host> | Rank-1 | S
38
+ 10 | Rank-2 | View<S*, host> | Rank-2 | View<S*, host>
39
+ 11 | Rank-1 | View<S[1], dev> | Rank-1 | View<S, dev>
40
+ 12 | Rank-1 | View<S*, dev> | Rank-1 | View<S, dev>
41
+ 13 | Rank-2 | View<S[1], dev> | Rank-2 | View<S, dev>
42
+ 14 | Rank-2 | View<S*, dev> | Rank-2 | View<S*, dev>
43
+
44
+ See comments on the implementation below for each rows
45
+ */
46
+
47
+ namespace KokkosKernels ::Impl {
48
+
49
+ template <typename T, typename ExecSpace, typename Enable = void >
50
+ struct is_host : std::false_type {};
51
+ template <typename T, typename ExecSpace>
52
+ struct is_host <
53
+ T, ExecSpace,
54
+ std::enable_if_t <Kokkos::is_view_v<T> &&
55
+ !Kokkos::SpaceAccessibility<
56
+ ExecSpace, typename T::memory_space>::accessible>>
57
+ : std::true_type {};
58
+ template <typename T, typename ExecSpace>
59
+ constexpr inline bool is_host_v = is_host<T, ExecSpace>::value;
60
+
61
+ template <typename T, typename ExecSpace, typename Enable = void >
62
+ struct is_rank_0_host : std::false_type {};
63
+ template <typename T, typename ExecSpace>
64
+ struct is_rank_0_host <T, ExecSpace,
65
+ std::enable_if_t <is_host_v<T, ExecSpace> && T::rank == 0 >>
66
+ : std::true_type {};
67
+ template <typename T, typename ExecSpace>
68
+ constexpr inline bool is_rank_0_host_v = is_rank_0_host<T, ExecSpace>::value;
69
+
70
+ template <typename T, typename ExecSpace, typename Enable = void >
71
+ struct is_rank_1_host : std::false_type {};
72
+ template <typename T, typename ExecSpace>
73
+ struct is_rank_1_host <T, ExecSpace,
74
+ std::enable_if_t <is_host_v<T, ExecSpace> && T::rank == 1 >>
75
+ : std::true_type {};
76
+ template <typename T, typename ExecSpace>
77
+ constexpr inline bool is_rank_1_host_v = is_rank_1_host<T, ExecSpace>::value;
78
+
79
+ template <typename T, typename ExecSpace, typename Enable = void >
80
+ struct is_rank_1_host_static : std::false_type {};
81
+ template <typename T, typename ExecSpace>
82
+ struct is_rank_1_host_static <T, ExecSpace,
83
+ std::enable_if_t <is_rank_1_host_v<T, ExecSpace> &&
84
+ T::static_extent (0 ) == 1>>
85
+ : std::true_type {};
86
+ template <typename T, typename ExecSpace>
87
+ constexpr inline bool is_rank_1_host_static_v =
88
+ is_rank_1_host_static<T, ExecSpace>::value;
89
+
90
+ template <typename T, typename ExecSpace, typename Enable = void >
91
+ struct is_dev : std::false_type {};
92
+ template <typename T, typename ExecSpace>
93
+ struct is_dev <
94
+ T, ExecSpace,
95
+ std::enable_if_t <Kokkos::is_view_v<T> &&
96
+ Kokkos::SpaceAccessibility<
97
+ ExecSpace, typename T::memory_space>::accessible>>
98
+ : std::true_type {};
99
+ template <typename T, typename ExecSpace>
100
+ constexpr inline bool is_dev_v = is_dev<T, ExecSpace>::value;
101
+
102
+ template <typename T, typename ExecSpace, typename Enable = void >
103
+ struct is_rank_0_dev : std::false_type {};
104
+ template <typename T, typename ExecSpace>
105
+ struct is_rank_0_dev <T, ExecSpace,
106
+ std::enable_if_t <is_dev_v<T, ExecSpace> && T::rank == 0 >>
107
+ : std::true_type {};
108
+ template <typename T, typename ExecSpace>
109
+ constexpr inline bool is_rank_0_dev_v = is_rank_0_dev<T, ExecSpace>::value;
110
+
111
+ template <typename T, typename ExecSpace, typename Enable = void >
112
+ struct is_rank_1_dev : std::false_type {};
113
+ template <typename T, typename ExecSpace>
114
+ struct is_rank_1_dev <T, ExecSpace,
115
+ std::enable_if_t <is_dev_v<T, ExecSpace> && T::rank == 1 >>
116
+ : std::true_type {};
117
+ template <typename T, typename ExecSpace>
118
+ constexpr inline bool is_rank_1_dev_v = is_rank_1_dev<T, ExecSpace>::value;
119
+
120
+ template <typename T, typename ExecSpace, typename Enable = void >
121
+ struct is_rank_1_dev_static : std::false_type {};
122
+ template <typename T, typename ExecSpace>
123
+ struct is_rank_1_dev_static <
124
+ T, ExecSpace,
125
+ std::enable_if_t <is_rank_1_dev_v<T, ExecSpace> && T::static_extent(0 ) == 1 >>
126
+ : std::true_type {};
127
+ template <typename T, typename ExecSpace>
128
+ constexpr inline bool is_rank_1_dev_static_v =
129
+ is_rank_1_dev_static<T, ExecSpace>::value;
130
+
131
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace,
132
+ typename Enable = void >
133
+ struct scal_unified_scalar_view ;
134
+
135
+ // Rows 1,2: AV is a scalar
136
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
137
+ struct scal_unified_scalar_view <RMV, AV, XMV, ExecSpace,
138
+ std::enable_if_t <!Kokkos::is_view_v<AV>>> {
139
+ using alpha_type = AV;
140
+
141
+ static alpha_type from (const AV &av) { return av; }
142
+ };
143
+
144
+ // Rows 3,4: AV is a rank 0 host view
145
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
146
+ struct scal_unified_scalar_view <
147
+ RMV, AV, XMV, ExecSpace,
148
+ std::enable_if_t <is_rank_0_host_v<AV, ExecSpace>>> {
149
+ using alpha_type = typename AV::data_type;
150
+
151
+ static alpha_type from (const AV &av) { return av (); }
152
+ };
153
+
154
+ // Rows 5,6: AV is a rank 0 device view
155
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
156
+ struct scal_unified_scalar_view <
157
+ RMV, AV, XMV, ExecSpace, std::enable_if_t <is_rank_0_dev_v<AV, ExecSpace>>> {
158
+ using alpha_type = Kokkos::View<const typename AV::data_type, typename AV::memory_space, Kokkos::MemoryUnmanaged>;
159
+
160
+ static alpha_type from (const AV &av) { return av; }
161
+ };
162
+
163
+ // Rows 7,8: AV is a rank 1 host view with known extent
164
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
165
+ struct scal_unified_scalar_view <
166
+ RMV, AV, XMV, ExecSpace,
167
+ std::enable_if_t <is_rank_1_host_static_v<AV, ExecSpace>>> {
168
+
169
+ // FIXME: const?
170
+ using alpha_type = typename AV::value_type;
171
+
172
+ static alpha_type from (const AV &av) { return av (0 ); }
173
+ };
174
+
175
+ // Row 9: AV is a rank 1 host view of unknown size, but we assume it's
176
+ // a single scalar since XMV and YMV are rank 1
177
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
178
+ struct scal_unified_scalar_view <
179
+ RMV, AV, XMV, ExecSpace,
180
+ std::enable_if_t <is_rank_1_host_v<AV, ExecSpace> && XMV::rank == 1 &&
181
+ RMV::rank == 1 >> {
182
+
183
+ // FIXME: const?
184
+ using alpha_type = typename AV::value_type;
185
+
186
+ static alpha_type from (const AV &av) { return av (0 ); }
187
+ };
188
+
189
+ // Row 10: AV is a rank 1 host view of unknown size, and we assume
190
+ // each element is to scale a vector in RMV and XMV
191
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
192
+ struct scal_unified_scalar_view <
193
+ RMV, AV, XMV, ExecSpace,
194
+ std::enable_if_t <is_rank_1_host_v<AV, ExecSpace> && XMV::rank == 2 &&
195
+ RMV::rank == 2 >> {
196
+
197
+ // FIXME: const?
198
+ using alpha_type = Kokkos::View<typename AV::data_type, typename AV::memory_space, Kokkos::MemoryUnmanaged>;
199
+
200
+ static alpha_type from (const AV &av) { return av; }
201
+ };
202
+
203
+ // Row 11, 12: AV is a rank 1 dev view, but we assume its
204
+ // a single scalar since XMV and YMV are rank 1
205
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
206
+ struct scal_unified_scalar_view <
207
+ RMV, AV, XMV, ExecSpace,
208
+ std::enable_if_t <is_rank_1_dev_v<AV, ExecSpace> && XMV::rank == 1 &&
209
+ RMV::rank == 1 >> {
210
+
211
+ using alpha_type =
212
+ Kokkos::View<const typename AV::value_type, typename AV::memory_space,
213
+ Kokkos::MemoryUnmanaged>;
214
+
215
+ static alpha_type from (const AV &av) { return Kokkos::subview (av, 0 ); }
216
+ };
217
+
218
+ // Row 13: AV is a rank 1 dev view of static size,
219
+ // so its a single scalar
220
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
221
+ struct scal_unified_scalar_view <
222
+ RMV, AV, XMV, ExecSpace,
223
+ std::enable_if_t <is_rank_1_dev_static_v<AV, ExecSpace>>> {
224
+
225
+ // FIXME: const?
226
+ using alpha_type =
227
+ Kokkos::View<const typename AV::value_type, typename AV::memory_space,
228
+ Kokkos::MemoryUnmanaged>;
229
+
230
+ static alpha_type from (const AV &av) { return Kokkos::subview (av, 0 ); }
231
+ };
232
+
233
+ // Row 14: AV is a rank 1 dev view of unknown size,
234
+ // and XMV and YMV are rank 2, so assume each entry is
235
+ // used to scale each vector
236
+ template <typename RMV, typename AV, typename XMV, typename ExecSpace>
237
+ struct scal_unified_scalar_view <
238
+ RMV, AV, XMV, ExecSpace,
239
+ std::enable_if_t <is_rank_1_dev_v<AV, ExecSpace> && XMV::rank == 2 &&
240
+ RMV::rank == 2 >> {
241
+ // FIXME: const?
242
+ using alpha_type = Kokkos::View<typename AV::data_type, typename AV::memory_space, Kokkos::MemoryUnmanaged>;
243
+
244
+ static alpha_type from (const AV &av) { return av; }
245
+ };
246
+
247
+ } // namespace KokkosKernels::Impl
248
+
249
+ #endif // KOKKOSBLAS1_SCAL_UNIFIED_SCALAR_VIEW_IMPL
0 commit comments