Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 92 additions & 14 deletions include/ddc/chunk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <cassert>
#include <string>
#include <type_traits>
#include <utility>

#include <Kokkos_Core.hpp>
Expand All @@ -14,6 +15,7 @@
#include "ddc/chunk_span.hpp"
#include "ddc/chunk_traits.hpp"
#include "ddc/detail/kokkos.hpp"
#include "ddc/detail/type_traits.hpp"
#include "ddc/kokkos_allocator.hpp"

namespace ddc {
Expand Down Expand Up @@ -60,6 +62,8 @@ class Chunk<ElementType, DiscreteDomain<DDims...>, Allocator>

using discrete_element_type = typename base_type::discrete_element_type;

using discrete_vector_type = typename base_type::discrete_vector_type;

using extents_type = typename base_type::extents_type;

using layout_type = typename base_type::layout_type;
Expand Down Expand Up @@ -181,38 +185,72 @@ class Chunk<ElementType, DiscreteDomain<DDims...>, Allocator>
* @param delems discrete coordinates
* @return const-reference to this element
*/
template <class... DElems>
template <
class... DElems,
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
element_type const& operator()(DElems const&... delems) const noexcept
{
static_assert(
sizeof...(DDims) == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain))
&& ...));
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain))
&& ...));
assert(this->m_domain.contains(delems...));
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(take<DDims>(delems...))...);
}

/** Element access using a list of DiscreteVector
* @param dvects discrete vectors
* @return reference to this element
*/
template <
class... DVects,
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
element_type const& operator()(DVects const&... dvects) const noexcept
{
static_assert(
sizeof...(DDims) == (0 + ... + DVects::size()),
"Invalid number of dimensions");
discrete_element_type const delem
= this->m_domain.front() + discrete_vector_type(dvects...);
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(delem)...);
}

/** Element access using a list of DiscreteElement
* @param delems discrete coordinates
* @return reference to this element
*/
template <class... DElems>
template <
class... DElems,
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
element_type& operator()(DElems const&... delems) noexcept
{
static_assert(
sizeof...(DDims) == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain))
&& ...));
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain))
&& ...));
assert(this->m_domain.contains(delems...));
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(take<DDims>(delems...))...);
}

/** Element access using a list of DiscreteVector
* @param dvects discrete vectors
* @return reference to this element
*/
template <
class... DVects,
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
element_type& operator()(DVects const&... dvects) noexcept
{
static_assert(
sizeof...(DDims) == (0 + ... + DVects::size()),
"Invalid number of dimensions");
discrete_element_type const delem
= this->m_domain.front() + discrete_vector_type(dvects...);
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(delem)...);
}

/** Returns the label of the Chunk
* @return c-string
*/
Expand Down Expand Up @@ -333,6 +371,8 @@ class Chunk : public ChunkCommon<ElementType, SupportType, Kokkos::layout_right>

using discrete_element_type = typename base_type::discrete_element_type;

using discrete_vector_type = typename base_type::discrete_vector_type;

using extents_type = typename base_type::extents_type;

