Skip to content

Commit 5e4673d

Browse files
allow virtual datum store to work with types reflectable by boost.pfr
1 parent 42a0817 commit 5e4673d

File tree

3 files changed

+104
-19
lines changed

3 files changed

+104
-19
lines changed

examples/nbody/nbody.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,15 @@ namespace usellama
154154
std::normal_distribution<FP> dist(FP(0), FP(1));
155155
for (std::size_t i = 0; i < PROBLEM_SIZE; ++i)
156156
{
157-
auto p = particles(i);
158-
p(0_DC, 0_DC) = dist(engine);
159-
p(0_DC, 1_DC) = dist(engine);
160-
p(0_DC, 2_DC) = dist(engine);
161-
p(1_DC, 0_DC) = dist(engine) / FP(10);
162-
p(1_DC, 1_DC) = dist(engine) / FP(10);
163-
p(1_DC, 2_DC) = dist(engine) / FP(10);
164-
p(2_DC) = dist(engine) / FP(100);
157+
Particle p;
158+
p.pos.x = dist(engine);
159+
p.pos.y = dist(engine);
160+
p.pos.z = dist(engine);
161+
p.vel.x = dist(engine) / FP(10);
162+
p.vel.y = dist(engine) / FP(10);
163+
p.vel.z = dist(engine) / FP(10);
164+
p.mass = dist(engine) / FP(100);
165+
particles(i).store(p);
165166
}
166167
watch.printAndReset("init");
167168

include/llama/View.hpp

+45-10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "macros.hpp"
1111
#include "mapping/One.hpp"
1212

13+
#include <boost/pfr.hpp>
1314
#include <boost/preprocessor/cat.hpp>
1415
#include <type_traits>
1516

@@ -259,18 +260,50 @@ namespace llama
259260
template <typename... Ts>
260261
constexpr inline auto dependentFalse = false;
261262

263+
template <typename T>
264+
constexpr inline auto tupleish_size = []() constexpr
265+
{
266+
if constexpr (isTupleLike<T>)
267+
return std::tuple_size_v<T>;
268+
else if constexpr (std::is_aggregate_v<T>) // TODO
269+
return boost::pfr::tuple_size_v<T>;
270+
else
271+
static_assert(
272+
dependentFalse<T>,
273+
"T is not a tuple like type or an aggregate to reflect with boost.pfr");
274+
}
275+
();
276+
277+
template <std::size_t I, typename T>
278+
decltype(auto) tupleish_get(T&& t)
279+
{
280+
using DT = std::decay_t<T>;
281+
if constexpr (isTupleLike<DT>)
282+
{
283+
using std::get;
284+
return get<I>(std::forward<T>(t));
285+
}
286+
else if constexpr (std::is_aggregate_v<DT>) // TODO
287+
return boost::pfr::get<I>(std::forward<T>(t));
288+
else
289+
static_assert(
290+
dependentFalse<T>,
291+
"T is not a tuple like type or an aggregate to reflect with boost.pfr");
292+
}
293+
262294
template <typename Tuple1, typename Tuple2, std::size_t... Is>
263295
LLAMA_FN_HOST_ACC_INLINE void assignTuples(Tuple1&& dst, Tuple2&& src, std::index_sequence<Is...>);
264296

