Skip to content

Indirectly Device Accessible Iterator Trait and ADL Customization Point #2126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
40ecbb2
is passed directly implementation
danhoeflinger Mar 10, 2025
9803d11
passed directly initial draft
danhoeflinger Mar 10, 2025
89bc2d8
improve default impl of customization point
danhoeflinger Mar 11, 2025
ffe8b63
tests for passed directly and fix for reverse iterator
danhoeflinger Mar 11, 2025
c4cb1b7
hiding sycl specifics in iterator_impl
danhoeflinger Mar 12, 2025
db149d7
spelling
danhoeflinger Mar 12, 2025
be71462
formatting
danhoeflinger Mar 12, 2025
50a48c9
adding test for custom wrapped iter type
danhoeflinger Mar 14, 2025
9c7b75a
update impl to match RFC
danhoeflinger Mar 21, 2025
28e28c1
remove injection into std namespace
danhoeflinger Mar 21, 2025
c3b1177
clang format
danhoeflinger Mar 21, 2025
9b5296d
adding (and fixing) test cases
danhoeflinger Mar 21, 2025
7d7aaf5
formatting
danhoeflinger Mar 21, 2025
16b8c7f
adding hidden friend without body
danhoeflinger Mar 21, 2025
992c996
codespell
danhoeflinger Mar 21, 2025
b5e9b37
improving definitions for our types
danhoeflinger Mar 21, 2025
1ba5b4b
improving internal traits to depend upon
danhoeflinger Apr 4, 2025
3988cb2
implementing as body-less functions
danhoeflinger Apr 4, 2025
508fb1e
formatting
danhoeflinger Apr 4, 2025
062ac5a
::std -> std
danhoeflinger Apr 4, 2025
82be2d0
formatting
danhoeflinger Apr 4, 2025
6e56639
formatting
danhoeflinger Apr 4, 2025
2f38e16
adjusting to new naming in spec PR
danhoeflinger Apr 4, 2025
58adaa1
enforce trait is a bool_constant
danhoeflinger Apr 11, 2025
106142f
Formatting
danhoeflinger Apr 11, 2025
6904ec5
adjust naming to match spec
danhoeflinger Apr 14, 2025
7efa1b7
moving definition to public namespace
danhoeflinger Apr 14, 2025
6e82c08
formatting
danhoeflinger Apr 14, 2025
dee32ad
sycl_iterator is indirecly accessible but not passed directly
danhoeflinger Apr 14, 2025
0661d31
static_assert to enforce bool_constant
danhoeflinger Apr 17, 2025
aa62583
formatting
danhoeflinger Apr 17, 2025
2bf81cc
revert comment change
danhoeflinger Apr 17, 2025
3a2a45c
adding fully qualified name
danhoeflinger Apr 18, 2025
162aa24
renaming test
danhoeflinger Apr 18, 2025
f80ab15
address feedback
danhoeflinger Apr 18, 2025
aa80e2e
Fixing naming to match spec (remove "_iterator")
danhoeflinger Apr 18, 2025
c3c2f02
formatting
danhoeflinger Apr 18, 2025
6a85701
using type alias instead of struct
danhoeflinger Apr 18, 2025
6e3b95d
fix is_sycl_iterator
danhoeflinger Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions include/oneapi/dpl/pstl/hetero/dpcpp/sycl_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,25 @@ struct sycl_iterator
return *this - it < 0;
}

// This function is required for types for which oneapi::dpl::__ranges::is_sycl_iterator_v = true to ensure proper
// handling by oneapi::dpl::__ranges::__get_sycl_range
// This function is required for types for which oneapi::dpl::__ranges::is_sycl_or_hetero_iterator = true to ensure
// proper handling by oneapi::dpl::__ranges::__get_sycl_range
sycl::buffer<T, dim, Allocator>
get_buffer() const
{
return buffer;
}

// This function is required for types for which oneapi::dpl::__ranges::is_sycl_iterator_v = true to ensure proper
// handling by oneapi::dpl::__ranges::__get_sycl_range
// This function is required for types for which oneapi::dpl::__ranges::is_sycl_or_hetero_iterator = true to ensure
// proper handling by oneapi::dpl::__ranges::__get_sycl_range
Size
get_idx() const
{
return idx;
}

