Skip to content

Commit a13ecd4

Browse files
add virtual datum operator-> and operator*
1 parent 4a6dcd3 commit a13ecd4

File tree

10 files changed

+145
-6
lines changed

10 files changed

+145
-6
lines changed

examples/nbody/nbody.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,13 @@ namespace usellama
107107
template <typename VirtualParticleI, typename VirtualParticleJ>
108108
LLAMA_FN_HOST_ACC_INLINE void pPInteraction(VirtualParticleI&& pi, VirtualParticleJ pj)
109109
{
110-
auto dist = pi.template loadAs<Particle>().pos - pj.template loadAs<Particle>().pos;
110+
auto dist = pi->pos - pj->pos;
111111
dist *= dist;
112112
const FP distSqr = EPS2 + dist.x + dist.y + dist.z;
113113
const FP distSixth = distSqr * distSqr * distSqr;
114114
const FP invDistCube = 1.0f / std::sqrt(distSixth);
115-
const FP sts = pj.template loadAs<Particle>().mass * invDistCube * TIMESTEP;
116-
pi(1_DC).store(pi.template loadAs<Particle>().vel + dist * sts);
115+
const FP sts = pj->mass * invDistCube * TIMESTEP;
116+
pi(1_DC).store(pi->vel + dist * sts);
117117
}
118118

119119
template <bool UseAccumulator, typename View>
@@ -143,7 +143,7 @@ namespace usellama
143143
{
144144
LLAMA_INDEPENDENT_DATA
145145
for (std::size_t i = 0; i < PROBLEM_SIZE; i++)
146-
particles(i)(0_DC) += particles(i)(1_DC) * TIMESTEP;
146+
particles(i)(0_DC).store(particles(i)->pos + particles(i)->vel * TIMESTEP);
147147
}
148148

149149
template <int Mapping, bool UseAccumulator, std::size_t AoSoALanes = 8 /*AVX2*/>

include/llama/View.hpp

+65-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ namespace llama
6060
template <std::size_t Dim, typename DatumDomain>
6161
LLAMA_FN_HOST_ACC_INLINE auto allocViewStack() -> decltype(auto)
6262
{
63+
using Mapping = llama::mapping::One<ArrayDomain<Dim>, DatumDomain>;
6364
using MadeDatumDomain = mapping::MakeDatumDomain<DatumDomain>; // user might pass struct to reflect
64-
using Mapping = llama::mapping::One<ArrayDomain<Dim>, MadeDatumDomain>;
6565
return allocView(Mapping{}, llama::allocator::Stack<sizeOf<MadeDatumDomain>>{});
6666
}
6767

@@ -342,8 +342,53 @@ namespace llama
342342
template <typename T, template <typename...> typename Tuple, typename... Args>
343343
constexpr inline auto
344344
isDirectListInitializableFromTuple<T, Tuple<Args...>> = isDirectListInitializable<T, Args...>;
345+
346+
template <typename Tuplish, typename Coord>
347+
struct GetNestedTuplishType;
348+
349+
template <typename Tuplish, std::size_t Head, std::size_t... Tail>
350+
struct GetNestedTuplishType<Tuplish, DatumCoord<Head, Tail...>>
351+
{
352+
using type = typename GetNestedTuplishType<
353+
std::decay_t<decltype(tupleish_get<Head>(std::declval<Tuplish>()))>,
354+
DatumCoord<Tail...>>::type;
355+
};
356+
357+
template <typename Tuplish>
358+
struct GetNestedTuplishType<Tuplish, DatumCoord<>>
359+
{
360+
using type = Tuplish;
361+
};
362+
363+
template <typename OriginalDatumDomain, typename BoundDatumDomain>
364+
struct GetValueStructOrVoid
365+
{
366+
using type = typename internal::GetNestedTuplishType<OriginalDatumDomain, BoundDatumDomain>::type;
367+
};
368+
369+
template <typename... Elements, typename BoundDatumDomain>
370+
struct GetValueStructOrVoid<DatumStruct<Elements...>, BoundDatumDomain>
371+
{
372+
using type = void;
373+
};
345374
} // namespace internal
346375

