@@ -21,6 +21,47 @@ namespace GKO_DEVICE_NAMESPACE {
2121namespace bitvector {
2222
2323
24+ template <typename IndexType, typename DevicePredicate>
25+ gko::bitvector<IndexType> from_predicate (
26+ std::shared_ptr<const DefaultExecutor> exec, IndexType size,
27+ DevicePredicate device_predicate)
28+ {
29+ using storage_type = typename device_bitvector<IndexType>::storage_type;
30+ constexpr auto block_size = device_bitvector<IndexType>::block_size;
31+ const auto num_blocks = static_cast <size_type>(ceildiv (size, block_size));
32+ array<uint32> bit_array{exec, num_blocks};
33+ array<IndexType> rank_array{exec, num_blocks};
34+ const auto bits = bit_array.get_data ();
35+ const auto ranks = rank_array.get_data ();
36+ const auto queue = exec->get_queue ();
37+ queue->submit ([&](sycl::handler& cgh) {
38+ cgh.parallel_for (num_blocks, [=](sycl::id<1 > block_i) {
39+ const auto base_i = static_cast <IndexType>(block_i) * block_size;
40+ storage_type mask{};
41+ if (base_i + block_size <= size) {
42+ for (int local_i = 0 ; local_i < block_size; local_i++) {
43+ const storage_type bit =
44+ device_predicate (base_i + local_i) ? 1 : 0 ;
45+ mask |= bit << local_i;
46+ }
47+ } else {
48+ for (int local_i = 0 ; base_i + local_i < size; local_i++) {
49+ const storage_type bit =
50+ device_predicate (base_i + local_i) ? 1 : 0 ;
51+ mask |= bit << local_i;
52+ }
53+ }
54+ bits[block_i] = mask;
55+ ranks[block_i] = gko::detail::popcount (mask);
56+ });
57+ });
58+ components::prefix_sum_nonnegative (exec, ranks, num_blocks);
59+
60+ return gko::bitvector<IndexType>{std::move (bit_array),
61+ std::move (rank_array), size};
62+ }
63+
64+
2465template <typename IndexIterator>
2566gko::bitvector<typename std::iterator_traits<IndexIterator>::value_type>
2667from_sorted_indices (
0 commit comments