Skip to content

Commit 32e4b80

Browse files
authored
Merge pull request #392 from E3SM-Project/bartgol/repack-simplification
Simplify and upgrade implementation of repack utility
2 parents 7156c94 + 853596e commit 32e4b80

File tree

1 file changed

+34
-176
lines changed

1 file changed

+34
-176
lines changed

src/pack/ekat_pack_kokkos.hpp

Lines changed: 34 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#define EKAT_PACK_KOKKOS_HPP
33

44
#include "ekat_pack.hpp"
5+
#include "ekat_pack_utils.hpp"
56
#include "ekat_kokkos_meta.hpp"
67
#include "ekat_kernel_assert.hpp"
8+
#include "ekat_scalar_traits.hpp"
79
#include "ekat_assert.hpp"
810

911
#include <vector>
@@ -290,187 +292,43 @@ scalarize (const Kokkos::View<ValueT*, Parms...>& v)
290292
return ScalarizeHelper<ValueT>::scalarize(v);
291293
}
292294

293-
// Turn a View of Pack<T,N>s into a View of Pack<T,M>s.
294-
// Requirement: the smaller number must divide the larger one:
295-
// max(M,N) % min(M,N) == 0.
296-
// Example: const auto b = repack<4>(a);
297-
298-
// Helper struct
299-
template<int N, typename OldType>
300-
struct RepackType {
301-
using type =
302-
typename std::conditional<std::is_const<OldType>::value,
303-
const Pack<std::remove_const_t<OldType>,N>,
304-
Pack<OldType,N>>::type;
305-
};
306-
307-
template<int N, typename T, int M>
308-
struct RepackType <N,Pack<T,M>>{
309-
using type = Pack<T,N>;
310-
};
311-
template<int N, typename T, int M>
312-
struct RepackType <N,const Pack<T,M>>{
313-
using type = const Pack<T,N>;
314-
};
315-
316-
// 2d shrinking
317-
template <typename NewPackT, typename OldPackT, typename... ViewProps>
318-
KOKKOS_FORCEINLINE_FUNCTION
319-
typename std::enable_if<NewPackT::packtag && OldPackT::packtag &&
320-
std::is_same<typename NewPackT::scalar,typename OldPackT::scalar>::value &&
321-
(OldPackT::n > NewPackT::n),
322-
Unmanaged<Kokkos::View<NewPackT**,ViewProps...>>
323-
>::type
324-
repack_impl (const Kokkos::View<OldPackT**, ViewProps...>& vp) {
325-
constexpr int new_pack_size = NewPackT::n;
326-
constexpr int old_pack_size = OldPackT::n;
327-
static_assert(new_pack_size > 0, "New pack size must be positive");
328-
329-
// It's overly restrictive to check compatibility between pack sizes.
330-
// What really matters is that the new pack size divides the last extent
331-
// of the "scalarized" view.
332-
// This MUST be a runtime check.
333-
assert ( (vp.extent_int(1)*old_pack_size) % new_pack_size == 0);
334-
335-
return Unmanaged<Kokkos::View<NewPackT**, ViewProps...> >(
336-
reinterpret_cast<NewPackT*>(vp.data()),
337-
vp.extent_int(0),
338-
(old_pack_size / new_pack_size) * vp.extent_int(1));
339-
}
340-
341-
// 2d growing
342-
template <typename NewPackT, typename OldPackT, typename... ViewProps>
343-
KOKKOS_FORCEINLINE_FUNCTION
344-
typename std::enable_if<NewPackT::packtag && OldPackT::packtag &&
345-
std::is_same<typename NewPackT::scalar,typename OldPackT::scalar>::value &&
346-
(OldPackT::n < NewPackT::n),
347-
Unmanaged<Kokkos::View<NewPackT**,ViewProps...>>
348-
>::type
349-
repack_impl (const Kokkos::View<OldPackT**, ViewProps...>& vp) {
350-
constexpr int new_pack_size = NewPackT::n;
351-
constexpr int old_pack_size = OldPackT::n;
352-
static_assert(new_pack_size > 0, "New pack size must be positive");
353-
// It's not enough to check that the new pack is a multiple of the old pack.
354-
// We actually need to check that the new pack size divides the last extent
355-
// of the "scalarized" view.
356-
// This MUST be a runtime check.
357-
assert ( (vp.extent_int(1)*old_pack_size) % new_pack_size == 0);
358-
359-
return Unmanaged<Kokkos::View<NewPackT**, ViewProps...> >(
360-
reinterpret_cast<NewPackT*>(vp.data()),
361-
vp.extent_int(0),
362-
vp.extent_int(1) / (new_pack_size / old_pack_size) );
363-
}
364-
365-
// 2d staying the same
366-
template <typename NewPackT, typename OldPackT, typename... ViewProps>
367-
KOKKOS_FORCEINLINE_FUNCTION
368-
typename std::enable_if<NewPackT::packtag && OldPackT::packtag &&
369-
std::is_same<typename NewPackT::scalar,typename OldPackT::scalar>::value &&
370-
(OldPackT::n == NewPackT::n),
371-
Unmanaged<Kokkos::View<NewPackT**,ViewProps...>>
372-
>::type
373-
repack_impl (const Kokkos::View<OldPackT**, ViewProps...>& vp) {
374-
return vp;
375-
}
295+
// Turn a View of Pack<T,M>s into a View of Pack<T,N>s,
296+
// or a View of T into a View of Pack<T,N> (provided T is not a Pack itself)
297+
template<int N, typename DT, typename... Props>
298+
KOKKOS_INLINE_FUNCTION
299+
auto repack (const Kokkos::View<DT,Props...>& src)
300+
{
301+
using src_view_t = Kokkos::View<DT,Props...>;
302+
using src_value_t = typename src_view_t::traits::value_type;
303+
using src_scalar_t = typename ScalarTraits<src_value_t>::scalar_type;
304+
using array_layout = typename src_view_t::traits::array_layout;
376305

377-
// General access point for repack (calls one of the three above)
378-
template <int N, typename OldValueT, typename... ViewProps>
379-
KOKKOS_FORCEINLINE_FUNCTION
380-
Unmanaged<Kokkos::View<typename RepackType<N,OldValueT>::type**,ViewProps...>>
381-
repack (const Kokkos::View<OldValueT**, ViewProps...>& v) {
382-
using OldPackT =
383-
typename std::conditional<IsPack<OldValueT>::value,
384-
OldValueT,
385-
typename RepackType<1,OldValueT>::type
386-
>::type;
387-
388-
// We are not changing the layout of the view, since
389-
// - if OldValueT was a pack, OldPackT=OldValueT
390-
// - if OldValueT was NOT a pack, OldPackT has size 1
391-
Kokkos::View<OldPackT**,ViewProps...> vp(
392-
reinterpret_cast<OldPackT*>(v.data()),v.extent(0),v.extent(1));
393-
return repack_impl<typename RepackType<N,OldPackT>::type>(vp);
394-
}
306+
if constexpr (ekat::ScalarTraits<src_value_t>::is_simd) {
307+
return repack<N>(scalarize(src));
308+
}
309+
constexpr int rank = src_view_t::rank();
395310

396-
// 1d shrinking
397-
template <typename NewPackT, typename OldPackT, typename... ViewProps>
398-
KOKKOS_FORCEINLINE_FUNCTION
399-
typename std::enable_if<NewPackT::packtag && OldPackT::packtag &&
400-
std::is_same<typename NewPackT::scalar,typename OldPackT::scalar>::value &&
401-
(OldPackT::n > NewPackT::n),
402-
Unmanaged<Kokkos::View<NewPackT*,ViewProps...>>
403-
>::type
404-
repack_impl (const Kokkos::View<OldPackT*, ViewProps...>& vp) {
405-
constexpr int new_pack_size = NewPackT::n;
406-
constexpr int old_pack_size = OldPackT::n;
407-
static_assert(new_pack_size > 0, "New pack size must be positive");
408-
409-
// It's overly restrictive to check compatibility between pack sizes.
410-
// What really matters is that the new pack size divides the last extent
411-
// of the "scalarized" view.
412-
// This MUST be a runtime check.
413-
assert ( (vp.extent_int(0)*old_pack_size) % new_pack_size == 0);
414-
415-
return Unmanaged<Kokkos::View<NewPackT*, ViewProps...> >(
416-
reinterpret_cast<NewPackT*>(vp.data()),
417-
(old_pack_size / new_pack_size) * vp.extent_int(0));
418-
}
311+
using nonconst_dst_value_t = ekat::Pack<src_scalar_t,N>;
312+
using dst_value_t = std::conditional_t<std::is_const_v<src_value_t>,
313+
std::add_const_t<nonconst_dst_value_t>,
314+
nonconst_dst_value_t>;
315+
using dst_data_type = typename ekat::DataND<dst_value_t,rank>::type;
316+
using dst_view_t = Unmanaged<Kokkos::View<dst_data_type,Props...>>;
317+
int packed_dim = rank-1;
419318

420-
// 1d growing
421-
template <typename NewPackT, typename OldPackT, typename... ViewProps>
422-
KOKKOS_FORCEINLINE_FUNCTION
423-
typename std::enable_if<NewPackT::packtag && OldPackT::packtag &&
424-
std::is_same<typename NewPackT::scalar,typename OldPackT::scalar>::value &&
425-
(OldPackT::n < NewPackT::n),
426-
Unmanaged<Kokkos::View<NewPackT*,ViewProps...>>
427-
>::type
428-
repack_impl (const Kokkos::View<OldPackT*, ViewProps...>& vp) {
429-
constexpr int new_pack_size = NewPackT::n;
430-
constexpr int old_pack_size = OldPackT::n;
431-
static_assert(new_pack_size > 0, "New pack size must be positive");
432-
433-
// It's not enough to check that the new pack is a multiple of the old pack.
434-
// We actually need to check that the new pack size divides the last extent
435-
// of the "scalarized" view.
436-
// This MUST be a runtime check.
437-
assert ( (vp.extent_int(0)*old_pack_size) % new_pack_size == 0);
438-
439-
EKAT_KERNEL_ASSERT(vp.extent_int(0) % (new_pack_size / old_pack_size) == 0);
440-
return Unmanaged<Kokkos::View<NewPackT*, ViewProps...> >(
441-
reinterpret_cast<NewPackT*>(vp.data()),
442-
vp.extent_int(0) / (new_pack_size / old_pack_size));
443-
}
319+
EKAT_KERNEL_REQUIRE_MSG (src.extent(packed_dim) % N == 0,
320+
"Error! Cannot pack input view, as the pack size does not divide the last dimension.\n");
444321

445-
// 1d staying the same
446-
template <typename NewPackT, typename OldPackT, typename... ViewProps>
447-
KOKKOS_FORCEINLINE_FUNCTION
448-
typename std::enable_if<NewPackT::packtag && OldPackT::packtag &&
449-
std::is_same<typename NewPackT::scalar,typename OldPackT::scalar>::value &&
450-
(OldPackT::n == NewPackT::n),
451-
Unmanaged<Kokkos::View<NewPackT*,ViewProps...>>
452-
>::type
453-
repack_impl (const Kokkos::View<OldPackT*, ViewProps...>& vp) {
454-
return vp;
455-
}
322+
auto data = src.data();
323+
auto packed_data = reinterpret_cast<dst_value_t*>(data);
456324

457-
// General access point for repack (calls one of the three above)
458-
template <int N, typename OldValueT, typename... ViewProps>
459-
KOKKOS_FORCEINLINE_FUNCTION
460-
Unmanaged<Kokkos::View<typename RepackType<N,OldValueT>::type*,ViewProps...>>
461-
repack (const Kokkos::View<OldValueT*, ViewProps...>& v) {
462-
using OldPackT =
463-
typename std::conditional<IsPack<OldValueT>::value,
464-
OldValueT,
465-
typename RepackType<1,OldValueT>::type
466-
>::type;
467-
468-
// We are not changing the layout of the view, since
469-
// - if OldValueT was a pack, OldPackT=OldValueT
470-
// - if OldValueT was NOT a pack, OldPackT has size 1
471-
Kokkos::View<OldPackT*,ViewProps...> vp(
472-
reinterpret_cast<OldPackT*>(v.data()),v.extent(0));
473-
return repack_impl<typename RepackType<N,OldPackT>::type>(vp);
325+
array_layout layout;
326+
for (int i=0; i<rank-1; ++i) {
327+
layout.dimension[i] = src.extent(i);
328+
}
329+
layout.dimension[rank-1] = ekat::PackInfo<N>::num_packs(src.extent(rank-1));
330+
331+
return dst_view_t(packed_data,layout);
474332
}
475333

476334
//

0 commit comments

Comments
 (0)