Skip to content

Commit 34f6a8c

Browse files
committed
add transform_iterator
1 parent 94eee10 commit 34f6a8c

3 files changed

Lines changed: 311 additions & 16 deletions

File tree

core/base/iterator_factory.hpp

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -705,6 +705,145 @@ permute_iterator<IteratorType, PermutationFn> make_permute_iterator(
705705
}
706706

707707

708+
/**
709+
* A transform_iterator is a read-only iterator that wraps an existing iterator
710+
* and applies a transformation to that iterator's value before returning it on
711+
* dereference.
712+
*/
713+
template <typename IteratorType, typename TransformFn>
714+
class transform_iterator {
715+
public:
716+
using difference_type =
717+
typename std::iterator_traits<IteratorType>::difference_type;
718+
using value_type = decltype(std::declval<TransformFn>()(
719+
std::declval<
720+
typename std::iterator_traits<IteratorType>::reference>()));
721+
using pointer = value_type*;
722+
using reference = value_type&;
723+
using iterator_category = std::random_access_iterator_tag;
724+
725+
constexpr transform_iterator() = default;
726+
727+
constexpr explicit transform_iterator(IteratorType it,
728+
TransformFn transform)
729+
: it_{std::move(it)}, transform_{std::move(transform)}
730+
{}
731+
732+
constexpr transform_iterator& operator+=(difference_type i)
733+
{
734+
it_ += i;
735+
return *this;
736+
}
737+
738+
constexpr transform_iterator& operator-=(difference_type i)
739+
{
740+
it_ -= i;
741+
return *this;
742+
}
743+
744+
constexpr transform_iterator& operator++()
745+
{
746+
it_++;
747+
return *this;
748+
}
749+
750+
constexpr transform_iterator operator++(int)
751+
{
752+
auto tmp = *this;
753+
++(*this);
754+
return tmp;
755+
}
756+
757+
constexpr transform_iterator& operator--()
758+
{
759+
it_--;
760+
return *this;
761+
}
762+
763+
constexpr transform_iterator operator--(int)
764+
{
765+
auto tmp = *this;
766+
--(*this);
767+
return tmp;
768+
}
769+
770+
constexpr transform_iterator operator+(difference_type i) const
771+
{
772+
auto tmp = *this;
773+
tmp += i;
774+
return tmp;
775+
}
776+
777+
constexpr friend transform_iterator operator+(
778+
difference_type i, const transform_iterator& iter)
779+
{
780+
return iter + i;
781+
}
782+
783+
constexpr transform_iterator operator-(difference_type i) const
784+
{
785+
auto tmp = *this;
786+
tmp -= i;
787+
return tmp;
788+
}
789+
790+
constexpr difference_type operator-(const transform_iterator& other) const
791+
{
792+
return this->it_ - other.it_;
793+
}
794+
795+
constexpr value_type operator*() const { return transform_(*it_); }
796+
797+
constexpr value_type operator[](difference_type i) const
798+
{
799+
return *(*this + i);
800+
}
801+
802+
constexpr bool operator==(const transform_iterator& other) const
803+
{
804+
return this->it_ == other.it_;
805+
}
806+
807+
constexpr bool operator!=(const transform_iterator& other) const
808+
{
809+
return !(*this == other);
810+
}
811+
812+
constexpr bool operator<(const transform_iterator& other) const
813+
{
814+
return this->it_ < other.it_;
815+
}
816+
817+
constexpr bool operator<=(const transform_iterator& other) const
818+
{
819+
return this->it_ <= other.it_;
820+
}
821+
822+
constexpr bool operator>(const transform_iterator& other) const
823+
{
824+
return !(*this <= other);
825+
}
826+
827+
constexpr bool operator>=(const transform_iterator& other) const
828+
{
829+
return !(*this < other);
830+
}
831+
832+
private:
833+
IteratorType it_;
834+
copy_assignable<TransformFn> transform_;
835+
};
836+
837+
838+
template <typename IteratorType, typename TransformFn>
839+
transform_iterator<IteratorType, TransformFn> make_transform_iterator(
840+
IteratorType it, TransformFn transform)
841+
{
842+
return transform_iterator<IteratorType, TransformFn>{std::move(it),
843+
std::move(transform)};
844+
}
845+
846+
708847
} // namespace detail
709848

710849

core/test/base/iterator_factory.cpp

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

55
#include "core/base/iterator_factory.hpp"
66

77
#include <algorithm>
88
#include <complex>
9+
#include <iterator>
910
#include <numeric>
1011
#include <vector>
1112

@@ -511,4 +512,141 @@ TYPED_TEST(PermuteIterator, DecreasingIterator)
511512
}
512513