using layout_type = typename base_type::layout_type;
Expand Down Expand Up @@ -442,36 +482,74 @@ class Chunk : public ChunkCommon<ElementType, SupportType, Kokkos::layout_right>
* @param delems discrete coordinates
* @return const-reference to this element
*/
template <class... DElems>
template <
class... DElems,
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
element_type const& operator()(DElems const&... delems) const noexcept
{
static_assert(
SupportType::rank() == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(this->m_domain.contains(delems...));
return DDC_MDSPAN_ACCESS_OP(
this->m_allocation_mdspan,
detail::array(this->m_domain.distance_from_front(delems...)));
}

/** Element access using a list of DiscreteVector
* @param dvects discrete vectors
* @return reference to this element
*/
template <
class... DVects,
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
element_type const& operator()(DVects const&... dvects) const noexcept
{
static_assert(
SupportType::rank() == (0 + ... + DVects::size()),
"Invalid number of dimensions");
return DDC_MDSPAN_ACCESS_OP(
this->m_allocation_mdspan,
detail::array(discrete_vector_type(dvects...)));
}

/** Element access using a list of DiscreteElement
* @param delems discrete coordinates
* @return reference to this element
*/
template <class... DElems>
template <
class... DElems,
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
element_type& operator()(DElems const&... delems) noexcept
{
static_assert(
SupportType::rank() == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(this->m_domain.contains(delems...));
return DDC_MDSPAN_ACCESS_OP(
this->m_allocation_mdspan,
detail::array(this->m_domain.distance_from_front(delems...)));
}

/** Element access using a list of DiscreteVector
* @param dvects discrete vectors
* @return reference to this element
*/
template <
class... DVects,
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
element_type& operator()(DVects const&... dvects) noexcept
{
static_assert(
SupportType::rank() == (0 + ... + DVects::size()),
"Invalid number of dimensions");
return DDC_MDSPAN_ACCESS_OP(
this->m_allocation_mdspan,
detail::array(discrete_vector_type(dvects...)));
}

/** Returns the label of the Chunk
* @return c-string
*/
Expand Down
4 changes: 4 additions & 0 deletions include/ddc/chunk_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class ChunkCommon<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>

using discrete_element_type = typename discrete_domain_type::discrete_element_type;

using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;

using extents_type = typename allocation_mdspan_type::extents_type;

using layout_type = typename allocation_mdspan_type::layout_type;
Expand Down Expand Up @@ -323,6 +325,8 @@ class ChunkCommon

using discrete_element_type = typename discrete_domain_type::discrete_element_type;

using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;

using extents_type = typename allocation_mdspan_type::extents_type;

using layout_type = typename allocation_mdspan_type::layout_type;
Expand Down
56 changes: 48 additions & 8 deletions include/ddc/chunk_span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ddc/chunk_common.hpp"
#include "ddc/detail/kokkos.hpp"
#include "ddc/detail/type_seq.hpp"
#include "ddc/detail/type_traits.hpp"
#include "ddc/discrete_domain.hpp"
#include "ddc/discrete_element.hpp"

Expand Down Expand Up @@ -79,6 +80,8 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy, Memo

using discrete_element_type = typename discrete_domain_type::discrete_element_type;

using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;

using extents_type = typename base_type::extents_type;

using layout_type = typename base_type::layout_type;
Expand Down Expand Up @@ -326,20 +329,36 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy, Memo
* @param delems discrete elements
* @return reference to this element
*/
template <class... DElems>
template <
class... DElems,
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
KOKKOS_FUNCTION constexpr reference operator()(DElems const&... delems) const noexcept
{
static_assert(
sizeof...(DDims) == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain))
&& ...));
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain))
&& ...));
assert(this->m_domain.contains(delems...));
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(take<DDims>(delems...))...);
}

/** Element access using a list of DiscreteVector
* @param dvects discrete vectors
* @return reference to this element
*/
template <
class... DVects,
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
KOKKOS_FUNCTION constexpr reference operator()(DVects const&... dvects) const noexcept
{
static_assert(
sizeof...(DDims) == (0 + ... + DVects::size()),
"Invalid number of dimensions");
discrete_element_type const delem
= this->m_domain.front() + discrete_vector_type(dvects...);
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(delem)...);
}

/** Access to the underlying allocation pointer
* @return allocation pointer
*/
Expand Down Expand Up @@ -413,6 +432,8 @@ class ChunkSpan : public ChunkCommon<ElementType, SupportType, LayoutStridedPoli

using discrete_element_type = typename discrete_domain_type::discrete_element_type;

using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;

using extents_type = typename base_type::extents_type;