376+
template <typename T>
377+
struct IndirectValue
378+
{
379+
T value;
380+
381+
auto operator->() -> T*
382+
{
383+
return &value;
384+
}
385+
386+
auto operator->() const -> const T*
387+
{
388+
return &value;
389+
}
390+
};
391+
347392
/// Virtual data type returned by \ref View after resolving a user domain
348393
/// coordinate or partially resolving a \ref DatumCoord. A virtual datum
349394
/// does not hold data itself (thus named "virtual"), it just binds enough
@@ -359,6 +404,8 @@ namespace llama
359404
private:
360405
using ArrayDomain = typename View::Mapping::ArrayDomain;
361406
using DatumDomain = typename View::Mapping::DatumDomain;
407+
using ValueStruct = typename internal::
408+
GetValueStructOrVoid<typename View::Mapping::OriginalDatumDomain, BoundDatumDomain>::type;
362409

363410
const ArrayDomain userDomainPos;
364411
std::conditional_t<OwnView, View, View&> view;
@@ -369,10 +416,12 @@ namespace llama
369416
/// AccessibleDatumDomain is the same as `Mapping::DatumDomain`.
370417
using AccessibleDatumDomain = GetType<DatumDomain, BoundDatumDomain>;
371418

419+
static constexpr auto supportsValueLoad = !std::is_void_v<ValueStruct>;
420+
372421
LLAMA_FN_HOST_ACC_INLINE VirtualDatum()
373422
/* requires(OwnView) */
374423
: userDomainPos({})
375-
, view{allocViewStack<1, DatumDomain>()}
424+
, view{allocViewStack<1, std::conditional_t<supportsValueLoad, ValueStruct, DatumDomain>>()}
376425
{
377426
static_assert(OwnView, "The default constructor of VirtualDatum is only available if the ");
378427
}
@@ -691,6 +740,20 @@ namespace llama
691740
return {*this};
692741
}
693742