513514

515+
template <typename ValueType>
516+
class TransformIterator : public ::testing::Test {
517+
protected:
518+
using value_type = ValueType;
519+
};
520+
521+
TYPED_TEST_SUITE(TransformIterator, gko::test::ComplexAndPODTypes,
522+
TypenameNameGenerator);
523+
524+
525+
TYPED_TEST(TransformIterator, EmptyIterator)
526+
{
527+
auto test_iter = gko::detail::make_transform_iterator<TypeParam*>(
528+
nullptr, [](TypeParam v) { return v; });
529+
530+
ASSERT_NO_THROW((void)std::find(test_iter, test_iter, TypeParam{}));
531+
}
532+
533+
534+
TYPED_TEST(TransformIterator, CopyingWithIdentityFunction)
535+
{
536+
std::vector<TypeParam> vec{6, 2, 5, 2, 4};
537+
std::vector<TypeParam> result;
538+
auto test_iter = gko::detail::make_transform_iterator(
539+
vec.begin(), [](TypeParam v) { return v; });
540+
541+
std::copy(test_iter, test_iter + vec.size(), std::back_inserter(result));
542+
543+
ASSERT_EQ(vec, result);
544+
}
545+
546+
547+
TYPED_TEST(TransformIterator, CopyingWithStatefulFunctor)
548+
{
549+
TypeParam scale = 3;
550+
std::vector<TypeParam> vec{6, 2, 5, 2, 4};
551+
std::vector<TypeParam> ref{18, 6, 15, 6, 12};
552+
std::vector<TypeParam> result;
553+
auto test_iter = gko::detail::make_transform_iterator(
554+
vec.begin(), [scale](TypeParam v) { return scale * v; });
555+
556+
std::copy(test_iter, test_iter + vec.size(), std::back_inserter(result));
557+
558+
ASSERT_EQ(ref, result);
559+
}
560+
561+
562+
TYPED_TEST(TransformIterator, IncreasingIterator)
563+
{
564+
std::vector<TypeParam> vec{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
565+
auto transform = [](TypeParam v) { return TypeParam{2} * v; };
566+
567+
auto test_iter =
568+
gko::detail::make_transform_iterator(vec.begin(), transform);
569+
auto begin = test_iter;
570+
auto plus_2 = begin + 2;
571+
auto plus_2_rev = 2 + begin;
572+
auto plus_minus_2 = plus_2 - 2;
573+
auto increment_pre_2 = begin;
574+
++increment_pre_2;
575+
++increment_pre_2;
576+
auto increment_post_2 = begin;
577+
increment_post_2++;
578+
increment_post_2++;
579+
auto increment_pre_test = begin;
580+
auto increment_post_test = begin;
581+
582+
// check results for equality
583+
ASSERT_TRUE(begin == plus_minus_2);
584+
ASSERT_TRUE(plus_2 == increment_pre_2);
585+
ASSERT_TRUE(plus_2_rev == increment_pre_2);
586+
ASSERT_TRUE(increment_pre_2 == increment_post_2);
587+
ASSERT_TRUE(begin == increment_post_test++);
588+
ASSERT_TRUE(begin + 1 == ++increment_pre_test);
589+
ASSERT_TRUE(*plus_2 == TypeParam{2} * vec[2]);
590+
// check other comparison operators and difference
591+
std::vector<gko::detail::transform_iterator<
592+
typename std::vector<TypeParam>::iterator, decltype(transform)>>
593+
its{begin,
594+
plus_2,
595+
plus_2_rev,
596+
plus_minus_2,
597+
increment_pre_2,
598+
increment_post_2,
599+
increment_pre_test,
600+
increment_post_test,
601+
begin + 5,
602+
begin + 9};
603+
std::sort(its.begin(), its.end());
604+
std::vector<int> dists;
605+
std::vector<int> ref_dists{0, 1, 0, 1, 0, 0, 0, 3, 4};
606+
for (int i = 0; i < its.size() - 1; i++) {
607+
SCOPED_TRACE(i);
608+
dists.push_back(its[i + 1] - its[i]);
609+
auto equal = dists.back() > 0;
610+
ASSERT_EQ(its[i + 1] > its[i], equal);
611+
ASSERT_EQ(its[i] < its[i + 1], equal);
612+
ASSERT_EQ(its[i] != its[i + 1], equal);
613+
ASSERT_EQ(its[i] == its[i + 1], !equal);
614+
ASSERT_EQ(its[i] >= its[i + 1], !equal);
615+
ASSERT_EQ(its[i + 1] <= its[i], !equal);
616+
ASSERT_TRUE(its[i + 1] >= its[i]);
617+
ASSERT_TRUE(its[i] <= its[i + 1]);
618+
}
619+
ASSERT_EQ(dists, ref_dists);
620+
}
621+
622+
623+
TYPED_TEST(TransformIterator, DecreasingIterator)
624+
{
625+
std::vector<TypeParam> vec{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
626+
auto transform = [](TypeParam v) { return TypeParam{2} * v; };
627+
628+
auto test_iter =
629+
gko::detail::make_transform_iterator(vec.begin(), transform);
630+
631+
auto iter = test_iter + 5;
632+
auto minus_2 = iter - 2;
633+
auto minus_plus_2 = minus_2 + 2;
634+
auto decrement_pre_2 = iter;
635+
--decrement_pre_2;
636+
--decrement_pre_2;
637+
auto decrement_post_2 = iter;
638+
decrement_post_2--;
639+
decrement_post_2--;
640+
auto decrement_pre_test = iter;
641+
auto decrement_post_test = iter;
642+
643+
ASSERT_TRUE(iter == minus_plus_2);
644+
ASSERT_TRUE(minus_2 == decrement_pre_2);
645+
ASSERT_TRUE(decrement_pre_2 == decrement_post_2);
646+
ASSERT_TRUE(iter == decrement_post_test--);
647+
ASSERT_TRUE(iter - 1 == --decrement_pre_test);
648+
ASSERT_TRUE(*minus_2 == TypeParam{2} * vec[3]);
649+
}
650+
651+
514652
} // namespace

test/base/iterator_factory.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,7 @@
1717
#include "test/utils/common_fixture.hpp"
1818

1919

20-
class IteratorFactory : public CommonTestFixture {
21-
public:
22-
IteratorFactory()
23-
: key_array{exec, {6, 2, 3, 8, 1, 0, 2}},
24-
value_array{exec, {9, 5, 7, 2, 4, 7, 2}},
25-
expected_key_array{ref, {7, 1, 2, 2, 3, 6, 8}},
26-
expected_value_array{ref, {7, 4, 2, 5, 7, 9, 2}}
27-
{}
28-
29-
gko::array<int> key_array;
30-
gko::array<int> value_array;
31-
gko::array<int> expected_key_array;
32-
gko::array<int> expected_value_array;
33-
};
20+
class IteratorFactory : public CommonTestFixture {};
3421

3522

3623
// nvcc doesn't like device lambdas declared in complex classes, move it out
@@ -65,8 +52,39 @@ void run_zip_iterator(std::shared_ptr<gko::EXEC_TYPE> exec,
6552

6653
TEST_F(IteratorFactory, KernelRunsZipIterator)
6754
{
55+
gko::array<int> key_array{exec, {6, 2, 3, 8, 1, 0, 2}};
56+
gko::array<int> value_array{exec, {9, 5, 7, 2, 4, 7, 2}};
57+
gko::array<int> expected_key_array{ref, {7, 1, 2, 2, 3, 6, 8}};
58+
gko::array<int> expected_value_array{ref, {7, 4, 2, 5, 7, 9, 2}};
59+
6860
run_zip_iterator(exec, key_array, value_array);
6961

7062
GKO_ASSERT_ARRAY_EQ(key_array, expected_key_array);
7163
GKO_ASSERT_ARRAY_EQ(value_array, expected_value_array);
7264
}
65+
66+
67+
// nvcc doesn't like device lambdas declared in complex classes, move it out
68+
void run_transform_iterator(std::shared_ptr<gko::EXEC_TYPE> exec,
69+
gko::array<int>& in_array,
70+
gko::array<int>& out_array)
71+
{
72+
gko::kernels::GKO_DEVICE_NAMESPACE::run_kernel(
73+
exec, [] GKO_KERNEL(auto i, auto it, auto out) { out[i] = it[i]; },
74+
in_array.get_size(),
75+
gko::detail::make_transform_iterator(
76+
in_array.get_data(), [] GKO_KERNEL(auto v) { return -v; }),
77+
out_array);
78+
}
79+
80+
81+
TEST_F(IteratorFactory, KernelRunsTransformIterator)
82+
{
83+
gko::array<int> in_array{exec, {6, 2, 3, 8, 1, 0, 2}};
84+
gko::array<int> out_array{exec, in_array.get_size()};
85+
gko::array<int> ref_array{ref, {-6, -2, -3, -8, -1, -0, -2}};
86+
87+
run_transform_iterator(exec, in_array, out_array);
88+
89+
GKO_ASSERT_ARRAY_EQ(out_array, ref_array);
90+
}

0 commit comments

Comments
 (0)