1111#include < raft/core/handle.hpp>
1212
1313#include < cuda/functional>
14- #include < cuda/std/functional>
15- #include < cuda/std/iterator>
14+ #include < cuda/iterator>
1615#include < thrust/copy.h>
1716#include < thrust/iterator/counting_iterator.h>
18- #include < thrust/iterator/transform_iterator.h>
1917#include < thrust/tabulate.h>
2018#include < thrust/transform.h>
2119#include < thrust/transform_reduce.h>
@@ -132,15 +130,15 @@ __device__ size_t copy_if_mask_set(InputIterator input_first,
132130
133131 return static_cast <size_t >(cuda::std::distance (
134132 output_first + output_start_offset,
135- thrust::copy_if (thrust::seq,
136- input_first + input_start_offset ,
137- input_first + ( input_start_offset + num_items) ,
138- thrust::make_transform_iterator (
139- thrust::make_counting_iterator (size_t {0 }),
140- check_bit_set_t <MaskIterator, size_t >{mask_first, size_t {0 }}) +
141- input_start_offset,
142- output_first + output_start_offset,
143- is_equal_t <bool >{true })));
133+ thrust::copy_if (
134+ thrust::seq ,
135+ input_first + input_start_offset,
136+ input_first + (input_start_offset + num_items),
137+ cuda::make_transform_iterator ( thrust::make_counting_iterator (size_t {0 }),
138+ check_bit_set_t <MaskIterator, size_t >{mask_first, size_t {0 }}) +
139+ input_start_offset,
140+ output_first + output_start_offset,
141+ is_equal_t <bool >{true })));
144142}
145143
146144template <typename MaskIterator> // should be packed bool
@@ -177,8 +175,8 @@ OutputIterator copy_if_mask_set(raft::handle_t const& handle,
177175 handle.get_thrust_policy (),
178176 input_first,
179177 input_last,
180- thrust ::make_transform_iterator (thrust::make_counting_iterator (size_t {0 }),
181- check_bit_set_t <MaskIterator, size_t >{mask_first, size_t {0 }}),
178+ cuda ::make_transform_iterator (thrust::make_counting_iterator (size_t {0 }),
179+ check_bit_set_t <MaskIterator, size_t >{mask_first, size_t {0 }}),
182180 output_first,
183181 is_equal_t <bool >{true });
184182}
@@ -196,8 +194,8 @@ OutputIterator copy_if_mask_unset(raft::handle_t const& handle,
196194 handle.get_thrust_policy (),
197195 input_first,
198196 input_last,
199- thrust ::make_transform_iterator (thrust::make_counting_iterator (size_t {0 }),
200- check_bit_set_t <MaskIterator, size_t >{mask_first, size_t {0 }}),
197+ cuda ::make_transform_iterator (thrust::make_counting_iterator (size_t {0 }),
198+ check_bit_set_t <MaskIterator, size_t >{mask_first, size_t {0 }}),
201199 output_first,
202200 is_equal_t <bool >{false });
203201}
0 commit comments