743+
// template <typename = std::enable_if_t<supportsValueLoad>>
744+
auto operator->() const -> IndirectValue<ValueStruct>
745+
{
746+
static_assert(supportsValueLoad);
747+
return {loadAs<ValueStruct>()};
748+
}
749+
750+
// template <typename = std::enable_if_t<supportsValueLoad>>
751+
auto operator*() const -> ValueStruct
752+
{
753+
static_assert(supportsValueLoad);
754+
return loadAs<ValueStruct>();
755+
}
756+
694757
template <typename TupleLike>
695758
void store(const TupleLike& t)
696759
{

include/llama/mapping/AoS.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace llama::mapping
1919
{
2020
using ArrayDomain = T_ArrayDomain;
2121
using DatumDomain = MakeDatumDomain<T_DatumDomain>;
22+
using OriginalDatumDomain = T_DatumDomain;
2223
static constexpr std::size_t blobCount = 1;
2324

2425
constexpr AoS() = default;

include/llama/mapping/AoSoA.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace llama::mapping
2222
{
2323
using ArrayDomain = T_ArrayDomain;
2424
using DatumDomain = MakeDatumDomain<T_DatumDomain>;
25+
using OriginalDatumDomain = T_DatumDomain;
2526
static constexpr std::size_t blobCount = 1;
2627

2728
constexpr AoSoA() = default;

include/llama/mapping/One.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace llama::mapping
1616
{
1717
using ArrayDomain = T_ArrayDomain;
1818
using DatumDomain = MakeDatumDomain<T_DatumDomain>;
19+
using OriginalDatumDomain = T_DatumDomain;
1920

2021
static constexpr std::size_t blobCount = 1;
2122

include/llama/mapping/SoA.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace llama::mapping
2222
{
2323
using ArrayDomain = T_ArrayDomain;
2424
using DatumDomain = MakeDatumDomain<T_DatumDomain>;
25+
using OriginalDatumDomain = T_DatumDomain;
2526
static constexpr std::size_t blobCount = []() constexpr
2627
{
2728
if constexpr (SeparateBuffers::value)

include/llama/mapping/SplitMapping.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ namespace llama::mapping
5454
{
5555
using ArrayDomain = T_ArrayDomain;
5656
using DatumDomain = MakeDatumDomain<T_DatumDomain>;
57+
using OriginalDatumDomain = T_DatumDomain;
5758

5859
using DatumDomainPartitions = decltype(internal::partitionDatumDomain(DatumDomain{}, DatumCoordForMapping1{}));
5960
using DatumDomain1 = boost::mp11::mp_first<DatumDomainPartitions>;

include/llama/mapping/Trace.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ namespace llama::mapping
3636
{
3737
using ArrayDomain = typename Mapping::ArrayDomain;
3838
using DatumDomain = typename Mapping::DatumDomain;
39+
using OriginalDatumDomain = typename Mapping::OriginalDatumDomain;
3940
static constexpr std::size_t blobCount = Mapping::blobCount;
4041

4142
constexpr Trace() = default;

include/llama/mapping/tree/Mapping.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ namespace llama::mapping::tree
173173
{
174174
using ArrayDomain = T_ArrayDomain;
175175
using DatumDomain = T_DatumDomain;
176+
using OriginalDatumDomain = T_DatumDomain;
176177
using BasicTree = TreeFromDomains<ArrayDomain, DatumDomain>;
177178
// TODO, support more than one blob
178179
static constexpr std::size_t blobCount = 1;

tests/virtualdatum.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,72 @@ TEST_CASE("VirtualDatum.store.aggregate")
927927
CHECK(datum(tag::Weight{}) == 7);
928928
}
929929
}
930+
931+
struct NameStruct
932+
{
933+
struct Pos
934+
{
935+
int a;
936+
int y;
937+
} pos;
938+
939+
struct Vel
940+
{
941+
int x;
942+
int y;
943+
int z;
944+
} vel;
945+
int weight;
946+
};
947+
948+
TEST_CASE("VirtualDatum.operator->")
949+
{
950+
llama::One<NameStruct> datum;
951+
datum.store(NameStruct{1, 2, 3, 4, 5, 6});
952+
953+
auto test = [](auto&& datum) {
954+
using namespace llama::literals;
955+
956+
CHECK(datum->pos.a == 1);
957+
CHECK(datum->pos.y == 2);
958+
CHECK(datum->vel.x == 3);
959+
CHECK(datum->vel.y == 4);
960+
CHECK(datum->vel.z == 5);
961+
CHECK(datum->weight == 6);
962+
963+
CHECK(datum(0_DC)->a == 1);
964+
CHECK(datum(0_DC)->y == 2);
965+
CHECK(datum(1_DC)->x == 3);
966+
CHECK(datum(1_DC)->y == 4);
967+
CHECK(datum(1_DC)->z == 5);
968+
};
969+
test(datum);
970+
test(std::as_const(datum));
971+
}
972+
973+
TEST_CASE("VirtualDatum.operator*")
974+
{
975+
llama::One<NameStruct> datum;
976+
datum.store(NameStruct{1, 2, 3, 4, 5, 6});
977+
978+
auto test = [](auto&& datum) {
979+
using namespace llama::literals;
980+
981+
const NameStruct ns = *datum;
982+
CHECK(ns.pos.a == 1);
983+
CHECK(ns.pos.y == 2);
984+
CHECK(ns.vel.x == 3);
985+
CHECK(ns.vel.y == 4);
986+
CHECK(ns.vel.z == 5);
987+
CHECK(ns.weight == 6);
988+
const NameStruct::Pos pos = *datum(0_DC);
989+
CHECK(pos.a == 1);
990+
CHECK(pos.y == 2);
991+
const NameStruct::Vel vel = *datum(1_DC);
992+
CHECK(vel.x == 3);
993+
CHECK(vel.y == 4);
994+
CHECK(vel.z == 5);
995+
};
996+
test(datum);
997+
test(std::as_const(datum));
998+
}

0 commit comments

Comments
 (0)