// While sycl_iterator cannot be "passed directly" because it is not device_copyable or a random access iterator,
// it does represent indirectly device accessible data.
friend std::true_type is_onedpl_indirectly_device_accessible_iterator(sycl_iterator);
};

// map access_mode tag to access_mode value
Expand Down Expand Up @@ -162,21 +166,28 @@ template <typename Iter, typename ValueType = std::decay_t<typename std::iterato
using __usm_host_alloc_vec_iter =
typename std::vector<ValueType, typename sycl::usm_allocator<ValueType, sycl::usm::alloc::host>>::iterator;

// Evaluates to true if the provided type is an iterator with a value_type and if the implementation of a
// Evaluates to true_type if the provided type is an iterator with a value_type and if the implementation of a
// std::vector<value_type, Alloc>::iterator can be distinguished between three different allocators, the
// default, usm_shared, and usm_host. If all are distinct, it is very unlikely any non-usm based allocator
// could be confused with a usm allocator.
template <typename Iter>
constexpr bool __vector_iter_distinguishes_by_allocator_v =
!std::is_same_v<__default_alloc_vec_iter<Iter>, __usm_shared_alloc_vec_iter<Iter>> &&
!std::is_same_v<__default_alloc_vec_iter<Iter>, __usm_host_alloc_vec_iter<Iter>> &&
!std::is_same_v<__usm_host_alloc_vec_iter<Iter>, __usm_shared_alloc_vec_iter<Iter>>;
using __vector_iter_distinguishes_by_allocator =
std::conjunction<std::negation<std::is_same<__default_alloc_vec_iter<Iter>, __usm_shared_alloc_vec_iter<Iter>>>,
std::negation<std::is_same<__default_alloc_vec_iter<Iter>, __usm_host_alloc_vec_iter<Iter>>>,
std::negation<std::is_same<__usm_host_alloc_vec_iter<Iter>, __usm_shared_alloc_vec_iter<Iter>>>>;

template <typename Iter>
inline constexpr bool __vector_iter_distinguishes_by_allocator_v =
__vector_iter_distinguishes_by_allocator<Iter>::value;

template <typename Iter>
using __is_known_usm_vector_iter =
std::conjunction<__vector_iter_distinguishes_by_allocator<Iter>,
std::disjunction<std::is_same<Iter, oneapi::dpl::__internal::__usm_shared_alloc_vec_iter<Iter>>,
std::is_same<Iter, oneapi::dpl::__internal::__usm_host_alloc_vec_iter<Iter>>>>;

template <typename Iter>
constexpr bool __is_known_usm_vector_iter_v =
oneapi::dpl::__internal::__vector_iter_distinguishes_by_allocator_v<Iter> &&
(std::is_same_v<Iter, oneapi::dpl::__internal::__usm_shared_alloc_vec_iter<Iter>> ||
std::is_same_v<Iter, oneapi::dpl::__internal::__usm_host_alloc_vec_iter<Iter>>);
inline constexpr bool __is_known_usm_vector_iter_v = __is_known_usm_vector_iter<Iter>::value;

} // namespace __internal

Expand Down
97 changes: 29 additions & 68 deletions include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,62 +196,6 @@ struct is_permutation<Iter, ::std::enable_if_t<Iter::is_permutation::value>> : :
{
};

//is_passed_directly trait definition; specializations for the oneDPL iterators

template <typename Iter, typename Void = void>
struct is_passed_directly : ::std::is_pointer<Iter>
{
};

//support legacy "is_passed_directly" trait
template <typename Iter>
struct is_passed_directly<Iter, ::std::enable_if_t<Iter::is_passed_directly::value>> : ::std::true_type
{
};

//support std::vector::iterator with usm host / shared allocator as passed directly
template <typename Iter>
struct is_passed_directly<Iter, std::enable_if_t<oneapi::dpl::__internal::__is_known_usm_vector_iter_v<Iter>>>
: std::true_type
{
};

template <typename Ip>
struct is_passed_directly<oneapi::dpl::counting_iterator<Ip>> : ::std::true_type
{
};

