Skip to content

Commit b5834f3

Browse files
authored
Add DiscreteVector operator access (#842)
* Add DiscreteVector operator access * Workaround msvc parsing issue
1 parent 47e7bfc commit b5834f3

File tree

5 files changed

+167
-24
lines changed

5 files changed

+167
-24
lines changed

include/ddc/chunk.hpp

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <cassert>
88
#include <string>
9+
#include <type_traits>
910
#include <utility>
1011

1112
#include <Kokkos_Core.hpp>
@@ -14,6 +15,7 @@
1415
#include "ddc/chunk_span.hpp"
1516
#include "ddc/chunk_traits.hpp"
1617
#include "ddc/detail/kokkos.hpp"
18+
#include "ddc/detail/type_traits.hpp"
1719
#include "ddc/kokkos_allocator.hpp"
1820

1921
namespace ddc {
@@ -60,6 +62,8 @@ class Chunk<ElementType, DiscreteDomain<DDims...>, Allocator>
6062

6163
using discrete_element_type = typename base_type::discrete_element_type;
6264

65+
using discrete_vector_type = typename base_type::discrete_vector_type;
66+
6367
using extents_type = typename base_type::extents_type;
6468

6569
using layout_type = typename base_type::layout_type;
@@ -181,38 +185,72 @@ class Chunk<ElementType, DiscreteDomain<DDims...>, Allocator>
181185
* @param delems discrete coordinates
182186
* @return const-reference to this element
183187
*/
184-
template <class... DElems>
188+
template <
189+
class... DElems,
190+
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
185191
element_type const& operator()(DElems const&... delems) const noexcept
186192
{
187193
static_assert(
188194
sizeof...(DDims) == (0 + ... + DElems::size()),
189195
"Invalid number of dimensions");
190196
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
191-
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain))
192-
&& ...));
193-
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain))
194-
&& ...));
197+
assert(this->m_domain.contains(delems...));
195198
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(take<DDims>(delems...))...);
196199
}
197200

201+
/** Element access using a list of DiscreteVector
202+
* @param dvects discrete vectors
203+
* @return reference to this element
204+
*/
205+
template <
206+
class... DVects,
207+
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
208+
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
209+
element_type const& operator()(DVects const&... dvects) const noexcept
210+
{
211+
static_assert(
212+
sizeof...(DDims) == (0 + ... + DVects::size()),
213+
"Invalid number of dimensions");
214+
discrete_element_type const delem
215+
= this->m_domain.front() + discrete_vector_type(dvects...);
216+
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(delem)...);
217+
}
218+
198219
/** Element access using a list of DiscreteElement
199220
* @param delems discrete coordinates
200221
* @return reference to this element
201222
*/
202-
template <class... DElems>
223+
template <
224+
class... DElems,
225+
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
203226
element_type& operator()(DElems const&... delems) noexcept
204227
{
205228
static_assert(
206229
sizeof...(DDims) == (0 + ... + DElems::size()),
207230
"Invalid number of dimensions");
208231
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
209-
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain))
210-
&& ...));
211-
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain))
212-
&& ...));
232+
assert(this->m_domain.contains(delems...));
213233
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(take<DDims>(delems...))...);
214234
}
215235

236+
/** Element access using a list of DiscreteVector
237+
* @param dvects discrete vectors
238+
* @return reference to this element
239+
*/
240+
template <
241+
class... DVects,
242+
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
243+
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
244+
element_type& operator()(DVects const&... dvects) noexcept
245+
{
246+
static_assert(
247+
sizeof...(DDims) == (0 + ... + DVects::size()),
248+
"Invalid number of dimensions");
249+
discrete_element_type const delem
250+
= this->m_domain.front() + discrete_vector_type(dvects...);
251+
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(delem)...);
252+
}
253+
216254
/** Returns the label of the Chunk
217255
* @return c-string
218256
*/
@@ -333,6 +371,8 @@ class Chunk : public ChunkCommon<ElementType, SupportType, Kokkos::layout_right>
333371

