diff --git a/include/llama/RecordRef.hpp b/include/llama/RecordRef.hpp index 5893f710c8..bc2c55946e 100644 --- a/include/llama/RecordRef.hpp +++ b/include/llama/RecordRef.hpp @@ -330,6 +330,125 @@ namespace llama LLAMA_FN_HOST_ACC_INLINE void storeSimdRecord(const Simd& srcSimd, T&& dstRef, RecordCoord rc); } // namespace internal + template + struct PointerRef + { + using ArrayIndex = typename View::ArrayIndex; + + private: + using ArrayIndexQual = std::conditional_t; + inline static constexpr auto maxAiVal = std::numeric_limits::max(); + + public: + LLAMA_FN_HOST_ACC_INLINE constexpr PointerRef(std::nullptr_t) : ai{}, view(nullptr) + { + setNull(); + static_assert(OwnIndex); + } + + constexpr PointerRef(const PointerRef& other) = default; + + LLAMA_FN_HOST_ACC_INLINE constexpr PointerRef(ArrayIndexQual ai, View& view) : ai(ai), view(&view) + { + } + + LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(const PointerRef& other) -> PointerRef& + { + checkViews(view, other.view); + ai = other.ai; + view = other.view; // adopt other view (in case this was a nullptr) + return *this; + } + + // retarget pointer + template + LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=( + const PointerRef& other) -> PointerRef& + { + checkViews(view, other.view); + ai = other.ai; + view = other.view; // adopt other view (in case this was a nullptr) + return *this; + } + + LLAMA_FN_HOST_ACC_INLINE constexpr auto isNull() const + { + return ai[0] == maxAiVal; + } + + LLAMA_FN_HOST_ACC_INLINE constexpr void setNull() + { + ai[0] = maxAiVal; + } + + LLAMA_FN_HOST_ACC_INLINE constexpr auto operator*() + { + assert(!isNull() && "nullptr dereferences"); + return RecordRef{ai, *view}; + } + + LLAMA_FN_HOST_ACC_INLINE constexpr auto operator*() const + { + assert(!isNull() && "nullptr dereferences"); + return RecordRef{ai, *view}; + } + + LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(std::nullptr_t) -> PointerRef& + { + static_assert(ArrayIndex::rank > 0, "nullptr is not supported for zero-dim views"); + setNull(); + return *this; + } + LLAMA_FN_HOST_ACC_INLINE friend constexpr auto operator==(PointerRef a, PointerRef b) -> bool + { + checkViews(a.view, b.view); + return a.ai == b.ai; + } + + LLAMA_FN_HOST_ACC_INLINE friend constexpr auto operator==(PointerRef a, std::nullptr_t) -> bool + { + return a.isNull(); + } + + LLAMA_FN_HOST_ACC_INLINE friend constexpr auto operator==(std::nullptr_t, PointerRef b) -> bool + { + return b.isNull(); + } + + LLAMA_FN_HOST_ACC_INLINE friend constexpr auto operator!=(PointerRef a, PointerRef b) -> bool + { + return !(a == b); + } + + LLAMA_FN_HOST_ACC_INLINE friend constexpr auto operator!=(std::nullptr_t a, PointerRef b) -> bool + { + return !(a == b); + } + + LLAMA_FN_HOST_ACC_INLINE friend constexpr auto operator!=(PointerRef a, std::nullptr_t b) -> bool + { + return !(a == b); + } + + ArrayIndexQual ai; + View* view; + + private: + LLAMA_FN_HOST_ACC_INLINE static void checkViews([[maybe_unused]] const View* a, [[maybe_unused]] const View* b) + { + assert((a == nullptr || b == nullptr || a == b) && "Mixing pointers into different views is not allowed"); + } + }; + + template + using Pointer = PointerRef; + + template + inline constexpr bool isPointerRef = false; + + template + inline constexpr bool isPointerRef> = true; + /// Record reference type returned by \ref View after resolving an array dimensions coordinate or partially /// resolving a \ref RecordCoord. A record reference does not hold data itself, it just binds enough information /// (array dimensions coord and partial record coord) to retrieve it later from a \ref View. Records references @@ -338,6 +457,9 @@ namespace llama template struct RecordRef : private TView::Mapping::ArrayIndex { + template + friend struct PointerRef; + using View = TView; ///< View this record reference points into. using BoundRecordCoord = TBoundRecordCoord; ///< Record coords into View::RecordDim which are already bound by this RecordRef. @@ -383,6 +505,11 @@ namespace llama ~RecordRef() = default; + LLAMA_FN_HOST_ACC_INLINE constexpr auto arrayIndex() -> ArrayIndex& + { + return static_cast(*this); + } + LLAMA_FN_HOST_ACC_INLINE constexpr auto arrayIndex() const -> ArrayIndex { return static_cast(*this); @@ -430,6 +557,14 @@ namespace llama LLAMA_FORCE_INLINE_RECURSIVE return RecordRef{arrayIndex(), this->view}; } + // else if constexpr(/*isPointer*/ std::is_same_v) + // { + // const ArrayIndex dstAi = this->view.access(arrayIndex(), AbsolutCoord{}); + // // using DstRecordDim = typename AccessedType::type; + // // static_assert(std::is_same_v, "Implementation limit"); + // using Pointee = RecordRef, false>; + // return Pointer{Pointee{dstAi, view}}; + // } else { LLAMA_FORCE_INLINE_RECURSIVE @@ -448,6 +583,14 @@ namespace llama LLAMA_FORCE_INLINE_RECURSIVE return RecordRef{arrayIndex(), this->view}; } + // else if constexpr(/*isPointer*/ std::is_same_v) + // { + // const ArrayIndex dstAi = this->view.access(arrayIndex(), AbsolutCoord{}); + // // using DstRecordDim = typename AccessedType::type; + // // static_assert(std::is_same_v, "Implementation limit"); + // using Pointee = RecordRef, false>; + // return Pointer{Pointee{dstAi, view}}; + // } else { LLAMA_FORCE_INLINE_RECURSIVE @@ -726,6 +869,16 @@ namespace llama internal::assignTuples(asTuple(), t, std::make_index_sequence>{}); } + LLAMA_FN_HOST_ACC_INLINE constexpr auto operator&() + { + return Pointer{arrayIndex(), view}; + } + + LLAMA_FN_HOST_ACC_INLINE constexpr auto operator&() const + { + return Pointer{arrayIndex(), view}; + } + // swap for equal RecordRef LLAMA_FN_HOST_ACC_INLINE friend void swap( std::conditional_t a, diff --git a/include/llama/llama.hpp b/include/llama/llama.hpp index 8fde8c1e6b..ca37c517b8 100644 --- a/include/llama/llama.hpp +++ b/include/llama/llama.hpp @@ -71,6 +71,7 @@ #include "mapping/Null.hpp" #include "mapping/One.hpp" #include "mapping/PermuteArrayIndex.hpp" +#include "mapping/PointerToIndex.hpp" #include "mapping/Projection.hpp" #include "mapping/SoA.hpp" #include "mapping/Split.hpp" diff --git a/include/llama/mapping/PointerToIndex.hpp b/include/llama/mapping/PointerToIndex.hpp new file mode 100644 index 0000000000..c7559eedc4 --- /dev/null +++ b/include/llama/mapping/PointerToIndex.hpp @@ -0,0 +1,89 @@ +// Copyright 2022 Bernhard Manfred Gruber +// SPDX-License-Identifier: LGPL-3.0-or-later + +#pragma once + +#include "../ProxyRefOpMixin.hpp" +#include "Common.hpp" + +namespace llama::mapping +{ + struct PointerToRecordDim + { + }; + + namespace internal + { + template + struct ReplacePointerImpl + { + using type = std::conditional_t, Replacement, T>; + }; + + template + struct ReplacePointerImpl> + { + using type = Record< + Field, typename ReplacePointerImpl>::type>...>; + }; + + template + using ReplacePointer = typename ReplacePointerImpl::type; + } // namespace internal + + template typename InnerMapping> + struct PointerToIndex + : private InnerMapping> + { + using Inner = InnerMapping>; + using ArrayExtents = typename Inner::ArrayExtents; + using ArrayIndex = typename Inner::ArrayIndex; + using RecordDim = TRecordDim; // hide Inner::RecordDim + using Inner::blobCount; + using Inner::blobSize; + using Inner::extents; + using Inner::Inner; + + template + LLAMA_FN_HOST_ACC_INLINE static constexpr auto isComputed(RecordCoord) -> bool + { + return std::is_same_v, PointerToRecordDim>; + } + + template + LLAMA_FN_HOST_ACC_INLINE constexpr auto compute( + ArrayIndex ai, + RecordCoord rc, + BlobArray& blobs) const + { + static_assert(isComputed(rc)); + using View = llama::View, accessor::Default>; + ArrayIndex& dstAi = mapToMemory(static_cast(*this), ai, rc, blobs); + auto& view = const_cast(reinterpret_cast(*this)); + return PointerRef>{dstAi, view}; + } + + template + LLAMA_FN_HOST_ACC_INLINE constexpr auto blobNrAndOffset(ArrayIndex ai, RecordCoord rc = {}) + const -> NrAndOffset + { + static_assert(!isComputed(rc)); + return Inner::blobNrAndOffset(ai, rc); + } + }; + + /// Binds parameters to a \ref ChangeType mapping except for array and record dimension, producing a quoted + /// meta function accepting the latter two. Useful to to prepare this mapping for a meta mapping. + template typename InnerMapping> + struct BindPointerToIndex + { + template + using fn = PointerToIndex; + }; + + template + inline constexpr bool isPointerToIndex = false; + + template typename InnerMapping> + inline constexpr bool isPointerToIndex> = true; +} // namespace llama::mapping diff --git a/tests/recordref.cpp b/tests/recordref.cpp index 71d27fe64e..c9329778b9 100644 --- a/tests/recordref.cpp +++ b/tests/recordref.cpp @@ -1291,3 +1291,142 @@ TEST_CASE("ScopedUpdate.RecordRef") test(v); test(v()); } + +TEST_CASE("RecordRef.Pointer") +{ + auto view = llama::allocView( + llama::mapping::PointerToIndex, Vec3I, llama::mapping::BindAoS<>::fn>{{1}}); + + auto rr = view[0]; + auto p = &rr; + auto p2 = p; + CHECK(p2 == p); + p = nullptr; + CHECK(p == nullptr); + CHECK(p2 != p); + CHECK(p2 == &rr); + auto rr2 = *p2; + p2 = nullptr; + CHECK(p2 == p); + STATIC_REQUIRE(std::is_same_v); + CHECK(rr == rr2); + + llama::Pointer> p3 = nullptr; + CHECK(p3 == nullptr); + p3 = &rr; + CHECK(p3 != nullptr); +} + +struct Next +{ +}; + +struct Value +{ +}; + +using Node = llama::Record, llama::Field>; + +TEST_CASE("RecordRef.Pointer.ForwardList") +{ + auto view = llama::allocView( + llama::mapping::PointerToIndex, Node, llama::mapping::BindAoS<>::fn>{{10}}); + + for(int i = 0; i < 10; i++) + { + view[i](Value{}) = i; + if(i < 9) + view[i](Next{}) = &view[i + 1]; + else + view[i](Next{}) = nullptr; + } + + CHECK(view[0](Value{}) == 0); + CHECK((*view[0](Next{}))(Value{}) == 1); + CHECK((*(*view[0](Next{}))(Next{}))(Value{}) == 2); + + // walk list + int i = 0; + auto head = &view[0]; + while(head != nullptr) + { + CHECK((*head)(Value{}) == i); + head = (*head)(Next{}); + i++; + } +} + +struct Left +{ +}; + +struct Right +{ +}; + +using TreeNode = llama::Record< + llama::Field, + llama::Field, + llama::Field>; + +TEST_CASE("RecordRef.Pointer.BinaryTree") +{ + auto nodes = llama::allocView( + llama::mapping::PointerToIndex, TreeNode, llama::mapping::BindAoS<>::fn>{ + {100}}); + int count = 0; + + auto createNode = [&](int value) + { + assert(count < nodes.extents()[0]); + auto p = &nodes[count]; + count++; + (*p)(Value{}) = value; + (*p)(Left{}) = nullptr; + (*p)(Right{}) = nullptr; + return p; + }; + + auto insert = [](auto& tree, auto z) + { + llama::Pointer> y = nullptr; + auto x = tree; + while(x != nullptr) + { + y = x; + if((*z)(Value{}) < (*x)(Value{})) + x = (*x)(Left{}); + else + x = (*x)(Right{}); + } + if(y == nullptr) + tree = z; + else if((*z)(Value{}) < (*y)(Value{})) + (*y)(Left{}) = z; + else + (*y)(Right{}) = z; + }; + + llama::Pointer> t = nullptr; + std::vector values{42, 3, 8, 75, 42, 63, 15, 7, 22}; + for(const int i : values) + insert(t, createNode(i)); + + auto visitInOrder = [](auto tree, auto f, auto&& self) + { + if(tree == nullptr) + return; + self((*tree)(Left{}), f, self); + f((*tree)(Value{})); + self((*tree)(Right{}), f, self); + }; + + std::vector valuesOut; + visitInOrder( + t, + [&](int i) { valuesOut.push_back(i); }, + visitInOrder); + + std::sort(begin(values), end(values)); + CHECK(values == valuesOut); +}