template <>
struct is_passed_directly<oneapi::dpl::discard_iterator> : ::std::true_type
{
};

template <typename Iter>
struct is_passed_directly<::std::reverse_iterator<Iter>> : is_passed_directly<Iter>
{
};

template <typename Iter, typename Unary>
struct is_passed_directly<oneapi::dpl::transform_iterator<Iter, Unary>> : is_passed_directly<Iter>
{
};

template <typename SourceIterator, typename IndexIterator>
struct is_passed_directly<oneapi::dpl::permutation_iterator<SourceIterator, IndexIterator>>
: ::std::conjunction<
is_passed_directly<SourceIterator>,
is_passed_directly<typename oneapi::dpl::permutation_iterator<SourceIterator, IndexIterator>::IndexMap>>
{
};

template <typename... Iters>
struct is_passed_directly<zip_iterator<Iters...>> : ::std::conjunction<is_passed_directly<Iters>...>
{
};

template <typename Iter>
inline constexpr bool is_passed_directly_v = is_passed_directly<Iter>::value;

// A trait for checking if iterator is heterogeneous or not

template <typename Iter>
Expand All @@ -275,7 +219,18 @@ struct is_hetero_legacy_trait<Iter, ::std::enable_if_t<Iter::is_hetero::value>>
};

template <typename Iter>
inline constexpr bool is_sycl_iterator_v = is_sycl_iterator<Iter>::value || is_hetero_legacy_trait<Iter>::value;
using is_sycl_or_hetero_iterator = std::disjunction<is_sycl_iterator<Iter>, is_hetero_legacy_trait<Iter>>;

template <typename Iter>
inline constexpr bool is_sycl_or_hetero_iterator_v = is_sycl_or_hetero_iterator<Iter>::value;

template <typename _Iter>
using __is_passed_directly_device_ready =
std::conjunction<oneapi::dpl::is_indirectly_device_accessible<_Iter>, sycl::is_device_copyable<_Iter>,
oneapi::dpl::__internal::__is_random_access_iterator<_Iter>>;

template <typename _Iter>
inline constexpr bool __is_passed_directly_device_ready_v = __is_passed_directly_device_ready<_Iter>::value;

//A trait for checking if it needs to create a temporary SYCL buffer or not

Expand All @@ -285,8 +240,9 @@ struct is_temp_buff : ::std::false_type
};

template <typename _Iter>
struct is_temp_buff<_Iter, ::std::enable_if_t<!is_sycl_iterator_v<_Iter> && !::std::is_pointer_v<_Iter> &&
!is_passed_directly_v<_Iter>>> : ::std::true_type
struct is_temp_buff<
_Iter, std::enable_if_t<!oneapi::dpl::__ranges::is_sycl_or_hetero_iterator_v<_Iter> && !std::is_pointer_v<_Iter> &&
!oneapi::dpl::__ranges::__is_passed_directly_device_ready_v<_Iter>>> : std::true_type
{
};

Expand Down Expand Up @@ -519,15 +475,15 @@ struct __get_sycl_range

//specialization for permutation_iterator using sycl_iterator as source
template <sycl::access::mode _LocalAccMode, typename _It, typename _Map,
::std::enable_if_t<is_sycl_iterator_v<_It>, int> = 0>
::std::enable_if_t<oneapi::dpl::__ranges::is_sycl_or_hetero_iterator_v<_It>, int> = 0>
auto
__process_input_iter(oneapi::dpl::permutation_iterator<_It, _Map> __first,
oneapi::dpl::permutation_iterator<_It, _Map> __last)
{
auto __n = __last - __first;
assert(__n > 0);

// Types for which oneapi::dpl::__ranges::is_sycl_iterator_v = true should have both:
// Types for which oneapi::dpl::__ranges::is_sycl_or_hetero_iterator_v = true should have both:
// "get_buffer()" to return the buffer they are base upon and
// "get_idx()" to return the buffer offset

Expand All @@ -547,7 +503,9 @@ struct __get_sycl_range

//specialization for permutation_iterator using USM pointer or direct pass object as source
template <sycl::access::mode _LocalAccMode, typename _Iter, typename _Map,
::std::enable_if_t<!is_sycl_iterator_v<_Iter> && is_passed_directly_v<_Iter>, int> = 0>
std::enable_if_t<!oneapi::dpl::__ranges::is_sycl_or_hetero_iterator_v<_Iter> &&
oneapi::dpl::__ranges::__is_passed_directly_device_ready_v<_Iter>,
int> = 0>
auto
__process_input_iter(oneapi::dpl::permutation_iterator<_Iter, _Map> __first,
oneapi::dpl::permutation_iterator<_Iter, _Map> __last)
Expand All @@ -563,9 +521,11 @@ struct __get_sycl_range
}