334372
using discrete_element_type = typename base_type::discrete_element_type;
335373

374+
using discrete_vector_type = typename base_type::discrete_vector_type;
375+
336376
using extents_type = typename base_type::extents_type;
337377

338378
using layout_type = typename base_type::layout_type;
@@ -442,36 +482,74 @@ class Chunk : public ChunkCommon<ElementType, SupportType, Kokkos::layout_right>
442482
* @param delems discrete coordinates
443483
* @return const-reference to this element
444484
*/
445-
template <class... DElems>
485+
template <
486+
class... DElems,
487+
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
446488
element_type const& operator()(DElems const&... delems) const noexcept
447489
{
448490
static_assert(
449491
SupportType::rank() == (0 + ... + DElems::size()),
450492
"Invalid number of dimensions");
451-
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
452493
assert(this->m_domain.contains(delems...));
453494
return DDC_MDSPAN_ACCESS_OP(
454495
this->m_allocation_mdspan,
455496
detail::array(this->m_domain.distance_from_front(delems...)));
456497
}
457498

499+
/** Element access using a list of DiscreteVector
500+
* @param dvects discrete vectors
501+
* @return reference to this element
502+
*/
503+
template <
504+
class... DVects,
505+
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
506+
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
507+
element_type const& operator()(DVects const&... dvects) const noexcept
508+
{
509+
static_assert(
510+
SupportType::rank() == (0 + ... + DVects::size()),
511+
"Invalid number of dimensions");
512+
return DDC_MDSPAN_ACCESS_OP(
513+
this->m_allocation_mdspan,
514+
detail::array(discrete_vector_type(dvects...)));
515+
}
516+
458517
/** Element access using a list of DiscreteElement
459518
* @param delems discrete coordinates
460519
* @return reference to this element
461520
*/
462-
template <class... DElems>
521+
template <
522+
class... DElems,
523+
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
463524
element_type& operator()(DElems const&... delems) noexcept
464525
{
465526
static_assert(
466527
SupportType::rank() == (0 + ... + DElems::size()),
467528
"Invalid number of dimensions");
468-
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
469529
assert(this->m_domain.contains(delems...));
470530
return DDC_MDSPAN_ACCESS_OP(
471531
this->m_allocation_mdspan,
472532
detail::array(this->m_domain.distance_from_front(delems...)));
473533
}
474534

535+
/** Element access using a list of DiscreteVector
536+
* @param dvects discrete vectors
537+
* @return reference to this element
538+
*/
539+
template <
540+
class... DVects,
541+
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
542+
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
543+
element_type& operator()(DVects const&... dvects) noexcept
544+
{
545+
static_assert(
546+
SupportType::rank() == (0 + ... + DVects::size()),
547+
"Invalid number of dimensions");
548+
return DDC_MDSPAN_ACCESS_OP(
549+
this->m_allocation_mdspan,
550+
detail::array(discrete_vector_type(dvects...)));
551+
}
552+
475553
/** Returns the label of the Chunk
476554
* @return c-string
477555
*/

include/ddc/chunk_common.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class ChunkCommon<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
5858

5959
using discrete_element_type = typename discrete_domain_type::discrete_element_type;
6060

61+
using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;
62+
6163
using extents_type = typename allocation_mdspan_type::extents_type;
6264

6365
using layout_type = typename allocation_mdspan_type::layout_type;
@@ -323,6 +325,8 @@ class ChunkCommon
323325

324326
using discrete_element_type = typename discrete_domain_type::discrete_element_type;
325327

328+
using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;
329+
326330
using extents_type = typename allocation_mdspan_type::extents_type;
327331

328332
using layout_type = typename allocation_mdspan_type::layout_type;

include/ddc/chunk_span.hpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ddc/chunk_common.hpp"
1616
#include "ddc/detail/kokkos.hpp"
1717
#include "ddc/detail/type_seq.hpp"
18+
#include "ddc/detail/type_traits.hpp"
1819
#include "ddc/discrete_domain.hpp"
1920
#include "ddc/discrete_element.hpp"
2021

