@@ -156,49 +156,99 @@ namespace __copy_if
156156 return output + num_selected;
157157 }
158158
159- template <typename Derived, typename InputIt, typename OutputIt, typename Predicate>
159+ template <typename Derived, typename InputIt, typename OutputIt, typename Predicate, typename PredicateInputIt >
160160 THRUST_HIP_RUNTIME_FUNCTION auto
161- copy_if (execution_policy<Derived>& policy, InputIt first, InputIt last, OutputIt output, Predicate predicate)
161+ copy_if_common (execution_policy<Derived>& policy, InputIt first, InputIt last, OutputIt output, Predicate predicate, PredicateInputIt predicate_input )
162162 -> std::enable_if_t <!(sizeof (typename std::iterator_traits<InputIt>::value_type) < 512 ), OutputIt>
163163 {
164164 using namespace thrust ::system::hip_rocprim::temp_storage;
165165 using size_type = typename iterator_traits<InputIt>::difference_type;
166+ using pos_type = thrust::detail::uint32_t ;
167+ using flag_type = thrust::detail::uint8_t ;
166168
167169 size_type num_items = thrust::distance (first, last);
168170 hipStream_t stream = hip_rocprim::stream (policy);
169171 bool debug_sync = THRUST_HIP_DEBUG_SYNC_FLAG;
170172
171- thrust::detail::temporary_array<thrust::detail::uint8_t , Derived> flags (policy, num_items);
173+ if (num_items == 0 )
174+ return output;
175+
176+ // Note: although flags can be stored in a uint8_t, in the inclusive scan performed on flags below,
177+ // the scan accumulator type to something larger (flag_type) to prevent overflow.
178+ // For this reason, we call rocprim::inclusive_scan directly here and pass in the accumulator type as template argument.
179+ thrust::detail::temporary_array<flag_type, Derived> flags (policy, num_items);
172180
173- hip_rocprim::throw_on_error (rocprim::transform (first ,
181+ hip_rocprim::throw_on_error (rocprim::transform (predicate_input ,
174182 flags.begin (),
175183 num_items,
176184 [predicate] __host__ __device__ (auto const & val){ return predicate (val) ? 1 : 0 ; },
177185 stream,
178186 debug_sync),
179187 " copy_if failed on transform" );
180188
181- thrust::detail::temporary_array<thrust::detail::uint32_t , Derived> pos (policy, num_items);
189+ thrust::detail::temporary_array<pos_type, Derived> pos (policy, num_items);
190+
191+ // Get the required temporary storage size.
192+ size_t storage_size = 0 ;
193+ hip_rocprim::throw_on_error (rocprim::inclusive_scan<rocprim::default_config,
194+ typename thrust::detail::temporary_array<flag_type, Derived>::iterator,
195+ typename thrust::detail::temporary_array<pos_type, Derived>::iterator,
196+ rocprim::plus<pos_type>,
197+ pos_type>(nullptr , storage_size, flags.begin (), pos.begin (), num_items, rocprim::plus<pos_type>{}, stream, debug_sync),
198+ " copy_if failed while determining inclusive scan storage size" );
182199
183- thrust::inclusive_scan (policy, flags.begin (), flags.end (), pos.begin ());
200+ // Allocate temporary storage.
201+ thrust::detail::temporary_array<thrust::detail::uint8_t , Derived> tmp (policy, storage_size);
202+ void *ptr = static_cast <void *>(tmp.data ().get ());
184203
204+ // Perform a scan on the positions.
205+ hip_rocprim::throw_on_error (rocprim::inclusive_scan<rocprim::default_config,
206+ typename thrust::detail::temporary_array<flag_type, Derived>::iterator,
207+ typename thrust::detail::temporary_array<pos_type, Derived>::iterator,
208+ rocprim::plus<pos_type>,
209+ pos_type>(ptr, storage_size, flags.begin (), pos.begin (), num_items, rocprim::plus<pos_type>{}, stream, debug_sync),
210+ " copy_if failed on inclusive scan" );
211+
212+ // Pull out the values for which the predicate evaluated to true and compact them into the output array.
185213 constexpr static size_t items_per_thread = 16 ;
186214 constexpr static size_t threads_per_block = 256 ;
187- const size_t block_size = std::ceil (static_cast <float >(num_items) / 16 / threads_per_block);
215+ const size_t block_size = std::ceil (static_cast <float >(num_items) / items_per_thread / threads_per_block);
188216
189217 copy_if_kernel<items_per_thread><<<block_size, threads_per_block>>>(first, flags.begin (), pos.begin (), num_items, output);
190218
191- return output + pos[num_items-1 ];
219+ return output + pos[num_items - 1 ];
220+ }
221+
222+ template <typename Derived, typename InputIt, typename OutputIt, typename Predicate>
223+ THRUST_HIP_RUNTIME_FUNCTION auto
224+ copy_if (execution_policy<Derived>& policy, InputIt first, InputIt last, OutputIt output, Predicate predicate)
225+ -> std::enable_if_t <!(sizeof (typename std::iterator_traits<InputIt>::value_type) < 512 ), OutputIt>
226+ {
227+ return copy_if_common (policy, first, last, output, predicate, first);
228+ }
229+
230+ template <typename Derived, typename InputIt, typename StencilIt, typename OutputIt, typename Predicate>
231+ THRUST_HIP_RUNTIME_FUNCTION auto
232+ copy_if (execution_policy<Derived>& policy,
233+ InputIt first,
234+ InputIt last,
235+ StencilIt stencil,
236+ OutputIt output,
237+ Predicate predicate)
238+ -> std::enable_if_t <!(sizeof (typename std::iterator_traits<InputIt>::value_type) < 512 ), OutputIt>
239+ {
240+ return copy_if_common (policy, first, last, output, predicate, stencil);
192241 }
193242
194243 template <typename Derived, typename InputIt, typename StencilIt, typename OutputIt, typename Predicate>
195- THRUST_HIP_RUNTIME_FUNCTION OutputIt
244+ THRUST_HIP_RUNTIME_FUNCTION auto
196245 copy_if (execution_policy<Derived>& policy,
197246 InputIt first,
198247 InputIt last,
199248 StencilIt stencil,
200249 OutputIt output,
201250 Predicate predicate)
251+ -> std::enable_if_t <(sizeof (typename std::iterator_traits<InputIt>::value_type) < 512 ), OutputIt>
202252 {
203253 using namespace thrust ::system::hip_rocprim::temp_storage;
204254 typedef typename iterator_traits<InputIt>::difference_type size_type;
@@ -264,7 +314,6 @@ namespace __copy_if
264314
265315 return output + num_selected;
266316 }
267-
268317} // namespace __copy_if
269318
270319// -------------------------
0 commit comments