diff --git a/include/oneapi/dpl/iterator b/include/oneapi/dpl/iterator index 5568292b1c1..45b7e1ccf0f 100644 --- a/include/oneapi/dpl/iterator +++ b/include/oneapi/dpl/iterator @@ -56,6 +56,10 @@ using ::std::rbegin; using ::std::rend; using ::std::reverse_iterator; using ::std::size; + +using oneapi::dpl::__internal::is_passed_directly_to_device; +using oneapi::dpl::__internal::is_passed_directly_to_device_v; + } // namespace dpl } // namespace oneapi namespace dpl = oneapi::dpl; diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h index d11a2676715..6ffbadb8da0 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h @@ -200,62 +200,6 @@ struct is_permutation> : : { }; -//is_passed_directly trait definition; specializations for the oneDPL iterators - -template -struct is_passed_directly : ::std::is_pointer -{ -}; - -//support legacy "is_passed_directly" trait -template -struct is_passed_directly> : ::std::true_type -{ -}; - -//support std::vector::iterator with usm host / shared allocator as passed directly -template -struct is_passed_directly>> - : std::true_type -{ -}; - -template -struct is_passed_directly> : ::std::true_type -{ -}; - -template <> -struct is_passed_directly : ::std::true_type -{ -}; - -template -struct is_passed_directly<::std::reverse_iterator> : is_passed_directly -{ -}; - -template -struct is_passed_directly> : is_passed_directly -{ -}; - -template -struct is_passed_directly> - : ::std::conjunction< - is_passed_directly, - is_passed_directly::IndexMap>> -{ -}; - -template -struct is_passed_directly> : ::std::conjunction...> -{ -}; - -template -inline constexpr bool is_passed_directly_v = is_passed_directly::value; - // A trait for checking if iterator is heterogeneous or not template @@ -290,7 +234,8 @@ struct is_temp_buff : ::std::false_type template struct is_temp_buff<_Iter, ::std::enable_if_t && !::std::is_pointer_v<_Iter> && - !is_passed_directly_v<_Iter>>> : ::std::true_type + !oneapi::dpl::__internal::is_passed_directly_to_device_v<_Iter>>> + : ::std::true_type { }; @@ -550,8 +495,10 @@ struct __get_sycl_range } //specialization for permutation_iterator using USM pointer or direct pass object as source - template && is_passed_directly_v<_Iter>, int> = 0> + template < + sycl::access::mode _LocalAccMode, typename _Iter, typename _Map, + ::std::enable_if_t && oneapi::dpl::__internal::is_passed_directly_to_device_v<_Iter>, + int> = 0> auto __process_input_iter(oneapi::dpl::permutation_iterator<_Iter, _Map> __first, oneapi::dpl::permutation_iterator<_Iter, _Map> __last) @@ -568,8 +515,10 @@ struct __get_sycl_range // specialization for general case, permutation_iterator with base iterator that is not sycl_iterator or // passed directly. - template && !is_passed_directly_v<_Iter>, int> = 0> + template < + sycl::access::mode _LocalAccMode, typename _Iter, typename _Map, + ::std::enable_if_t< + !is_sycl_iterator_v<_Iter> && !oneapi::dpl::__internal::is_passed_directly_to_device_v<_Iter>, int> = 0> auto __process_input_iter(oneapi::dpl::permutation_iterator<_Iter, _Map> __first, oneapi::dpl::permutation_iterator<_Iter, _Map> __last) @@ -578,8 +527,8 @@ struct __get_sycl_range assert(__n > 0); //TODO: investigate better method of handling this specifically for fancy_iterators which are composed fully - // of a combination of fancy_iterators, sycl_iterators, and is_passed_directly types. - // Currently this relies on UB because the size of the accessor when handling sycl_iterators + // of a combination of fancy_iterators, sycl_iterators, and is_passed_directly_in_onedpl_device_policies + // types. Currently this relies on UB because the size of the accessor when handling sycl_iterators // in recursion below this level is incorrect. auto res_src = this->operator()(__first.base(), __first.base() + 1 /*source size*/); @@ -604,7 +553,8 @@ struct __get_sycl_range // for raw pointers and direct pass objects (for example, counting_iterator, iterator of USM-containers) template - ::std::enable_if_t, __range_holder>> + ::std::enable_if_t, + __range_holder>> __process_input_iter(_Iter __first, _Iter __last) { assert(__first < __last); diff --git a/include/oneapi/dpl/pstl/iterator_impl.h b/include/oneapi/dpl/pstl/iterator_impl.h index 63ce678e4ae..f5b75592b8b 100644 --- a/include/oneapi/dpl/pstl/iterator_impl.h +++ b/include/oneapi/dpl/pstl/iterator_impl.h @@ -75,6 +75,68 @@ struct __make_references } }; +template +struct __is_legacy_passed_directly : std::false_type +{ +}; + +template +struct __is_legacy_passed_directly<_Iter, ::std::enable_if_t<_Iter::is_passed_directly::value>> : std::true_type +{ +}; + +template +struct __is_reverse_iterator_passed_directly; + +template +constexpr auto +is_passed_directly_in_onedpl_device_policies(const T&) +{ + if constexpr (std::is_pointer>::value) + return std::true_type{}; +#if _ONEDPL_BACKEND_SYCL + // TODO: hide this better in sycl backend, either all passed directly functions, or just this + else if constexpr (oneapi::dpl::__internal::__is_known_usm_vector_iter_v>) + return std::true_type{}; +#endif + else if constexpr (__is_legacy_passed_directly>::value) + return std::true_type{}; + else if constexpr (__is_reverse_iterator_passed_directly>::value) + return std::true_type{}; + else + return std::false_type{}; +} + +struct __is_passed_directly_in_onedpl_device_policies_fn +{ + template + constexpr auto + operator()(const T& t) const + { + return is_passed_directly_in_onedpl_device_policies(t); + } +}; + +inline constexpr __is_passed_directly_in_onedpl_device_policies_fn __is_passed_directly_in_onedpl_device_policies; + +template +struct is_passed_directly_to_device + : decltype(oneapi::dpl::__internal::__is_passed_directly_in_onedpl_device_policies(std::declval())){}; + +template +inline constexpr bool is_passed_directly_to_device_v = is_passed_directly_to_device::value; + +template +struct __is_reverse_iterator_passed_directly : std::false_type +{ +}; + +template +struct __is_reverse_iterator_passed_directly> + : oneapi::dpl::__internal::is_passed_directly_to_device<_BaseIter> +{ +}; + //zip_iterator version for forward iterator //== and != comparison is performed only on the first element of the tuple // @@ -266,6 +328,9 @@ class counting_iterator return !(*this < __it); } + friend std::true_type + is_passed_directly_in_onedpl_device_policies(const counting_iterator&); + private: _Ip __my_counter_; }; @@ -398,6 +463,10 @@ class zip_iterator return !(*this < __it); } + friend auto + is_passed_directly_in_onedpl_device_policies(const zip_iterator&) + -> std::conjunction ...>; + private: __it_types __my_it_; }; @@ -574,6 +643,9 @@ class transform_iterator { return __my_unary_func_; } + friend auto + is_passed_directly_in_onedpl_device_policies(const transform_iterator&) + -> oneapi::dpl::__internal::is_passed_directly_to_device<_Iter>; }; template @@ -765,11 +837,16 @@ class permutation_iterator return !(*this < it); } + friend auto is_passed_directly_in_onedpl_device_policies(const permutation_iterator&) -> std::conjunction, + oneapi::dpl::__internal::is_passed_directly_to_device< + typename oneapi::dpl::permutation_iterator::IndexMap>>; + private: SourceIterator my_source_it; IndexMap my_index; }; + template permutation_iterator make_permutation_iterator(SourceIterator source, IndexMap map, StartIndex... idx) @@ -927,10 +1004,15 @@ class discard_iterator return !(*this < __it); } - private: + friend std::true_type + is_passed_directly_in_onedpl_device_policies(const discard_iterator&); + + private: difference_type __my_position_; }; + + } // namespace dpl } // namespace oneapi diff --git a/test/general/implementation_details/passed_directly.pass.cpp b/test/general/implementation_details/passed_directly.pass.cpp new file mode 100644 index 00000000000..1cc35aabd78 --- /dev/null +++ b/test/general/implementation_details/passed_directly.pass.cpp @@ -0,0 +1,380 @@ +// -*- C++ -*- +//===-- passed_directly.pass.cpp -----------------------------------------------===// +// +// Copyright (C) Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file incorporates work covered by the following copyright and permission +// notice: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// +#include +#include "support/test_config.h" + +#include _PSTL_TEST_HEADER(iterator) + +#include "support/utils_device_copyable.h" +#include "support/utils.h" + +#if TEST_DPCPP_BACKEND_PRESENT + +struct simple_passed_directly_iterator +{ + using iterator_category = std::input_iterator_tag; + using value_type = int; + using difference_type = std::ptrdiff_t; + using pointer = int*; + using reference = int&; + + using is_passed_directly = std::true_type; + + simple_passed_directly_iterator(int start = 0) : value(start) {} + + int + operator*() const + { + return value; + } + + simple_passed_directly_iterator& + operator++() + { + ++value; + return *this; + } + + simple_passed_directly_iterator + operator++(int) + { + simple_passed_directly_iterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool + operator==(const simple_passed_directly_iterator& a, const simple_passed_directly_iterator& b) + { + return a.value == b.value; + } + + friend bool + operator!=(const simple_passed_directly_iterator& a, const simple_passed_directly_iterator& b) + { + return !(a == b); + } + + private: + int value; +}; + +struct simple_explicitly_not_passed_directly_iterator +{ + using iterator_category = std::input_iterator_tag; + using value_type = int; + using difference_type = std::ptrdiff_t; + using pointer = int*; + using reference = int&; + + using is_passed_directly = std::false_type; + + simple_explicitly_not_passed_directly_iterator(int start = 0) : value(start) {} + + int + operator*() const + { + return value; + } + + simple_explicitly_not_passed_directly_iterator& + operator++() + { + ++value; + return *this; + } + + simple_explicitly_not_passed_directly_iterator + operator++(int) + { + simple_explicitly_not_passed_directly_iterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool + operator==(const simple_explicitly_not_passed_directly_iterator& a, + const simple_explicitly_not_passed_directly_iterator& b) + { + return a.value == b.value; + } + + friend bool + operator!=(const simple_explicitly_not_passed_directly_iterator& a, + const simple_explicitly_not_passed_directly_iterator& b) + { + return !(a == b); + } + + private: + int value; +}; + +struct simple_implicitly_not_passed_directly_iterator +{ + using iterator_category = std::input_iterator_tag; + using value_type = int; + using difference_type = std::ptrdiff_t; + using pointer = int*; + using reference = int&; + + using is_passed_directly = std::false_type; + + simple_implicitly_not_passed_directly_iterator(int start = 0) : value(start) {} + + int + operator*() const + { + return value; + } + + simple_implicitly_not_passed_directly_iterator& + operator++() + { + ++value; + return *this; + } + + simple_implicitly_not_passed_directly_iterator + operator++(int) + { + simple_implicitly_not_passed_directly_iterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool + operator==(const simple_implicitly_not_passed_directly_iterator& a, + const simple_implicitly_not_passed_directly_iterator& b) + { + return a.value == b.value; + } + + friend bool + operator!=(const simple_implicitly_not_passed_directly_iterator& a, + const simple_implicitly_not_passed_directly_iterator& b) + { + return !(a == b); + } + + private: + int value; +}; + +namespace custom_user +{ +template +struct base_strided_iterator +{ + using iterator_category = std::input_iterator_tag; + using value_type = typename std::iterator_traits::value_type; + + base_strided_iterator(BaseIter base, int stride) : base(base), stride(stride) {} + + int + operator*() const + { + return *base; + } + + base_strided_iterator& + operator++() + { + std::advance(base, stride); + return *this; + } + + base_strided_iterator + operator++(int) + { + base_strided_iterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool + operator==(const base_strided_iterator& a, const base_strided_iterator& b) + { + return a.base == b.base; + } + + friend bool + operator!=(const base_strided_iterator& a, const base_strided_iterator& b) + { + return !(a == b); + } + + private: + BaseIter base; + int stride; +}; + +template +struct first_strided_iterator : public base_strided_iterator +{ + first_strided_iterator(BaseIter base, int stride) : base_strided_iterator(base, stride) {} +}; + +template +auto +is_passed_directly_in_onedpl_device_policies(const first_strided_iterator&) +{ + return oneapi::dpl::is_passed_directly_to_device{}; +} + +template +struct second_strided_iterator : public base_strided_iterator +{ + second_strided_iterator(BaseIter base, int stride) : base_strided_iterator(base, stride) {} +}; + +template +auto +is_passed_directly_in_onedpl_device_policies(const second_strided_iterator&) + -> decltype(oneapi::dpl::is_passed_directly_to_device{}); + +template +struct third_strided_iterator : public base_strided_iterator +{ + third_strided_iterator(BaseIter base, int stride) : base_strided_iterator(base, stride) {} + friend auto + is_passed_directly_in_onedpl_device_policies(const third_strided_iterator&) + { + return oneapi::dpl::is_passed_directly_to_device{}; + } +}; + +template +struct fourth_strided_iterator : public base_strided_iterator +{ + fourth_strided_iterator(BaseIter base, int stride) : base_strided_iterator(base, stride) {} + friend auto + is_passed_directly_in_onedpl_device_policies(const fourth_strided_iterator&) + -> oneapi::dpl::is_passed_directly_to_device; +}; + +} // namespace custom_user + +template +void +test_with_base_iterator() +{ + //test assumption about base iterator passed directly + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for base iterator"); + + // test wrapping base in transform_iterator + using TransformIter = oneapi::dpl::transform_iterator; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for transform iterator"); + + // test wrapping base in permutation_iterator with counting iter + using PermutationIter = oneapi::dpl::permutation_iterator>; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for permutation iterator"); + + // test wrapping base in permutation_iter with functor + using PermutationIterFunctor = oneapi::dpl::permutation_iterator; + static_assert( + oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for permutation iterator with functor"); + + // test wrapping base in zip_iterator + using ZipIter = oneapi::dpl::zip_iterator; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for zip iterator"); + + // test wrapping base in zip_iterator with counting_iterator first + using ZipIterCounting = oneapi::dpl::zip_iterator, BaseIter>; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for zip iterator with " + "counting iterator first"); + + // test wrapping base in zip_iterator with counting_iterator second + using ZipIterCounting2 = oneapi::dpl::zip_iterator>; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for zip iterator with " + "counting iterator first"); + + // test wrapping base in reverse_iterator + using ReverseIter = std::reverse_iterator; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for reverse iterator"); + + // test custom user first strided iterator with normal ADL function + using FirstStridedIter = custom_user::first_strided_iterator; + static_assert( + oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for custom user strided iterator"); + + // test custom user second strided iterator (no body for is_passed_directly_in_onedpl_device_policies) + using SecondStridedIter = custom_user::second_strided_iterator; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for custom user strided " + "iterator with no body in ADL function definition"); + + // test custom user first strided iterator with hidden friend ADL function + using ThirdStridedIter = custom_user::third_strided_iterator; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for custom user strided " + "iterator with hidden friend ADL function"); + + // test custom user first strided iterator with hidden friend ADL function without body + using FourthStridedIter = custom_user::fourth_strided_iterator; + static_assert(oneapi::dpl::is_passed_directly_to_device_v == base_passed_directly, + "is_passed_directly_in_onedpl_device_policies is not working correctly for custom user strided " + "iterator with hidden friend ADL function without a body"); +} + +#endif // TEST_DPCPP_BACKEND_PRESENT + +int +main() +{ + +#if TEST_DPCPP_BACKEND_PRESENT + // counting_iterator + test_with_base_iterator>(); + + // pointer (USM assumed) + test_with_base_iterator(); + + // create a usm allocated vector + sycl::queue q; + sycl::usm_allocator alloc(q); + std::vector> vec(alloc); + test_with_base_iterator, + decltype(vec.begin())>(); + + // custom iter type with legacy is_passed_directly trait defined + test_with_base_iterator(); + + // custom iter type with explicit is_passed_directly trait defined as false + test_with_base_iterator(); + + // custom iter type implicitly not passed directly + test_with_base_iterator(); + + // std vector with normal allocator + std::vector vec2(10); + test_with_base_iterator(); + + // test discard_iterator + static_assert(oneapi::dpl::is_passed_directly_to_device_v == true, + "is_passed_directly_in_onedpl_device_policies is not working correctly for discard iterator"); + +#endif // TEST_DPCPP_BACKEND_PRESENT + return TestUtils::done(TEST_DPCPP_BACKEND_PRESENT); +}