@@ -79,6 +80,8 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy, Memo
7980

8081
using discrete_element_type = typename discrete_domain_type::discrete_element_type;
8182

83+
using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;
84+
8285
using extents_type = typename base_type::extents_type;
8386

8487
using layout_type = typename base_type::layout_type;
@@ -326,20 +329,36 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy, Memo
326329
* @param delems discrete elements
327330
* @return reference to this element
328331
*/
329-
template <class... DElems>
332+
template <
333+
class... DElems,
334+
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
330335
KOKKOS_FUNCTION constexpr reference operator()(DElems const&... delems) const noexcept
331336
{
332337
static_assert(
333338
sizeof...(DDims) == (0 + ... + DElems::size()),
334339
"Invalid number of dimensions");
335-
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
336-
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain))
337-
&& ...));
338-
assert(((DiscreteElement<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain))
339-
&& ...));
340+
assert(this->m_domain.contains(delems...));
340341
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(take<DDims>(delems...))...);
341342
}
342343

344+
/** Element access using a list of DiscreteVector
345+
* @param dvects discrete vectors
346+
* @return reference to this element
347+
*/
348+
template <
349+
class... DVects,
350+
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
351+
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
352+
KOKKOS_FUNCTION constexpr reference operator()(DVects const&... dvects) const noexcept
353+
{
354+
static_assert(
355+
sizeof...(DDims) == (0 + ... + DVects::size()),
356+
"Invalid number of dimensions");
357+
discrete_element_type const delem
358+
= this->m_domain.front() + discrete_vector_type(dvects...);
359+
return DDC_MDSPAN_ACCESS_OP(this->m_internal_mdspan, uid<DDims>(delem)...);
360+
}
361+
343362
/** Access to the underlying allocation pointer
344363
* @return allocation pointer
345364
*/
@@ -413,6 +432,8 @@ class ChunkSpan : public ChunkCommon<ElementType, SupportType, LayoutStridedPoli
413432

414433
using discrete_element_type = typename discrete_domain_type::discrete_element_type;
415434

435+
using discrete_vector_type = typename discrete_domain_type::discrete_vector_type;
436+
416437
using extents_type = typename base_type::extents_type;
417438

418439
using layout_type = typename base_type::layout_type;
@@ -615,19 +636,38 @@ class ChunkSpan : public ChunkCommon<ElementType, SupportType, LayoutStridedPoli
615636
* @param delems discrete elements
616637
* @return reference to this element
617638
*/
618-
template <class... DElems>
639+
template <
640+
class... DElems,
641+
std::enable_if_t<detail::all_of_v<is_discrete_element_v<DElems>...>, int> = 0>
619642
KOKKOS_FUNCTION constexpr reference operator()(DElems const&... delems) const noexcept
620643
{
621644
static_assert(
622645
SupportType::rank() == (0 + ... + DElems::size()),
623646
"Invalid number of dimensions");
624-
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
625647
assert(this->m_domain.contains(delems...));
626648
return DDC_MDSPAN_ACCESS_OP(
627649
this->m_allocation_mdspan,
628650
detail::array(this->m_domain.distance_from_front(delems...)));
629651
}
630652

653+
/** Element access using a list of DiscreteVector
654+
* @param dvects discrete vectors
655+
* @return reference to this element
656+
*/
657+
template <
658+
class... DVects,
659+
std::enable_if_t<detail::all_of_v<is_discrete_vector_v<DVects>...>, int> = 0,
660+
std::enable_if_t<sizeof...(DVects) != 0, int> = 0>
661+
KOKKOS_FUNCTION constexpr reference operator()(DVects const&... dvects) const noexcept
662+
{
663+
static_assert(
664+
SupportType::rank() == (0 + ... + DVects::size()),
665+
"Invalid number of dimensions");
666+
return DDC_MDSPAN_ACCESS_OP(
667+
this->m_allocation_mdspan,
668+
detail::array(discrete_vector_type(dvects...)));
669+
}
670+
631671
/** Access to the underlying allocation pointer
632672
* @return allocation pointer
633673
*/

