Skip to content

Commit d05c47d

Browse files
authored
[back port] Copy output data for argsort. (#6866) (#6870)
Fix GPU AUC.
1 parent 9f5e2c5 commit d05c47d

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

src/common/device_helpers.cuh

+16-7
Original file line numberDiff line numberDiff line change
@@ -1321,15 +1321,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
13211321
TemporaryArray<KeyT> out(keys.size());
13221322
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
13231323
out.data().get());
1324+
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
13241325
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
1325-
sorted_idx.data());
1326+
sorted_idx_out.data().get());
13261327

13271328
if (accending) {
13281329
void *d_temp_storage = nullptr;
13291330
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
13301331
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
13311332
sizeof(KeyT) * 8, false, nullptr, false)));
1332-
dh::TemporaryArray<char> storage(bytes);
1333+
TemporaryArray<char> storage(bytes);
13331334
d_temp_storage = storage.data().get();
13341335
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
13351336
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
@@ -1339,12 +1340,15 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
13391340
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
13401341
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
13411342
sizeof(KeyT) * 8, false, nullptr, false)));
1342-
dh::TemporaryArray<char> storage(bytes);
1343+
TemporaryArray<char> storage(bytes);
13431344
d_temp_storage = storage.data().get();
13441345
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
13451346
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
13461347
sizeof(KeyT) * 8, false, nullptr, false)));
13471348
}
1349+
1350+
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
1351+
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
13481352
}
13491353

13501354
namespace detail {
@@ -1379,14 +1383,19 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
13791383
size_t bytes = 0;
13801384
Iota(sorted_idx);
13811385
TemporaryArray<std::remove_const_t<U>> values_out(values.size());
1386+
TemporaryArray<std::remove_const_t<IdxT>> sorted_idx_out(sorted_idx.size());
1387+
13821388
detail::DeviceSegmentedRadixSortPair<!accending>(
13831389
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
1384-
sorted_idx.data(), sorted_idx.size(), n_groups, group_ptr.data(),
1390+
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
13851391
group_ptr.data() + 1);
1386-
dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes);
1392+
TemporaryArray<xgboost::common::byte> temp_storage(bytes);
13871393
detail::DeviceSegmentedRadixSortPair<!accending>(
13881394
temp_storage.data().get(), bytes, values.data(), values_out.data().get(),
1389-
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups,
1390-
group_ptr.data(), group_ptr.data() + 1);
1395+
sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(),
1396+
n_groups, group_ptr.data(), group_ptr.data() + 1);
1397+
1398+
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
1399+
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
13911400
}
13921401
} // namespace dh

src/metric/auc.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
269269
});
270270

271271
// unique values are sparse, so we need a CSR style indptr
272-
dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size() + 1);
272+
dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size());
273273
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
274274
auto n_uniques = dh::SegmentedUniqueByKey(
275275
thrust::cuda::par(alloc),

0 commit comments

Comments
 (0)