@@ -1321,15 +1321,16 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
1321
1321
TemporaryArray<KeyT> out (keys.size ());
1322
1322
cub::DoubleBuffer<KeyT> d_keys (const_cast <KeyT *>(keys.data ()),
1323
1323
out.data ().get ());
1324
+ TemporaryArray<IdxT> sorted_idx_out (sorted_idx.size ());
1324
1325
cub::DoubleBuffer<ValueT> d_values (const_cast <ValueT *>(sorted_idx.data ()),
1325
- sorted_idx .data ());
1326
+ sorted_idx_out .data (). get ());
1326
1327
1327
1328
if (accending) {
1328
1329
void *d_temp_storage = nullptr ;
1329
1330
safe_cuda ((cub::DispatchRadixSort<false , KeyT, ValueT, size_t >::Dispatch (
1330
1331
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 ,
1331
1332
sizeof (KeyT) * 8 , false , nullptr , false )));
1332
- dh:: TemporaryArray<char > storage (bytes);
1333
+ TemporaryArray<char > storage (bytes);
1333
1334
d_temp_storage = storage.data ().get ();
1334
1335
safe_cuda ((cub::DispatchRadixSort<false , KeyT, ValueT, size_t >::Dispatch (
1335
1336
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 ,
@@ -1339,12 +1340,15 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
1339
1340
safe_cuda ((cub::DispatchRadixSort<true , KeyT, ValueT, size_t >::Dispatch (
1340
1341
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 ,
1341
1342
sizeof (KeyT) * 8 , false , nullptr , false )));
1342
- dh:: TemporaryArray<char > storage (bytes);
1343
+ TemporaryArray<char > storage (bytes);
1343
1344
d_temp_storage = storage.data ().get ();
1344
1345
safe_cuda ((cub::DispatchRadixSort<true , KeyT, ValueT, size_t >::Dispatch (
1345
1346
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 ,
1346
1347
sizeof (KeyT) * 8 , false , nullptr , false )));
1347
1348
}
1349
+
1350
+ safe_cuda (cudaMemcpyAsync (sorted_idx.data (), sorted_idx_out.data ().get (),
1351
+ sorted_idx.size_bytes (), cudaMemcpyDeviceToDevice));
1348
1352
}
1349
1353
1350
1354
namespace detail {
@@ -1379,14 +1383,19 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
1379
1383
size_t bytes = 0 ;
1380
1384
Iota (sorted_idx);
1381
1385
TemporaryArray<std::remove_const_t <U>> values_out (values.size ());
1386
+ TemporaryArray<std::remove_const_t <IdxT>> sorted_idx_out (sorted_idx.size ());
1387
+
1382
1388
detail::DeviceSegmentedRadixSortPair<!accending>(
1383
1389
nullptr , bytes, values.data (), values_out.data ().get (), sorted_idx.data (),
1384
- sorted_idx .data (), sorted_idx.size (), n_groups, group_ptr.data (),
1390
+ sorted_idx_out .data (). get (), sorted_idx.size (), n_groups, group_ptr.data (),
1385
1391
group_ptr.data () + 1 );
1386
- dh:: TemporaryArray<xgboost::common::byte> temp_storage (bytes);
1392
+ TemporaryArray<xgboost::common::byte> temp_storage (bytes);
1387
1393
detail::DeviceSegmentedRadixSortPair<!accending>(
1388
1394
temp_storage.data ().get (), bytes, values.data (), values_out.data ().get (),
1389
- sorted_idx.data (), sorted_idx.data (), sorted_idx.size (), n_groups,
1390
- group_ptr.data (), group_ptr.data () + 1 );
1395
+ sorted_idx.data (), sorted_idx_out.data ().get (), sorted_idx.size (),
1396
+ n_groups, group_ptr.data (), group_ptr.data () + 1 );
1397
+
1398
+ safe_cuda (cudaMemcpyAsync (sorted_idx.data (), sorted_idx_out.data ().get (),
1399
+ sorted_idx.size_bytes (), cudaMemcpyDeviceToDevice));
1391
1400
}
1392
1401
} // namespace dh
0 commit comments