Skip to content

Commit c806fb5

Browse files
authored
Extend fallback coverage for copy_if (#512)
We recently added a "fallback" implementation for thrust::copy_if that is invoked when copying a custom type that's too large to fit in shared memory. This change extends the fallback slightly so that it can be used with an overload of copy_if that accepts a stencil buffer (to copy by key). It also adds a unit test to cover this case. It also fixes a small bug in the fallback implementation that could cause the scan accumulator type to overflow when the results are compacted.
1 parent 8a0fb9d commit c806fb5

File tree

2 files changed

+119
-10
lines changed

2 files changed

+119
-10
lines changed

test/test_copy.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,66 @@ TYPED_TEST(CopyIntegerTests, TestCopyIf)
444444
}
445445
}
446446

447+
TEST(CopyLargeTypesTests, TestCopyIfStencilLargeType)
448+
{
449+
using T = large_data;
450+
451+
SCOPED_TRACE(testing::Message() << "with device_id= " << test::set_device_from_ctest());
452+
453+
for(auto size : get_sizes())
454+
{
455+
SCOPED_TRACE(testing::Message() << "with size= " << size);
456+
457+
thrust::host_vector<T> h_data(size);
458+
thrust::sequence(h_data.begin(), h_data.end());
459+
thrust::device_vector<T> d_data(size);
460+
thrust::sequence(d_data.begin(), d_data.end());
461+
462+
for(auto seed : get_seeds())
463+
{
464+
SCOPED_TRACE(testing::Message() << "with seed= " << seed);
465+
466+
thrust::host_vector<T> h_stencil = get_random_data<int>(size, std::numeric_limits<int>::min(), std::numeric_limits<int>::max(), seed);;
467+
thrust::device_vector<T> d_stencil = h_stencil;
468+
469+
typename thrust::host_vector<T>::iterator h_new_end;
470+
typename thrust::device_vector<T>::iterator d_new_end;
471+
472+
// test with Predicate that returns a bool
473+
{
474+
thrust::host_vector<T> h_result(size);
475+
thrust::device_vector<T> d_result(size);
476+
477+
h_new_end
478+
= thrust::copy_if(h_data.begin(), h_data.end(), h_stencil.begin(), h_result.begin(), is_even<T>());
479+
d_new_end
480+
= thrust::copy_if(d_data.begin(), d_data.end(), d_stencil.begin(), d_result.begin(), is_even<T>());
481+
482+
h_result.resize(h_new_end - h_result.begin());
483+
d_result.resize(d_new_end - d_result.begin());
484+
485+
ASSERT_EQ(h_result, d_result);
486+
}
487+
488+
// test with Predicate that returns a non-bool
489+
{
490+
thrust::host_vector<T> h_result(size);
491+
thrust::device_vector<T> d_result(size);
492+
493+
h_new_end
494+
= thrust::copy_if(h_data.begin(), h_data.end(), h_stencil.begin(), h_result.begin(), mod_3<T>());
495+
d_new_end
496+
= thrust::copy_if(d_data.begin(), d_data.end(), d_stencil.begin(), d_result.begin(), mod_3<T>());
497+
498+
h_result.resize(h_new_end - h_result.begin());
499+
d_result.resize(d_new_end - d_result.begin());
500+
501+
ASSERT_EQ(h_result, d_result);
502+
}
503+
}
504+
}
505+
}
506+
447507
TYPED_TEST(CopyIntegerTests, TestCopyIfStencil)
448508
{
449509
using T = typename TestFixture::input_type;

thrust/system/hip/detail/copy_if.h

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,49 +156,99 @@ namespace __copy_if
156156
return output + num_selected;
157157
}
158158