using layout_type = typename base_type::layout_type;
Expand Down Expand Up @@ -615,19 +636,38 @@ class ChunkSpan : public ChunkCommon<ElementType, SupportType, LayoutStridedPoli
* @param delems discrete elements
* @return reference to this element
*/
template <class... DElems>
template <
class... DElems,
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
KOKKOS_FUNCTION constexpr reference operator()(DElems const&... delems) const noexcept
{
static_assert(
SupportType::rank() == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(this->m_domain.contains(delems...));
return DDC_MDSPAN_ACCESS_OP(
this->m_allocation_mdspan,
detail::array(this->m_domain.distance_from_front(delems...)));
}

/** Element access using a list of DiscreteVector
* @param dvects discrete vectors
* @return reference to this element
*/
template <
class... DVects,
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
KOKKOS_FUNCTION constexpr reference operator()(DVects const&... dvects) const noexcept
{
static_assert(
SupportType::rank() == (0 + ... + DVects::size()),
"Invalid number of dimensions");
return DDC_MDSPAN_ACCESS_OP(
this->m_allocation_mdspan,
detail::array(discrete_vector_type(dvects...)));
}

/** Access to the underlying allocation pointer
* @return allocation pointer
*/
Expand Down
13 changes: 13 additions & 0 deletions include/ddc/detail/type_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (C) The DDC development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT

#pragma once

namespace ddc::detail {

// This helper was introduced to workaround a parsing issue with msvc.
template <bool... Bs>
inline constexpr bool all_of_v = (Bs && ...);

} // namespace ddc::detail
12 changes: 10 additions & 2 deletions tests/chunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ TEST(Chunk1DTest, AccessConst)
chunk(ix) = factor * (ix - lbound_x);
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
EXPECT_EQ(chunk_cref(ix), factor * (ix - lbound_x));
EXPECT_EQ(chunk_cref(ix - lbound_x), factor * (ix - lbound_x));
}
}

Expand All @@ -222,6 +223,7 @@ TEST(Chunk1DTest, Access)
chunk(ix) = factor * (ix - lbound_x);
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
EXPECT_EQ(chunk(ix), factor * (ix - lbound_x));
EXPECT_EQ(chunk(ix - lbound_x), factor * (ix - lbound_x));
}
}

Expand Down Expand Up @@ -434,9 +436,11 @@ TEST(Chunk2DTest, Access)
ChunkXY<double> chunk(dom_x_y);
for (DElemX const ix : chunk.domain<DDimX>()) {
for (DElemY const iy : chunk.domain<DDimY>()) {
chunk(ix, iy) = 1.357 * (ix - lbound_x) + 1.159 * (iy - lbound_y);
double const value = 1.357 * (ix - lbound_x) + 1.159 * (iy - lbound_y);
chunk(ix, iy) = value;
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
EXPECT_EQ(chunk(ix, iy), chunk(ix, iy));
EXPECT_EQ(chunk(ix, iy), value);
EXPECT_EQ(chunk(ix - lbound_x, iy - lbound_y), value);
}
}
}
Expand All @@ -449,6 +453,8 @@ TEST(Chunk2DTest, AccessReordered)
chunk(ix, iy) = 1.455 * (ix - lbound_x) + 1.522 * (iy - lbound_y);
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
EXPECT_EQ(chunk(iy, ix), chunk(ix, iy));
EXPECT_EQ(chunk(iy - lbound_y, ix - lbound_x), chunk(ix, iy));
EXPECT_EQ(chunk(iy - lbound_y, ix - lbound_x), chunk(ix - lbound_x, iy - lbound_y));
}
}
}
Expand Down Expand Up @@ -611,6 +617,7 @@ TEST(Chunk3DTest, AccessFromDiscreteElements)
DDomZ const dom_z(lbound_z, ddc::DiscreteVector<DDimZ>(4));
ddc::Chunk<double, DDomXYZ> chunk(DDomXYZ(dom_x_y, dom_z));
ddc::ChunkSpan const chunk_span = chunk.span_cview();
ddc::DiscreteElement<DDimZ, DDimX> const lbound_zx(lbound_z, lbound_x);
for (DElemX const ix : chunk.domain<DDimX>()) {
for (DElemY const iy : chunk.domain<DDimY>()) {
for (DElemZ const iz : chunk.domain<DDimZ>()) {
Expand All @@ -620,6 +627,7 @@ TEST(Chunk3DTest, AccessFromDiscreteElements)
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
EXPECT_EQ(chunk(ix, iy, iz), chunk(iy, izx));
EXPECT_EQ(chunk(ix, iy, iz), chunk_span(iy, izx));
EXPECT_EQ(chunk(ix, iy, iz), chunk(iy - lbound_y, izx - lbound_zx));
}
}
}
Expand Down