// specialization for general case, permutation_iterator with base iterator that is not sycl_iterator or
// passed directly.
// device accessible content iterators.
template <sycl::access::mode _LocalAccMode, typename _Iter, typename _Map,
::std::enable_if_t<!is_sycl_iterator_v<_Iter> && !is_passed_directly_v<_Iter>, int> = 0>
std::enable_if_t<!oneapi::dpl::__ranges::is_sycl_or_hetero_iterator_v<_Iter> &&
!oneapi::dpl::__ranges::__is_passed_directly_device_ready_v<_Iter>,
int> = 0>
auto
__process_input_iter(oneapi::dpl::permutation_iterator<_Iter, _Map> __first,
oneapi::dpl::permutation_iterator<_Iter, _Map> __last)
Expand All @@ -574,7 +534,7 @@ struct __get_sycl_range
assert(__n > 0);

//TODO: investigate better method of handling this specifically for fancy_iterators which are composed fully
// of a combination of fancy_iterators, sycl_iterators, and is_passed_directly types.
// of a combination of fancy_iterators, sycl_iterators, and passed_directly types.
// Currently this relies on UB because the size of the accessor when handling sycl_iterators
// in recursion below this level is incorrect.
auto res_src = this->operator()(__first.base(), __first.base() + 1 /*source size*/);
Expand All @@ -600,7 +560,8 @@ struct __get_sycl_range

// for raw pointers and direct pass objects (for example, counting_iterator, iterator of USM-containers)
template <sycl::access::mode _LocalAccMode, typename _Iter>
::std::enable_if_t<is_passed_directly_v<_Iter>, __range_holder<oneapi::dpl::__ranges::guard_view<_Iter>>>
std::enable_if_t<oneapi::dpl::__ranges::__is_passed_directly_device_ready_v<_Iter>,
__range_holder<oneapi::dpl::__ranges::guard_view<_Iter>>>
__process_input_iter(_Iter __first, _Iter __last)
{
assert(__first < __last);
Expand All @@ -612,13 +573,13 @@ struct __get_sycl_range
template <sycl::access::mode _LocalAccMode, typename _Iter>
auto
__process_input_iter(_Iter __first, _Iter __last)
-> ::std::enable_if_t<is_sycl_iterator_v<_Iter>,
-> ::std::enable_if_t<oneapi::dpl::__ranges::is_sycl_or_hetero_iterator_v<_Iter>,
__range_holder<oneapi::dpl::__ranges::all_view<val_t<_Iter>, _LocalAccMode>>>
{
assert(__first < __last);
using value_type = val_t<_Iter>;

// Types for which oneapi::dpl::__ranges::is_sycl_iterator_v = true should have both:
// Types for which oneapi::dpl::__ranges::is_sycl_or_hetero_iterator_v = true should have both:
// "get_buffer()" to return the buffer they are base upon and
// "get_idx()" to return the buffer offset

Expand Down
78 changes: 78 additions & 0 deletions include/oneapi/dpl/pstl/iterator_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <iterator>
#include <tuple>
#include <cassert>
#include <type_traits>

#include "onedpl_config.h"
#include "utils.h"
Expand Down Expand Up @@ -75,6 +76,38 @@ struct __make_references
}
};

template <typename _Iter, typename = void>
struct __is_legacy_passed_directly : std::false_type
{
};

template <typename _Iter>
struct __is_legacy_passed_directly<_Iter, ::std::enable_if_t<_Iter::is_passed_directly::value>> : std::true_type
{
};

template <typename _T>
struct __is_reversed_indirectly_device_accessible_it;

