Skip to content

Commit ccc2347

Browse files
committed
improving definitions for our types
Signed-off-by: Dan Hoeflinger <[email protected]>
1 parent cde467a commit ccc2347

File tree

1 file changed

+57
-40
lines changed

1 file changed

+57
-40
lines changed

include/oneapi/dpl/pstl/iterator_impl.h

+57-40
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class zip_forward_iterator
171171
// On windows, this requires clause is necessary so that concepts in MSVC STL do not detect the iterator as
172172
// dereferenceable when a source iterator is a sycl_iterator, which is a supported type.
173173
reference
174-
operator*() const _ONEDPL_CPP20_REQUIRES(std::indirectly_readable<_Types> &&...)
174+
operator*() const _ONEDPL_CPP20_REQUIRES(std::indirectly_readable<_Types>&&...)
175175
{
176176
return __make_references<reference>()(__my_it_, ::std::make_index_sequence<__num_types>());
177177
}
@@ -235,8 +235,16 @@ class counting_iterator
235235
counting_iterator() : __my_counter_() {}
236236
explicit counting_iterator(_Ip __init) : __my_counter_(__init) {}
237237

238-
reference operator*() const { return __my_counter_; }
239-
reference operator[](difference_type __i) const { return *(*this + __i); }
238+
reference
239+
operator*() const
240+
{
241+
return __my_counter_;
242+
}
243+
reference
244+
operator[](difference_type __i) const
245+
{
246+
return *(*this + __i);
247+
}
240248

241249
difference_type
242250
operator-(const counting_iterator& __it) const
@@ -328,17 +336,13 @@ class counting_iterator
328336
return !(*this < __it);
329337
}
330338

339+
friend std::true_type
340+
is_passed_directly_in_onedpl_device_policies(const counting_iterator&);
341+
331342
private:
332343
_Ip __my_counter_;
333344
};
334345

335-
template <typename T>
336-
constexpr auto
337-
is_passed_directly_in_onedpl_device_policies(const oneapi::dpl::counting_iterator<T>&)
338-
{
339-
return std::true_type{};
340-
}
341-
342346
template <typename... _Types>
343347
class zip_iterator
344348
{
@@ -361,13 +365,17 @@ class zip_iterator
361365
// On windows, this requires clause is necessary so that concepts in MSVC STL do not detect the iterator as
362366
// dereferenceable when a source iterator is a sycl_iterator, which is a supported type.
363367
reference
364-
operator*() const _ONEDPL_CPP20_REQUIRES(std::indirectly_readable<_Types> &&...)
368+
operator*() const _ONEDPL_CPP20_REQUIRES(std::indirectly_readable<_Types>&&...)
365369
{
366370
return oneapi::dpl::__internal::__make_references<reference>()(__my_it_,
367371
::std::make_index_sequence<__num_types>());
368372
}
369373

370-
reference operator[](difference_type __i) const { return *(*this + __i); }
374+
reference
375+
operator[](difference_type __i) const
376+
{
377+
return *(*this + __i);
378+
}
371379

372380
difference_type
373381
operator-(const zip_iterator& __it) const
@@ -467,6 +475,10 @@ class zip_iterator
467475
return !(*this < __it);
468476
}
469477

478+
friend auto
479+
is_passed_directly_in_onedpl_device_policies(const zip_iterator&)
480+
-> std::conjunction<oneapi::dpl::__internal::is_passed_directly_to_device<_Types> ...>;
481+
470482
private:
471483
__it_types __my_it_;
472484
};
@@ -485,16 +497,6 @@ make_zip_iterator(std::tuple<_Tp...> __arg)
485497
return zip_iterator<_Tp...>(__arg);
486498
}
487499