159-
template <typename Derived, typename InputIt, typename OutputIt, typename Predicate>
159+
template <typename Derived, typename InputIt, typename OutputIt, typename Predicate, typename PredicateInputIt>
160160
THRUST_HIP_RUNTIME_FUNCTION auto
161-
copy_if(execution_policy<Derived>& policy, InputIt first, InputIt last, OutputIt output, Predicate predicate)
161+
copy_if_common(execution_policy<Derived>& policy, InputIt first, InputIt last, OutputIt output, Predicate predicate, PredicateInputIt predicate_input)
162162
-> std::enable_if_t<!(sizeof(typename std::iterator_traits<InputIt>::value_type) < 512), OutputIt>
163163
{
164164
using namespace thrust::system::hip_rocprim::temp_storage;
165165
using size_type = typename iterator_traits<InputIt>::difference_type;
166+
using pos_type = thrust::detail::uint32_t;
167+
using flag_type = thrust::detail::uint8_t;
166168

167169
size_type num_items = thrust::distance(first, last);
168170
hipStream_t stream = hip_rocprim::stream(policy);
169171
bool debug_sync = THRUST_HIP_DEBUG_SYNC_FLAG;
170172

171-
thrust::detail::temporary_array<thrust::detail::uint8_t, Derived> flags(policy, num_items);
173+
if(num_items == 0)
174+
return output;
175+
176+
// Note: although flags can be stored in a uint8_t, in the inclusive scan performed on flags below,
177+
// the scan accumulator type to something larger (flag_type) to prevent overflow.
178+
// For this reason, we call rocprim::inclusive_scan directly here and pass in the accumulator type as template argument.
179+
thrust::detail::temporary_array<flag_type, Derived> flags(policy, num_items);
172180

173-
hip_rocprim::throw_on_error(rocprim::transform(first,
181+
hip_rocprim::throw_on_error(rocprim::transform(predicate_input,
174182
flags.begin(),
175183
num_items,
176184
[predicate] __host__ __device__ (auto const & val){ return predicate(val) ? 1 : 0; },
177185
stream,
178186
debug_sync),
179187
"copy_if failed on transform");
180188

181-
thrust::detail::temporary_array<thrust::detail::uint32_t, Derived> pos(policy, num_items);
189+
thrust::detail::temporary_array<pos_type, Derived> pos(policy, num_items);
190+
191+
// Get the required temporary storage size.
192+
size_t storage_size = 0;
193+
hip_rocprim::throw_on_error(rocprim::inclusive_scan<rocprim::default_config,
194+
typename thrust::detail::temporary_array<flag_type, Derived>::iterator,
195+
typename thrust::detail::temporary_array<pos_type, Derived>::iterator,
196+
rocprim::plus<pos_type>,
197+
pos_type>(nullptr, storage_size, flags.begin(), pos.begin(), num_items, rocprim::plus<pos_type>{}, stream, debug_sync),
198+
"copy_if failed while determining inclusive scan storage size");
182199

183-
thrust::inclusive_scan(policy, flags.begin(), flags.end(), pos.begin());
200+
// Allocate temporary storage.
201+
thrust::detail::temporary_array<thrust::detail::uint8_t, Derived> tmp(policy, storage_size);
202+
void *ptr = static_cast<void*>(tmp.data().get());
184203

204+
// Perform a scan on the positions.
205+
hip_rocprim::throw_on_error(rocprim::inclusive_scan<rocprim::default_config,
206+
typename thrust::detail::temporary_array<flag_type, Derived>::iterator,
207+
typename thrust::detail::temporary_array<pos_type, Derived>::iterator,
208+
rocprim::plus<pos_type>,
209+
pos_type>(ptr, storage_size, flags.begin(), pos.begin(), num_items, rocprim::plus<pos_type>{}, stream, debug_sync),
210+
"copy_if failed on inclusive scan");
211+
212+
// Pull out the values for which the predicate evaluated to true and compact them into the output array.
185213
constexpr static size_t items_per_thread = 16;
186214
constexpr static size_t threads_per_block = 256;
187-
const size_t block_size = std::ceil(static_cast<float>(num_items) / 16 / threads_per_block);
215+
const size_t block_size = std::ceil(static_cast<float>(num_items) / items_per_thread / threads_per_block);
188216

189217
copy_if_kernel<items_per_thread><<<block_size, threads_per_block>>>(first, flags.begin(), pos.begin(), num_items, output);
190218

191-
return output + pos[num_items-1];
219+
return output + pos[num_items - 1];
220+
}
221+
222+
template <typename Derived, typename InputIt, typename OutputIt, typename Predicate>
223+
THRUST_HIP_RUNTIME_FUNCTION auto
224+
copy_if(execution_policy<Derived>& policy, InputIt first, InputIt last, OutputIt output, Predicate predicate)
225+
-> std::enable_if_t<!(sizeof(typename std::iterator_traits<InputIt>::value_type) < 512), OutputIt>
226+
{
227+
return copy_if_common(policy, first, last, output, predicate, first);
228+
}
229+
230+
template <typename Derived, typename InputIt, typename StencilIt, typename OutputIt, typename Predicate>
231+
THRUST_HIP_RUNTIME_FUNCTION auto
232+
copy_if(execution_policy<Derived>& policy,
233+
InputIt first,
234+
InputIt last,
235+
StencilIt stencil,
236+
OutputIt output,
237+
Predicate predicate)
238+
-> std::enable_if_t<!(sizeof(typename std::iterator_traits<InputIt>::value_type) < 512), OutputIt>
239+
{
240+
return copy_if_common(policy, first, last, output, predicate, stencil);
192241
}
193242

194243
template <typename Derived, typename InputIt, typename StencilIt, typename OutputIt, typename Predicate>
195-
THRUST_HIP_RUNTIME_FUNCTION OutputIt
244+
THRUST_HIP_RUNTIME_FUNCTION auto
196245
copy_if(execution_policy<Derived>& policy,
197246
InputIt first,
198247
InputIt last,
199248
StencilIt stencil,
200249
OutputIt output,
201250
Predicate predicate)
251+
-> std::enable_if_t<(sizeof(typename std::iterator_traits<InputIt>::value_type) < 512), OutputIt>
202252
{
203253
using namespace thrust::system::hip_rocprim::temp_storage;
204254
typedef typename iterator_traits<InputIt>::difference_type size_type;
@@ -264,7 +314,6 @@ namespace __copy_if
264314

265315
return output + num_selected;
266316
}
267-
268317
} // namespace __copy_if
269318

270319
//-------------------------

0 commit comments

Comments
 (0)