Skip to content

Commit 4eddf48

Browse files
committed
[oneDPL][ranges] + support sized output range for copy_if; dpcpp backend, part 2
1 parent 3e86e51 commit 4eddf48

File tree

5 files changed

+29
-13
lines changed

5 files changed

+29
-13
lines changed

include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ __pattern_copy_if(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterato
901901
if (__first == __last)
902902
return __result_first;
903903

904+
auto __n = __last - __first;
904905
auto __keep1 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _Iterator1>();
905906
auto __buf1 = __keep1(__first, __last);
906907
auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _Iterator2>();

include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,18 +540,18 @@ std::pair<oneapi::dpl::__internal::__difference_t<_Range1>, oneapi::dpl::__inter
540540
__pattern_copy_if(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
541541
_Predicate __pred, _Assign __assign)
542542
{
543-
using _Index = oneapi::dpl::__internal::__difference_t<_Range2>;
543+
using _Index = std::size_t; //TODO
544544
_Index __n = __rng1.size();
545545
if (__n == 0 || __rng2.empty())
546546
return {0, 0};
547547

548548
auto __res = oneapi::dpl::__par_backend_hetero::__parallel_copy_if_out_lim(
549549
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1),
550-
std::forward<_Range2>(__rng2), __pred, __assign).get();
550+
std::forward<_Range2>(__rng2), __pred, __assign);
551551

552-
std::array<_Index, _2> __idx;
552+
std::array<_Index, 2> __idx;
553553
__res.get_values(__idx); //a blocking call
554-
return {__idx[0], __idx[1];
554+
return {__idx[1], __idx[0]}; //__parallel_copy_if_out_lim returns {last index in output, last index in input}
555555
}
556556

557557
#if _ONEDPL_CPP20_RANGES_PRESENT

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
316316
// Storage for the results of scan for each workgroup
317317

318318
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Type>;
319-
__result_and_scratch_storage_t __result_and_scratch{__exec, 1, __n_groups + 1};
319+
__result_and_scratch_storage_t __result_and_scratch{__exec, 2, __n_groups + 1};
320320

321321
_PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu);
322322

@@ -1235,6 +1235,7 @@ __parallel_scan_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag
12351235
_InRng&& __in_rng, _OutRng&& __out_rng, _CreateMaskOp __create_mask_op,
12361236
_CopyByMaskOp __copy_by_mask_op)
12371237
{
1238+
using _Size = decltype(__out_rng.size());
12381239
using _ReduceOp = std::plus<_Size>;
12391240
using _Assigner = unseq_backend::__scan_assigner;
12401241
using _NoAssign = unseq_backend::__scan_no_assign;
@@ -1370,6 +1371,7 @@ auto
13701371
__parallel_copy_if_out_lim(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
13711372
_InRng&& __in_rng, _OutRng&& __out_rng, _Pred __pred, _Assign __assign = _Assign{})
13721373
{
1374+
using _Size = decltype(__out_rng.size());
13731375
using _ReduceOp = std::plus<_Size>;
13741376
using _CreateOp = unseq_backend::__create_mask<_Pred, _Size>;
13751377
using _CopyOp = unseq_backend::__copy_by_mask<_ReduceOp, _Assign,

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <tuple>
2323
#include <algorithm>
2424

25-
#include "../../iterator_impl.h"
25+
#include "../../iterator_impl.h"
2626

2727
#include "sycl_defs.h"
2828
#include "execution_sycl_defs.h"
@@ -683,8 +683,8 @@ struct __result_and_scratch_storage : __result_and_scratch_storage_base
683683
}
684684
}
685685

686-
template <typename _T, std::size_t _N>
687-
void get_values(std::array<_T, _N>& __arr)
686+
template <std::size_t _N>
687+
void get_values(std::array<_T, _N>& __arr) const
688688
{
689689
assert(__result_n > 0);
690690
assert(_N == __result_n);
@@ -713,14 +713,14 @@ struct __result_and_scratch_storage : __result_and_scratch_storage_base
713713
return __get_value(idx);
714714
}
715715

716-
template <typename _Event, typename _T, std::size_t _N>
716+
template <typename _Event, std::size_t _N>
717717
void
718718
__wait_and_get_value(_Event&& __event, std::array<_T, _N>& __arr) const
719719
{
720720
if (is_USM())
721721
__event.wait_and_throw();
722722

723-
return get_values(__arr);
723+
get_values(__arr);
724724
}
725725
};
726726

@@ -745,7 +745,7 @@ struct __wait_and_get_value
745745
constexpr void
746746
operator()(auto&& __event, const __result_and_scratch_storage<_ExecutionPolicy, _T>& __storage, std::array<_T, _N>& __arr)
747747
{
748-
return __storage.__wait_and_get_value(__event, __arr);
748+
__storage.__wait_and_get_value(__event, __arr);
749749
}
750750

751751
template <typename _T>
@@ -812,9 +812,11 @@ class __future : private std::tuple<_Args...>
812812
}
813813

814814
template <typename _T, std::size_t _N>
815-
std::enable_if_t<sizeof...(_Args) > 0>
815+
void
816816
get_values(std::array<_T, _N>& __arr)
817817
{
818+
static_assert(sizeof...(_Args) > 0);
819+
auto& __val = std::get<0>(*this);
818820
__wait_and_get_value{}(event(), __val, __arr);
819821
}
820822

include/oneapi/dpl/pstl/hetero/dpcpp/unseq_backend_sycl.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,12 +625,23 @@ struct __copy_by_mask
625625
// ::std::tuple as operands, in all the other cases this is not necessary and no conversion
626626
// is performed(i.e. __typle_type is the same type as its operand).
627627
if(__out_idx < __out_acc.size())
628+
{
628629
__assigner(static_cast<__tuple_type>(get<0>(__in_acc[__item_idx])), __out_acc[__out_idx]);
630+
auto __last_out_idx = __wg_sums_ptr[(__n - 1) / __size_per_wg];
631+
if(__out_idx + 1 == __last_out_idx)
632+
{
633+
__ret_ptr[0] = __item_idx + 1, __ret_ptr[1] = __last_out_idx;
634+
}
635+
}
636+
else if(__out_idx == __out_acc.size())
637+
{
638+
__ret_ptr[0] = __item_idx, __ret_ptr[1] = __out_idx;
639+
}
629640
}
630641
if (__item_idx == 0)
631642
{
632643
//copy final result to output
633-
*__ret_ptr = __wg_sums_ptr[(__n - 1) / __size_per_wg];
644+
__ret_ptr[1] = __wg_sums_ptr[(__n - 1) / __size_per_wg];
634645
}
635646
}
636647
};

0 commit comments

Comments
 (0)