Skip to content

Commit 53c7f10

Browse files
committed
Update ZipIterator:
- Fix operator[] - Change constexpr to ALPAKA_FN_HOST_ACC
1 parent f6625f8 commit 53c7f10

File tree

2 files changed

+89
-93
lines changed

2 files changed

+89
-93
lines changed

example/zipIterator/src/zipIterator-main.cpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ inline typename std::enable_if<I < sizeof...(Tp), void>::type forEach(std::tuple
2727
forEach<I + 1, FuncT, Tp...>(t, f);
2828
}
2929

30-
template<typename IteratorTupleVal>
31-
void printTuple(IteratorTupleVal tuple)
30+
template<typename TIteratorTupleVal>
31+
void printTuple(TIteratorTupleVal tuple)
3232
{
3333
std::cout << "(";
3434
int index = 0;
35-
int tupleSize = std::tuple_size<IteratorTupleVal>{};
35+
int tupleSize = std::tuple_size<TIteratorTupleVal>{};
3636
forEach(tuple, [&index, tupleSize](auto &x) { std::cout << x << (++index < tupleSize ? ", " : ""); });
3737
std::cout << ")";
3838
}
@@ -113,10 +113,10 @@ int main()
113113

114114
std::cout << "\nTesting zip iterator in host with tuple<uint64_t, char, double>\n\n";
115115

116-
using IteratorTuplePtr = std::tuple<uint64_t*, char*, double*>;
117-
using IteratorTupleVal = std::tuple<uint64_t, char, double>;
118-
IteratorTuplePtr zipTuple = std::make_tuple(hostNative, hostNativeChar, hostNativeDouble);
119-
vikunja::mem::iterator::ZipIterator<IteratorTuplePtr, IteratorTupleVal> zipIter(zipTuple);
116+
using TIteratorTuplePtr = std::tuple<uint64_t*, char*, double*>;
117+
using TIteratorTupleVal = std::tuple<uint64_t, char, double>;
118+
TIteratorTuplePtr zipTuple = std::make_tuple(hostNative, hostNativeChar, hostNativeDouble);
119+
vikunja::mem::iterator::ZipIterator<TIteratorTuplePtr, TIteratorTupleVal> zipIter(zipTuple);
120120

121121
std::cout << "*zipIter: ";
122122
printTuple(*zipIter);
@@ -201,19 +201,17 @@ int main()
201201
std::cout << "\n\n"
202202
<< "-----\n\n";
203203

204-
IteratorTuplePtr deviceZipTuple = std::make_tuple(deviceNative, deviceNativeChar, deviceNativeDouble);
205-
vikunja::mem::iterator::ZipIterator<IteratorTuplePtr, IteratorTupleVal> deviceZipIter(deviceZipTuple);
204+
TIteratorTuplePtr deviceZipTuple = std::make_tuple(deviceNative, deviceNativeChar, deviceNativeDouble);
205+
vikunja::mem::iterator::ZipIterator<TIteratorTuplePtr, TIteratorTupleVal> deviceZipIter(deviceZipTuple);
206206

207-
auto deviceMemResult(alpaka::allocBuf<IteratorTupleVal, Idx>(devAcc, extent));
208-
auto hostMemResult(alpaka::allocBuf<IteratorTupleVal, Idx>(devHost, extent));
209-
IteratorTupleVal* hostNativeResultPtr = alpaka::getPtrNative(hostMemResult);
210-
IteratorTupleVal* deviceNativeResultPtr = alpaka::getPtrNative(deviceMemResult);
207+
auto deviceMemResult(alpaka::allocBuf<TIteratorTupleVal, Idx>(devAcc, extent));
208+
auto hostMemResult(alpaka::allocBuf<TIteratorTupleVal, Idx>(devHost, extent));
209+
TIteratorTupleVal* hostNativeResultPtr = alpaka::getPtrNative(hostMemResult);
210+
TIteratorTupleVal* deviceNativeResultPtr = alpaka::getPtrNative(deviceMemResult);
211211