template <typename T>
constexpr auto is_onedpl_indirectly_device_accessible_iterator(T)
-> std::disjunction<
#if _ONEDPL_BACKEND_SYCL
oneapi::dpl::__internal::__is_known_usm_vector_iter<std::decay_t<T>>, // USM vector iterator
#endif // _ONEDPL_BACKEND_SYCL
std::is_pointer<std::decay_t<T>>, // USM pointer
oneapi::dpl::__internal::__is_legacy_passed_directly<std::decay_t<T>>, // legacy passed directly iter
oneapi::dpl::__internal::__is_reversed_indirectly_device_accessible_it<std::decay_t<T>>>; // reverse iterator

struct __is_onedpl_indirectly_device_accessible_iterator_fn
{
template <typename T>
constexpr auto
operator()(const T& t) const -> decltype(is_onedpl_indirectly_device_accessible_iterator(t));
};

inline constexpr __is_onedpl_indirectly_device_accessible_iterator_fn __is_onedpl_indirectly_device_accessible_iterator;

//zip_iterator version for forward iterator
//== and != comparison is performed only on the first element of the tuple
//
Expand Down Expand Up @@ -157,6 +190,22 @@ namespace oneapi
{
namespace dpl
{

template <typename T>
struct is_indirectly_device_accessible
: decltype(oneapi::dpl::__internal::__is_onedpl_indirectly_device_accessible_iterator(std::declval<T>()))
{
static_assert(
std::is_same_v<decltype(decltype(oneapi::dpl::__internal::__is_onedpl_indirectly_device_accessible_iterator(
std::declval<T>()))::value),
const bool>,
"Return type of is_onedpl_indirectly_device_accessible_iterator does not have the characteristics of a "
"bool_constant");
};

template <typename T>
inline constexpr bool is_indirectly_device_accessible_v = is_indirectly_device_accessible<T>::value;

template <typename _Ip>
class counting_iterator
{
Expand Down Expand Up @@ -266,6 +315,9 @@ class counting_iterator
return !(*this < __it);
}

friend std::true_type
is_onedpl_indirectly_device_accessible_iterator(const counting_iterator&);

private:
_Ip __my_counter_;
};
Expand Down Expand Up @@ -398,6 +450,10 @@ class zip_iterator
return !(*this < __it);
}

friend auto
is_onedpl_indirectly_device_accessible_iterator(const zip_iterator&)
-> std::conjunction<oneapi::dpl::is_indirectly_device_accessible<_Types>...>;

private:
__it_types __my_it_;
};
Expand Down Expand Up @@ -574,6 +630,9 @@ class transform_iterator
{
return __my_unary_func_;
}
friend auto
is_onedpl_indirectly_device_accessible_iterator(const transform_iterator&)
-> oneapi::dpl::is_indirectly_device_accessible<_Iter>;
};

template <typename _Iter, typename _UnaryFunc>
Expand Down Expand Up @@ -765,6 +824,12 @@ class permutation_iterator
return !(*this < it);
}

friend auto
is_onedpl_indirectly_device_accessible_iterator(const permutation_iterator&)
-> std::conjunction<oneapi::dpl::is_indirectly_device_accessible<SourceIterator>,
oneapi::dpl::is_indirectly_device_accessible<
typename oneapi::dpl::permutation_iterator<SourceIterator, _Permutation>::IndexMap>>;

private:
SourceIterator my_source_it;
IndexMap my_index;
Expand Down Expand Up @@ -927,6 +992,9 @@ class discard_iterator
return !(*this < __it);
}

friend std::true_type
is_onedpl_indirectly_device_accessible_iterator(const discard_iterator&);

private:
difference_type __my_position_;
};
Expand All @@ -940,6 +1008,16 @@ namespace dpl
{
namespace __internal
{
template <typename _T>
struct __is_reversed_indirectly_device_accessible_it : std::false_type
{
};

template <typename _BaseIter>
struct __is_reversed_indirectly_device_accessible_it<std::reverse_iterator<_BaseIter>>
: oneapi::dpl::is_indirectly_device_accessible<_BaseIter>
{
};

struct make_zipiterator_functor
{
Expand Down
Loading
Loading