265297
template <typename T1, typename T2>
266298
LLAMA_FN_HOST_ACC_INLINE void assignTupleElement(T1&& dst, T2&& src)
267299
{
268-
if constexpr (isTupleLike<std::decay_t<T1>> && isTupleLike<std::decay_t<T2>>)
269-
{
270-
static_assert(std::tuple_size_v<std::decay_t<T1>> == std::tuple_size_v<std::decay_t<T2>>);
271-
assignTuples(dst, src, std::make_index_sequence<std::tuple_size_v<std::decay_t<T1>>>{});
272-
}
273-
else if constexpr (!isTupleLike<std::decay_t<T1>> && !isTupleLike<std::decay_t<T2>>)
300+
using DT1 = std::decay_t<T1>;
301+
using DT2 = std::decay_t<T2>;
302+
constexpr auto isTupleish1 = isTupleLike<DT1> || std::is_aggregate_v<DT1>;
303+
constexpr auto isTupleish2 = isTupleLike<DT2> || std::is_aggregate_v<DT2>;
304+
if constexpr (isTupleish1 && isTupleish2)
305+
assignTuples(dst, src, std::make_index_sequence<tupleish_size<DT1>>{});
306+
else if constexpr (!isTupleish1 && !isTupleish2)
274307
std::forward<T1>(dst) = std::forward<T2>(src);
275308
else
276309
static_assert(dependentFalse<T1, T2>, "Elements to assign are not tuple/tuple or non-tuple/non-tuple.");
@@ -279,9 +312,11 @@ namespace llama
279312
template <typename Tuple1, typename Tuple2, std::size_t... Is>
280313
LLAMA_FN_HOST_ACC_INLINE void assignTuples(Tuple1&& dst, Tuple2&& src, std::index_sequence<Is...>)
281314
{
282-
static_assert(std::tuple_size_v<std::decay_t<Tuple1>> == std::tuple_size_v<std::decay_t<Tuple2>>);
283-
using std::get;
284-
(assignTupleElement(get<Is>(std::forward<Tuple1>(dst)), get<Is>(std::forward<Tuple2>(src))), ...);
315+
static_assert(tupleish_size<std::decay_t<Tuple1>> == tupleish_size<std::decay_t<Tuple2>>);
316+
(assignTupleElement(
317+
tupleish_get<Is>(std::forward<Tuple1>(dst)),
318+
tupleish_get<Is>(std::forward<Tuple2>(src))),
319+
...);
285320
}
286321

287322
template <typename T, typename Tuple, std::size_t... Is>
@@ -659,7 +694,7 @@ namespace llama
659694
template <typename TupleLike>
660695
void store(const TupleLike& t)
661696
{
662-
internal::assignTuples(asTuple(), t, std::make_index_sequence<std::tuple_size_v<TupleLike>>{});
697+
internal::assignTuples(asTuple(), t, std::make_index_sequence<std::tuple_size_v<decltype(asTuple())>>{});
663698
}
664699
};
665700
} // namespace llama

tests/virtualdatum.cpp

+50-1
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ TEST_CASE("VirtualDatum.load.constref")
797797
}
798798
}
799799

800-
TEST_CASE("VirtualDatum.store")
800+
TEST_CASE("VirtualDatum.store.tuplelike")
801801
{
802802
llama::One<Name> datum;
803803

@@ -878,3 +878,52 @@ TEST_CASE("VirtualDatum.loadAs.constref")
878878
CHECK(pos.y == 1);
879879
}
880880
}
881+
882+
TEST_CASE("VirtualDatum.store.aggregate")
883+
{
884+
struct MyPosAgg
885+
{
886+
int a;
887+
int y;
888+
};
889+
890+
struct MyVelAgg
891+
{
892+
int x;
893+
int y;
894+
int z;
895+
};
896+
897+
struct MyDatumAgg
898+
{
899+
MyPosAgg pos;
900+
MyVelAgg vel;
901+
int weight;
902+
};
903+
904+
llama::One<Name> datum;
905+
906+
datum = 1;
907+
{
908+
MyPosAgg pos{2, 3};
909+
datum(tag::Pos{}).store(pos);
910+
CHECK(datum(tag::Pos{}, tag::A{}) == 2);
911+
CHECK(datum(tag::Pos{}, tag::Y{}) == 3);
912+
CHECK(datum(tag::Vel{}, tag::X{}) == 1);
913+
CHECK(datum(tag::Vel{}, tag::Y{}) == 1);
914+
CHECK(datum(tag::Vel{}, tag::Z{}) == 1);
915+
CHECK(datum(tag::Weight{}) == 1);
916+
}
917+
918+
datum = 1;
919+
{
920+
MyDatumAgg d{{2, 3}, {4, 5, 6}, 7};
921+
datum.store(d);
922+
CHECK(datum(tag::Pos{}, tag::A{}) == 2);
923+
CHECK(datum(tag::Pos{}, tag::Y{}) == 3);
924+
CHECK(datum(tag::Vel{}, tag::X{}) == 4);
925+
CHECK(datum(tag::Vel{}, tag::Y{}) == 5);
926+
CHECK(datum(tag::Vel{}, tag::Z{}) == 6);
927+
CHECK(datum(tag::Weight{}) == 7);
928+
}
929+
}

0 commit comments

Comments
 (0)