212-
auto doubleNum = [] ALPAKA_FN_HOST_ACC(IteratorTupleVal const& t)
212+
auto doubleNum = [] ALPAKA_FN_HOST_ACC(TIteratorTupleVal const& t)
213213
{
214-
// return std::make_tuple(2 * std::get<0>(t), std::get<1>(t), 2 * std::get<2>(t));
215-
// return std::make_tuple(static_cast<uint64_t>(5), 'e', static_cast<double>(14.12));
216-
return t;
214+
return std::make_tuple(2 * std::get<0>(t), std::get<1>(t), 2 * std::get<2>(t));
217215
};
218216

219217
vikunja::transform::deviceTransform<TAcc>(

include/vikunja/mem/iterator/ZipIterator.hpp

Lines changed: 74 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -40,117 +40,115 @@ namespace vikunja
4040
{
4141
/**
4242
* @brief A zip iterator that takes multiple input sequences and yields a sequence of tuples
43-
* @tparam IteratorTuplePtr The type of the data
44-
* @tparam IteratorTupleVal The type of the data
45-
* @tparam IdxType The type of the index
43+
* @tparam TIteratorTuplePtr The type of the data
44+
* @tparam TIteratorTupleVal The type of the data
45+
* @tparam TIdx The type of the index
4646
*/
47-
template<typename IteratorTuplePtr, typename IteratorTupleVal, typename IdxType = int64_t>
47+
template<typename TIteratorTuplePtr, typename TIteratorTupleVal, typename TIdx = int64_t>
4848
class ZipIterator
4949
{
5050
public:
5151
// Need all 5 of these types for iterator_traits
52-
using reference = IteratorTupleVal&;
53-
using value_type = IteratorTupleVal;
54-
using pointer = IteratorTupleVal*;
55-
using difference_type = IdxType;
52+
using reference = TIteratorTupleVal&;
53+
using value_type = TIteratorTupleVal;
54+
using pointer = TIteratorTupleVal*;
55+
using difference_type = TIdx;
5656
using iterator_category = std::random_access_iterator_tag;
5757

5858
/**
5959
* @brief Constructor for the ZipIterator
6060
* @param iteratorTuplePtr The tuple to initialize the iterator with
6161
* @param idx The index for the iterator, default 0
6262
*/
63-
constexpr ZipIterator(IteratorTuplePtr iteratorTuplePtr, const IdxType& idx = static_cast<IdxType>(0))
64-
: mIndex(idx)
65-
, mIteratorTuplePtr(iteratorTuplePtr)
66-
, mIteratorTupleVal(makeValueTuple(mIteratorTuplePtr))
63+
ALPAKA_FN_HOST_ACC ZipIterator(TIteratorTuplePtr iteratorTuplePtr, const TIdx& idx = static_cast<TIdx>(0))
64+
: m_index(idx)
65+
, m_iteratorTuplePtr(iteratorTuplePtr)
66+
, m_iteratorTupleVal(makeValueTuple(m_iteratorTuplePtr))
6767
{
6868
if (idx != 0)
6969
{
70-
forEach(mIteratorTuplePtr, [idx](auto &x) { x += idx; });
71-
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
70+
forEach(m_iteratorTuplePtr, [idx](auto &x) { x += idx; });
71+
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
7272
}
7373
}
7474

7575
/**
7676
* @brief Dereference operator to receive the stored value
7777
*/
78-
NODISCARD constexpr ALPAKA_FN_INLINE IteratorTupleVal& operator*()
78+
NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal& operator*()
7979
{
80-
return mIteratorTupleVal;
80+
return m_iteratorTupleVal;
8181
}
8282

8383
/**
8484
* @brief Index operator to get stored value at some given offset from this iterator
8585
*/
86-
NODISCARD constexpr ALPAKA_FN_INLINE const IteratorTupleVal operator[](const IdxType idx)
86+
NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal operator[](const TIdx idx) const
8787
{
88-
IteratorTuplePtr tmp = mIteratorTuplePtr;
89-
IdxType indexDiff = idx - mIndex;
90-
forEach(tmp, [indexDiff](auto &x) { x += indexDiff; });
91-
return makeValueTuple(tmp);
88+
TIdx indexDiff = idx - m_index;
89+
return (*this + indexDiff).operator*();
9290
}
9391

94-
// NODISCARD constexpr ALPAKA_FN_INLINE IteratorTupleVal& operator=(IteratorTupleVal iteratorTupleVal)
92+
// NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal& operator=(TIteratorTupleVal iteratorTupleVal)
9593
// {
9694
// updateIteratorTupleValue(iteratorTupleVal);
97-
// mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
98-
// return mIteratorTupleVal;
95+
// m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
96+
// return m_iteratorTupleVal;
9997
// }
10098

10199
#pragma region arithmeticoperators
102100
/**
103101
* @brief Prefix increment operator
104102
*/
105-
constexpr ALPAKA_FN_INLINE ZipIterator& operator++()
103+
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator++()
106104
{
107-
++mIndex;
108-
forEach(mIteratorTuplePtr, [](auto &x) { ++x; });
109-
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
105+
++m_index;
106+
forEach(m_iteratorTuplePtr, [](auto &x) { ++x; });
107+
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
110108
return *this;
111109
}
112110

113111
/**
114112
* @brief Postfix increment operator
115113
* @note Use prefix increment operator instead if possible to avoid copies
116114
*/
117-
constexpr ZipIterator operator++(int)
115+
ALPAKA_FN_HOST_ACC ZipIterator operator++(int)
118116
{
119117
ZipIterator tmp = *this;
120-
++mIndex;
121-
forEach(mIteratorTuplePtr, [](auto &x) { ++x; });
122-
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
118+
++m_index;
119+
forEach(m_iteratorTuplePtr, [](auto &x) { ++x; });
120+
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
123121
return tmp;
124122
}
125123

126124
/**
127125
* @brief Prefix decrement operator
128126
*/
129-
constexpr ALPAKA_FN_INLINE ZipIterator& operator--()
127+
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator--()
130128
{
131-
--mIndex;
132-
forEach(mIteratorTuplePtr, [](auto &x) { --x; });
133-
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
129+
--m_index;
130+
forEach(m_iteratorTuplePtr, [](auto &x) { --x; });
131+
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
134132
return *this;
135133
}
136134

137135
/**
138136
* @brief Postfix decrement operator
139137
* @note Use prefix decrement operator instead if possible to avoid copies
140138
*/
141-
constexpr ALPAKA_FN_INLINE ZipIterator operator--(int)
139+
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator operator--(int)
142140
{
143141
ZipIterator tmp = *this;
144-
--mIndex;
145-
forEach(mIteratorTuplePtr, [](auto &x) { --x; });
146-
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
142+
--m_index;
143+
forEach(m_iteratorTuplePtr, [](auto &x) { --x; });
144+
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
147145
return tmp;
148146
}
149147

150148
/**
151149
* @brief Add an index to this iterator
152150
*/
153-
NODISCARD constexpr friend ALPAKA_FN_INLINE ZipIterator operator+(ZipIterator zipIter, const IdxType idx)
151+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE ZipIterator operator+(ZipIterator zipIter, const TIdx idx)
154152
{
155153
zipIter += idx;
156154
return zipIter;
@@ -159,7 +157,7 @@ namespace vikunja
159157
/**
160158
* @brief Subtract an index from this iterator
161159
*/
162-
NODISCARD constexpr friend ALPAKA_FN_INLINE ZipIterator operator-(ZipIterator zipIter, const IdxType idx)
160+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE ZipIterator operator-(ZipIterator zipIter, const TIdx idx)
163161
{
164162
zipIter -= idx;
165163
return zipIter;
@@ -168,22 +166,22 @@ namespace vikunja
168166
/**
169167
* @brief Add an index to this iterator
170168
*/
171-
constexpr ALPAKA_FN_INLINE ZipIterator& operator+=(const IdxType idx)
169+
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator+=(const TIdx idx)
172170
{
173-
mIndex += idx;
174-
forEach(mIteratorTuplePtr, [idx](auto &x) { x += idx; });
175-
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
171+
m_index += idx;
172+
forEach(m_iteratorTuplePtr, [idx](auto &x) { x += idx; });
173+
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
176174
return *this;
177175
}
178176

179177
/**
180178
* @brief Subtract an index from this iterator
181179
*/
182-
constexpr ALPAKA_FN_INLINE ZipIterator& operator-=(const IdxType idx)
180+
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator-=(const TIdx idx)
183181
{
184-
mIndex -= idx;
185-
forEach(mIteratorTuplePtr, [idx](auto &x) { x -= idx; });
186-
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
182+
m_index -= idx;
183+
forEach(m_iteratorTuplePtr, [idx](auto &x) { x -= idx; });
184+
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
187185
return *this;
188186
}
189187

@@ -196,68 +194,68 @@ namespace vikunja
196194
/**
197195
* @brief Spaceship operator for comparisons
198196
*/
199-
NODISCARD constexpr ALPAKA_FN_INLINE auto operator<=>(const ZipIterator& other) const noexcept
197+
NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE auto operator<=>(const ZipIterator& other) const noexcept
200198
{
201-
return mIteratorTuplePtr.operator<=>(other.mIteratorTuplePtr);
199+
return m_iteratorTuplePtr.operator<=>(other.m_iteratorTuplePtr);
202200
}
203201

204202
#else
205203

206204
/**
207205
* @brief Equality comparison, returns true if the index are the same
208206
*/
209-
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator==(const ZipIterator& zipIter, const ZipIterator& other) noexcept
207+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator==(const ZipIterator& zipIter, const ZipIterator& other) noexcept
210208
{
211-
return zipIter.mIndex == other.mIndex;
209+
return zipIter.m_index == other.m_index;
212210
}
213211

214212
/**
215213
* @brief Inequality comparison, negated equality operator
216214
*/
217-
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator!=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
215+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator!=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
218216
{
219217
return !operator==(zipIter, other);
220218
}
221219

222220
/**
223221
* @brief Less than comparison, index is checked
224222
*/
225-
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator<(const ZipIterator& zipIter, const ZipIterator& other) noexcept
223+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator<(const ZipIterator& zipIter, const ZipIterator& other) noexcept
226224
{
227-
return zipIter.mIndex < other.mIndex;
225+
return zipIter.m_index < other.m_index;
228226
}
229227

230228
/**
231229
* @brief Greater than comparison, index is checked
232230
*/
233-
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator>(const ZipIterator& zipIter, const ZipIterator& other) noexcept
231+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator>(const ZipIterator& zipIter, const ZipIterator& other) noexcept
234232
{
235-
return zipIter.mIndex > other.mIndex;
233+
return zipIter.m_index > other.m_index;
236234
}
237235

238236
/**
239237
* @brief Less than or equal comparison, index is checked
240238
*/
241-
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator<=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
239+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator<=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
242240
{
243-
return zipIter.mIndex <= other.mIndex;
241+
return zipIter.m_index <= other.m_index;
244242
}
245243

246244
/**
247245
* @brief Greater than or equal comparison, index is checked
248246
*/
249-
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator>=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
247+
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator>=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
250248
{
251-
return zipIter.mIndex >= other.mIndex;
249+
return zipIter.m_index >= other.m_index;
252250
}
253251
#endif
254252

255253
#pragma endregion comparisonoperators
256254

257255
private:
258-
IdxType mIndex;
259-
IteratorTuplePtr mIteratorTuplePtr;
260-
IteratorTupleVal mIteratorTupleVal;
256+
TIdx m_index;
257+
TIteratorTuplePtr m_iteratorTuplePtr;
258+
TIteratorTupleVal m_iteratorTupleVal;
261259

262260
template<int... Is>
263261
struct seq { };
@@ -294,17 +292,17 @@ namespace vikunja
294292
forEach<I + 1, FuncT, Tp...>(t, f);
295293
}
296294

297-
template<std::size_t I = 0, typename... Tp>
298-
inline typename std::enable_if<I == sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...> &) // Unused arguments are given no names
299-
{
300-
}
295+
// template<std::size_t I = 0, typename... Tp>
296+
// inline typename std::enable_if<I == sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...> &) // Unused arguments are given no names
297+
// {
298+
// }
301299

302-
template<std::size_t I = 0, typename... Tp>
303-
inline typename std::enable_if<I < sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...>& t)
304-
{
305-
*std::get<I>(mIteratorTuplePtr) = std::get<I>(t);
306-
updateIteratorTupleValue<I + 1, Tp...>(t);
307-
}
300+
// template<std::size_t I = 0, typename... Tp>
301+
// inline typename std::enable_if<I < sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...>& t)
302+
// {
303+
// *std::get<I>(m_iteratorTuplePtr) = std::get<I>(t);
304+
// updateIteratorTupleValue<I + 1, Tp...>(t);
305+
// }
308306
};
309307

310308
} // namespace iterator

0 commit comments

Comments
 (0)