Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Passed directly trait #2126

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/oneapi/dpl/iterator
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ using ::std::rbegin;
using ::std::rend;
using ::std::reverse_iterator;
using ::std::size;

using oneapi::dpl::__internal::is_passed_directly_to_device;
using oneapi::dpl::__internal::is_passed_directly_to_device_v;

} // namespace dpl
} // namespace oneapi
namespace dpl = oneapi::dpl;
Expand Down
78 changes: 14 additions & 64 deletions include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,62 +200,6 @@ struct is_permutation<Iter, ::std::enable_if_t<Iter::is_permutation::value>> : :
{
};

//is_passed_directly trait definition; specializations for the oneDPL iterators

template <typename Iter, typename Void = void>
struct is_passed_directly : ::std::is_pointer<Iter>
{
};

//support legacy "is_passed_directly" trait
template <typename Iter>
struct is_passed_directly<Iter, ::std::enable_if_t<Iter::is_passed_directly::value>> : ::std::true_type
{
};

//support std::vector::iterator with usm host / shared allocator as passed directly
template <typename Iter>
struct is_passed_directly<Iter, std::enable_if_t<oneapi::dpl::__internal::__is_known_usm_vector_iter_v<Iter>>>
: std::true_type
{
};

template <typename Ip>
struct is_passed_directly<oneapi::dpl::counting_iterator<Ip>> : ::std::true_type
{
};

template <>
struct is_passed_directly<oneapi::dpl::discard_iterator> : ::std::true_type
{
};

template <typename Iter>
struct is_passed_directly<::std::reverse_iterator<Iter>> : is_passed_directly<Iter>
{
};

template <typename Iter, typename Unary>
struct is_passed_directly<oneapi::dpl::transform_iterator<Iter, Unary>> : is_passed_directly<Iter>
{
};

template <typename SourceIterator, typename IndexIterator>
struct is_passed_directly<oneapi::dpl::permutation_iterator<SourceIterator, IndexIterator>>
: ::std::conjunction<
is_passed_directly<SourceIterator>,
is_passed_directly<typename oneapi::dpl::permutation_iterator<SourceIterator, IndexIterator>::IndexMap>>
{
};

template <typename... Iters>
struct is_passed_directly<zip_iterator<Iters...>> : ::std::conjunction<is_passed_directly<Iters>...>
{
};

template <typename Iter>
inline constexpr bool is_passed_directly_v = is_passed_directly<Iter>::value;

// A trait for checking if iterator is heterogeneous or not

template <typename Iter>
Expand Down Expand Up @@ -290,7 +234,8 @@ struct is_temp_buff : ::std::false_type

template <typename _Iter>
struct is_temp_buff<_Iter, ::std::enable_if_t<!is_sycl_iterator_v<_Iter> && !::std::is_pointer_v<_Iter> &&
!is_passed_directly_v<_Iter>>> : ::std::true_type
!oneapi::dpl::__internal::is_passed_directly_to_device_v<_Iter>>>
: ::std::true_type
{
};

Expand Down Expand Up @@ -550,8 +495,10 @@ struct __get_sycl_range
}

//specialization for permutation_iterator using USM pointer or direct pass object as source
template <sycl::access::mode _LocalAccMode, typename _Iter, typename _Map,
::std::enable_if_t<!is_sycl_iterator_v<_Iter> && is_passed_directly_v<_Iter>, int> = 0>
template <
sycl::access::mode _LocalAccMode, typename _Iter, typename _Map,
::std::enable_if_t<!is_sycl_iterator_v<_Iter> && oneapi::dpl::__internal::is_passed_directly_to_device_v<_Iter>,
int> = 0>
auto
__process_input_iter(oneapi::dpl::permutation_iterator<_Iter, _Map> __first,
oneapi::dpl::permutation_iterator<_Iter, _Map> __last)
Expand All @@ -568,8 +515,10 @@ struct __get_sycl_range

// specialization for general case, permutation_iterator with base iterator that is not sycl_iterator or
// passed directly.
template <sycl::access::mode _LocalAccMode, typename _Iter, typename _Map,
::std::enable_if_t<!is_sycl_iterator_v<_Iter> && !is_passed_directly_v<_Iter>, int> = 0>
template <
sycl::access::mode _LocalAccMode, typename _Iter, typename _Map,
::std::enable_if_t<
!is_sycl_iterator_v<_Iter> && !oneapi::dpl::__internal::is_passed_directly_to_device_v<_Iter>, int> = 0>
auto
__process_input_iter(oneapi::dpl::permutation_iterator<_Iter, _Map> __first,
oneapi::dpl::permutation_iterator<_Iter, _Map> __last)
Expand All @@ -578,8 +527,8 @@ struct __get_sycl_range
assert(__n > 0);

//TODO: investigate better method of handling this specifically for fancy_iterators which are composed fully
// of a combination of fancy_iterators, sycl_iterators, and is_passed_directly types.
// Currently this relies on UB because the size of the accessor when handling sycl_iterators
// of a combination of fancy_iterators, sycl_iterators, and is_passed_directly_in_onedpl_device_policies
// types. Currently this relies on UB because the size of the accessor when handling sycl_iterators
// in recursion below this level is incorrect.
auto res_src = this->operator()(__first.base(), __first.base() + 1 /*source size*/);