include/ddc/detail/type_traits.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Copyright (C) The DDC development team, see COPYRIGHT.md file
2+
//
3+
// SPDX-License-Identifier: MIT
4+
5+
#pragma once
6+
7+
namespace ddc::detail {
8+
9+
// This helper was introduced to workaround a parsing issue with msvc.
10+
template <bool... Bs>
11+
inline constexpr bool all_of_v = (Bs && ...);
12+
13+
} // namespace ddc::detail

tests/chunk.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ TEST(Chunk1DTest, AccessConst)
211211
chunk(ix) = factor * (ix - lbound_x);
212212
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
213213
EXPECT_EQ(chunk_cref(ix), factor * (ix - lbound_x));
214+
EXPECT_EQ(chunk_cref(ix - lbound_x), factor * (ix - lbound_x));
214215
}
215216
}
216217

@@ -222,6 +223,7 @@ TEST(Chunk1DTest, Access)
222223
chunk(ix) = factor * (ix - lbound_x);
223224
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
224225
EXPECT_EQ(chunk(ix), factor * (ix - lbound_x));
226+
EXPECT_EQ(chunk(ix - lbound_x), factor * (ix - lbound_x));
225227
}
226228
}
227229

@@ -434,9 +436,11 @@ TEST(Chunk2DTest, Access)
434436
ChunkXY<double> chunk(dom_x_y);
435437
for (DElemX const ix : chunk.domain<DDimX>()) {
436438
for (DElemY const iy : chunk.domain<DDimY>()) {
437-
chunk(ix, iy) = 1.357 * (ix - lbound_x) + 1.159 * (iy - lbound_y);
439+
double const value = 1.357 * (ix - lbound_x) + 1.159 * (iy - lbound_y);
440+
chunk(ix, iy) = value;
438441
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
439-
EXPECT_EQ(chunk(ix, iy), chunk(ix, iy));
442+
EXPECT_EQ(chunk(ix, iy), value);
443+
EXPECT_EQ(chunk(ix - lbound_x, iy - lbound_y), value);
440444
}
441445
}
442446
}
@@ -449,6 +453,8 @@ TEST(Chunk2DTest, AccessReordered)
449453
chunk(ix, iy) = 1.455 * (ix - lbound_x) + 1.522 * (iy - lbound_y);
450454
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
451455
EXPECT_EQ(chunk(iy, ix), chunk(ix, iy));
456+
EXPECT_EQ(chunk(iy - lbound_y, ix - lbound_x), chunk(ix, iy));
457+
EXPECT_EQ(chunk(iy - lbound_y, ix - lbound_x), chunk(ix - lbound_x, iy - lbound_y));
452458
}
453459
}
454460
}
@@ -611,6 +617,7 @@ TEST(Chunk3DTest, AccessFromDiscreteElements)
611617
DDomZ const dom_z(lbound_z, ddc::DiscreteVector<DDimZ>(4));
612618
ddc::Chunk<double, DDomXYZ> chunk(DDomXYZ(dom_x_y, dom_z));
613619
ddc::ChunkSpan const chunk_span = chunk.span_cview();
620+
ddc::DiscreteElement<DDimZ, DDimX> const lbound_zx(lbound_z, lbound_x);
614621
for (DElemX const ix : chunk.domain<DDimX>()) {
615622
for (DElemY const iy : chunk.domain<DDimY>()) {
616623
for (DElemZ const iz : chunk.domain<DDimZ>()) {
@@ -620,6 +627,7 @@ TEST(Chunk3DTest, AccessFromDiscreteElements)
620627
// we expect exact equality, not EXPECT_DOUBLE_EQ: this is the same ref twice
621628
EXPECT_EQ(chunk(ix, iy, iz), chunk(iy, izx));
622629
EXPECT_EQ(chunk(ix, iy, iz), chunk_span(iy, izx));
630+
EXPECT_EQ(chunk(ix, iy, iz), chunk(iy - lbound_y, izx - lbound_zx));
623631
}
624632
}
625633
}

0 commit comments

Comments
 (0)