488-
template <typename... _Tp>
489-
constexpr auto
490-
is_passed_directly_in_onedpl_device_policies(const oneapi::dpl::zip_iterator<_Tp...>&)
491-
{
492-
if constexpr ((oneapi::dpl::__internal::is_passed_directly_to_device_v<_Tp> && ...))
493-
return std::true_type{};
494-
else
495-
return std::false_type{};
496-
}
497-
498500
template <typename _Iter, typename _UnaryFunc>
499501
class transform_iterator
500502
{
@@ -553,7 +555,11 @@ class transform_iterator
553555
{
554556
return __my_unary_func_(*__my_it_);
555557
}
556-
reference operator[](difference_type __i) const { return *(*this + __i); }
558+
reference
559+
operator[](difference_type __i) const
560+
{
561+
return *(*this + __i);
562+
}
557563
transform_iterator&
558564
operator++()
559565
{
@@ -653,15 +659,11 @@ class transform_iterator
653659
{
654660
return __my_unary_func_;
655661
}
662+
friend auto
663+
is_passed_directly_in_onedpl_device_policies(const transform_iterator&)
664+
-> oneapi::dpl::__internal::is_passed_directly_to_device<_Iter>;
656665
};
657666

658-
template <typename _It, typename _Unary>
659-
constexpr auto
660-
is_passed_directly_in_onedpl_device_policies(const oneapi::dpl::transform_iterator<_It, _Unary>&)
661-
{
662-
return __internal::is_passed_directly_to_device<_It>{};
663-
}
664-
665667
template <typename _Iter, typename _UnaryFunc>
666668
transform_iterator<_Iter, _UnaryFunc>
667669
make_transform_iterator(_Iter __it, _UnaryFunc __unary_func)
@@ -745,12 +747,16 @@ class permutation_iterator
745747
// dereferenceable when the source or map iterator is a sycl_iterator, which is a supported type for both.
746748
reference
747749
operator*() const
748-
_ONEDPL_CPP20_REQUIRES(std::indirectly_readable<SourceIterator> && std::indirectly_readable<IndexMap>)
750+
_ONEDPL_CPP20_REQUIRES(std::indirectly_readable<SourceIterator>&& std::indirectly_readable<IndexMap>)
749751
{
750752
return my_source_it[*my_index];
751753
}
752754

753-
reference operator[](difference_type __i) const { return *(*this + __i); }
755+
reference
756+
operator[](difference_type __i) const
757+
{
758+
return *(*this + __i);
759+
}
754760

755761
permutation_iterator&
756762
operator++()
@@ -851,6 +857,10 @@ class permutation_iterator
851857
return !(*this < it);
852858
}
853859

860+
friend auto is_passed_directly_in_onedpl_device_policies(const permutation_iterator&) -> std::conjunction<oneapi::dpl::__internal::is_passed_directly_to_device<SourceIterator>,
861+
oneapi::dpl::__internal::is_passed_directly_to_device<
862+
typename oneapi::dpl::permutation_iterator<SourceIterator, _Permutation>::IndexMap>>;
863+
854864
private:
855865
SourceIterator my_source_it;
856866
IndexMap my_index;
@@ -923,8 +933,16 @@ class discard_iterator
923933
discard_iterator() : __my_position_() {}
924934
explicit discard_iterator(difference_type __init) : __my_position_(__init) {}
925935

926-
reference operator*() const { return internal::ignore; }
927-
reference operator[](difference_type) const { return internal::ignore; }
936+
reference
937+
operator*() const
938+
{
939+
return internal::ignore;
940+
}
941+
reference
942+
operator[](difference_type) const
943+
{
944+
return internal::ignore;
945+
}
928946

929947
// GCC Bug 66297: constexpr non-static member functions of non-literal types
930948
#if __GNUC__ && _ONEDPL_GCC_VERSION < 70200 && !(__INTEL_COMPILER || __clang__)
@@ -1025,15 +1043,14 @@ class discard_iterator
10251043
return !(*this < __it);
10261044
}
10271045

1028-
private:
1046+
friend std::true_type
1047+
is_passed_directly_in_onedpl_device_policies(const discard_iterator&);
1048+
1049+
private:
10291050
difference_type __my_position_;
10301051
};
10311052

1032-
constexpr auto
1033-
is_passed_directly_in_onedpl_device_policies(const oneapi::dpl::discard_iterator&)
1034-
{
1035-
return std::true_type{};
1036-
}
1053+
10371054

10381055
} // namespace dpl
10391056
} // namespace oneapi

0 commit comments

Comments
 (0)