Expand All @@ -604,7 +553,8 @@ struct __get_sycl_range

// for raw pointers and direct pass objects (for example, counting_iterator, iterator of USM-containers)
template <sycl::access::mode _LocalAccMode, typename _Iter>
::std::enable_if_t<is_passed_directly_v<_Iter>, __range_holder<oneapi::dpl::__ranges::guard_view<_Iter>>>
::std::enable_if_t<oneapi::dpl::__internal::is_passed_directly_to_device_v<_Iter>,
__range_holder<oneapi::dpl::__ranges::guard_view<_Iter>>>
__process_input_iter(_Iter __first, _Iter __last)
{
assert(__first < __last);
Expand Down
84 changes: 83 additions & 1 deletion include/oneapi/dpl/pstl/iterator_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,68 @@ struct __make_references
}
};

template <typename _Iter, typename _Void = void>
struct __is_legacy_passed_directly : std::false_type
{
};

template <typename _Iter>
struct __is_legacy_passed_directly<_Iter, ::std::enable_if_t<_Iter::is_passed_directly::value>> : std::true_type
{
};

template <typename _T>
struct __is_reverse_iterator_passed_directly;

template <typename T>
constexpr auto
is_passed_directly_in_onedpl_device_policies(const T&)
{
if constexpr (std::is_pointer<std::decay_t<T>>::value)
return std::true_type{};
#if _ONEDPL_BACKEND_SYCL
// TODO: hide this better in sycl backend, either all passed directly functions, or just this
else if constexpr (oneapi::dpl::__internal::__is_known_usm_vector_iter_v<std::decay_t<T>>)
return std::true_type{};
#endif
else if constexpr (__is_legacy_passed_directly<std::decay_t<T>>::value)
return std::true_type{};
else if constexpr (__is_reverse_iterator_passed_directly<std::decay_t<T>>::value)
return std::true_type{};
else
return std::false_type{};
}

struct __is_passed_directly_in_onedpl_device_policies_fn
{
template <typename T>
constexpr auto
operator()(const T& t) const
{
return is_passed_directly_in_onedpl_device_policies(t);
}
};

inline constexpr __is_passed_directly_in_onedpl_device_policies_fn __is_passed_directly_in_onedpl_device_policies;

template <typename T>
struct is_passed_directly_to_device
: decltype(oneapi::dpl::__internal::__is_passed_directly_in_onedpl_device_policies(std::declval<T>())){};

template <typename T>
inline constexpr bool is_passed_directly_to_device_v = is_passed_directly_to_device<T>::value;

template <typename _T>
struct __is_reverse_iterator_passed_directly : std::false_type
{
};

template <typename _BaseIter>
struct __is_reverse_iterator_passed_directly<std::reverse_iterator<_BaseIter>>
: oneapi::dpl::__internal::is_passed_directly_to_device<_BaseIter>
{
};

//zip_iterator version for forward iterator
//== and != comparison is performed only on the first element of the tuple
//
Expand Down Expand Up @@ -266,6 +328,9 @@ class counting_iterator
return !(*this < __it);
}

friend std::true_type
is_passed_directly_in_onedpl_device_policies(const counting_iterator&);

private:
_Ip __my_counter_;
};
Expand Down Expand Up @@ -398,6 +463,10 @@ class zip_iterator
return !(*this < __it);
}

friend auto
is_passed_directly_in_onedpl_device_policies(const zip_iterator&)
-> std::conjunction<oneapi::dpl::__internal::is_passed_directly_to_device<_Types> ...>;

private:
__it_types __my_it_;
};
Expand Down Expand Up @@ -574,6 +643,9 @@ class transform_iterator
{
return __my_unary_func_;
}
friend auto
is_passed_directly_in_onedpl_device_policies(const transform_iterator&)
-> oneapi::dpl::__internal::is_passed_directly_to_device<_Iter>;
};

template <typename _Iter, typename _UnaryFunc>
Expand Down Expand Up @@ -765,11 +837,16 @@ class permutation_iterator
return !(*this < it);
}

friend auto is_passed_directly_in_onedpl_device_policies(const permutation_iterator&) -> std::conjunction<oneapi::dpl::__internal::is_passed_directly_to_device<SourceIterator>,
oneapi::dpl::__internal::is_passed_directly_to_device<
typename oneapi::dpl::permutation_iterator<SourceIterator, _Permutation>::IndexMap>>;

private:
SourceIterator my_source_it;
IndexMap my_index;
};


template <typename SourceIterator, typename IndexMap, typename... StartIndex>
permutation_iterator<SourceIterator, IndexMap>
make_permutation_iterator(SourceIterator source, IndexMap map, StartIndex... idx)
Expand Down Expand Up @@ -927,10 +1004,15 @@ class discard_iterator
return !(*this < __it);
}

private:
friend std::true_type
is_passed_directly_in_onedpl_device_policies(const discard_iterator&);

private:
difference_type __my_position_;
};



} // namespace dpl
} // namespace oneapi

Expand Down
Loading
Loading