Skip to content

Commit ab6d3c6

Browse files
authored
Use context for CUDA external memory DMatrix. (#12137)
1 parent 3a809d8 commit ab6d3c6

17 files changed

Lines changed: 113 additions & 100 deletions

src/data/ellpack_page.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2024, XGBoost contributors
2+
* Copyright 2019-2026, XGBoost contributors
33
*/
44
#ifndef XGBOOST_USE_CUDA
55

src/data/ellpack_page.cu

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "../common/compressed_iterator.h" // for CompressedIterator
1919
#include "../common/cuda_context.cuh" // for CUDAContext
2020
#include "../common/cuda_rt_utils.h" // for SetDevice
21-
#include "../common/cuda_stream.h" // for DefaultStream
21+
#include "../common/cuda_stream.h" // for StreamRef
2222
#include "../common/hist_util.cuh" // for HistogramCuts
2323
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
2424
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
@@ -32,6 +32,8 @@
3232
namespace xgboost {
3333
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl{}} {}
3434

35+
EllpackPageImpl::EllpackPageImpl() = default;
36+
3537
EllpackPage::EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param)
3638
: impl_{new EllpackPageImpl{ctx, dmat, param}} {}
3739

@@ -500,17 +502,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
500502
this->monitor_.Stop("CopyGHistToEllpack");
501503
}
502504

503-
EllpackPageImpl::~EllpackPageImpl() noexcept(false) {
504-
// Sync the stream to make sure all running CUDA kernels finish before deallocation.
505-
auto status = curt::DefaultStream().Sync(false);
506-
if (status != cudaSuccess) {
507-
auto str = cudaGetErrorString(status);
508-
// For external-memory, throwing here can trigger a series of calls to
509-
// `std::terminate` by various destructors. For now, we just log the error.
510-
LOG(WARNING) << "Ran into CUDA error:" << str << "\nXGBoost is likely to abort.";
511-
}
512-
dh::safe_cuda(status);
513-
}
505+
EllpackPageImpl::~EllpackPageImpl() noexcept(false) = default;
514506

515507
// A functor that copies the data from one EllpackPage to another.
516508
template <typename IterT>

src/data/ellpack_page.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2025, XGBoost Contributors
2+
* Copyright 2019-2026, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_DATA_ELLPACK_PAGE_CUH_
55
#define XGBOOST_DATA_ELLPACK_PAGE_CUH_
@@ -186,7 +186,7 @@ class EllpackPageImpl {
186186
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
187187
* set later by the reader.
188188
*/
189-
EllpackPageImpl() = default;
189+
EllpackPageImpl();
190190

191191
/**
192192
* @brief Constructor from existing ellpack matrics.

src/data/ellpack_page.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2023 by XGBoost Contributors
2+
* Copyright 2017-2026, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
55
#define XGBOOST_DATA_ELLPACK_PAGE_H_

src/data/ellpack_page_raw_format.cu

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
/**
2-
* Copyright 2019-2025, XGBoost contributors
2+
* Copyright 2019-2026, XGBoost contributors
33
*/
44
#include <dmlc/registry.h>
55

66
#include <cstddef> // for size_t
77
#include <vector> // for vector
88

9-
#include "../common/cuda_rt_utils.h"
9+
#include "../common/cuda_context.cuh" // for CUDAContext
1010
#include "../common/cuda_stream.h" // for Event
1111
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
1212
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
@@ -21,7 +21,7 @@ DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
2121
namespace {
2222
// Function to support system without HMM or ATS
2323
template <typename T>
24-
[[nodiscard]] bool ReadDeviceVec(common::AlignedResourceReadStream* fi,
24+
[[nodiscard]] bool ReadDeviceVec(Context const* ctx, common::AlignedResourceReadStream* fi,
2525
common::RefResourceView<T>* vec) {
2626
xgboost_NVTX_FN_RANGE();
2727

@@ -42,7 +42,7 @@ template <typename T>
4242

4343
*vec = common::MakeFixedVecWithCudaMalloc<T>(n);
4444
dh::safe_cuda(
45-
cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, curt::DefaultStream()));
45+
cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
4646
return true;
4747
}
4848
} // namespace
@@ -62,7 +62,7 @@ template <typename T>
6262
RET_IF_NOT(fi->Read(&impl->info.row_stride));
6363

6464
if (this->param_.prefetch_copy || !has_hmm_ats_) {
65-
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
65+
RET_IF_NOT(ReadDeviceVec(ctx_, fi, &impl->gidx_buffer));
6666
} else {
6767
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
6868
}
@@ -73,7 +73,7 @@ template <typename T>
7373

7474
impl->SetCuts(this->cuts_);
7575

76-
curt::DefaultStream().Sync();
76+
ctx_->CUDACtx()->Stream().Sync();
7777
return true;
7878
}
7979

@@ -87,14 +87,13 @@ template <typename T>
8787
bytes += fo->Write(impl->is_dense);
8888
bytes += fo->Write(impl->info.row_stride);
8989
std::vector<common::CompressedByteT> h_gidx_buffer;
90-
Context ctx = Context{}.MakeCUDA(curt::CurrentDevice());
9190
// write data into the h_gidx_buffer
92-
[[maybe_unused]] auto h_accessor = impl->GetHostEllpack(&ctx, &h_gidx_buffer);
91+
[[maybe_unused]] auto h_accessor = impl->GetHostEllpack(ctx_, &h_gidx_buffer);
9392
bytes += common::WriteVec(fo, h_gidx_buffer);
9493
bytes += fo->Write(impl->base_rowid);
9594
bytes += fo->Write(impl->NumSymbols());
9695

97-
curt::DefaultStream().Sync();
96+
ctx_->CUDACtx()->Stream().Sync();
9897
return bytes;
9998
}
10099

@@ -104,21 +103,21 @@ template <typename T>
104103
auto* impl = page->Impl();
105104
CHECK(this->cuts_->cut_values_.DeviceCanRead());
106105

107-
auto ctx = Context{}.MakeCUDA(curt::CurrentDevice());
106+
auto stream = ctx_->CUDACtx()->Stream();
108107

109108
auto dispatch = [&] {
110-
fi->Read(&ctx, page, this->param_.prefetch_copy || !this->has_hmm_ats_);
109+
fi->Read(ctx_, page, this->param_.prefetch_copy || !this->has_hmm_ats_);
111110
impl->SetCuts(this->cuts_);
112111
};
113112

114113
if (ConsoleLogger::GlobalVerbosity() == ConsoleLogger::LogVerbosity::kDebug) {
115114
curt::Event start{false}, stop{false};
116115
float milliseconds = 0;
117-
start.Record(ctx.CUDACtx()->Stream());
116+
start.Record(stream);
118117

119118
dispatch();
120119

121-
stop.Record(ctx.CUDACtx()->Stream());
120+
stop.Record(stream);
122121
stop.Sync();
123122
dh::safe_cuda(cudaEventElapsedTime(&milliseconds, start, stop));
124123
double n_bytes = page->Impl()->MemCostBytes();
@@ -128,7 +127,7 @@ template <typename T>
128127
dispatch();
129128
}
130129

131-
curt::DefaultStream().Sync();
130+
stream.Sync();
132131

133132
return true;
134133
}
@@ -137,8 +136,8 @@ template <typename T>
137136
EllpackHostCacheStream* fo) const {
138137
xgboost_NVTX_FN_RANGE_C(3, 252, 198);
139138

140-
bool new_page = fo->Write(page);
141-
curt::DefaultStream().Sync();
139+
bool new_page = fo->Write(ctx_, page);
140+
ctx_->CUDACtx()->Stream().Sync();
142141

143142
if (new_page) {
144143
auto cache = fo->Share();

src/data/ellpack_page_raw_format.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,17 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
2929
BatchParam param_;
3030
// Supports CUDA HMM or ATS
3131
bool has_hmm_ats_{false};
32+
Context const* ctx_;
3233

3334
public:
34-
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
35+
explicit EllpackPageRawFormat(Context const* ctx,
36+
std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
3537
BatchParam param, bool has_hmm_ats)
3638
: cuts_{std::move(cuts)},
3739
device_{device},
3840
param_{std::move(param)},
39-
has_hmm_ats_{has_hmm_ats} {}
41+
has_hmm_ats_{has_hmm_ats},
42+
ctx_{ctx} {}
4043
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
4144
[[nodiscard]] std::size_t Write(EllpackPage const& page,
4245
common::AlignedFileWriteStream* fo) override;

src/data/ellpack_page_source.cu

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2025, XGBoost contributors
2+
* Copyright 2019-2026, XGBoost contributors
33
*/
44
#include <algorithm> // for max
55
#include <cstddef> // for size_t
@@ -145,9 +145,8 @@ class EllpackHostCacheStreamImpl {
145145
ptr_ = k;
146146
}
147147

148-
[[nodiscard]] bool Write(EllpackPage const& page) {
148+
[[nodiscard]] bool Write(Context const* ctx, EllpackPage const& page) {
149149
auto impl = page.Impl();
150-
auto ctx = Context{}.MakeCUDA(dh::CurrentDevice());
151150

152151
this->cache_->sizes_orig.push_back(page.Impl()->MemCostBytes());
153152
auto orig_ptr = this->cache_->sizes_orig.size() - 1;
@@ -219,10 +218,10 @@ class EllpackHostCacheStreamImpl {
219218
dc::CuMemParams c_out;
220219
std::size_t constexpr kChunkSize = 1ul << 21;
221220
auto params = dc::CompressSnappy(
222-
&ctx, old_impl->gidx_buffer.ToSpan().subspan(n_h_bytes, n_comp_bytes), &tmp, kChunkSize);
221+
ctx, old_impl->gidx_buffer.ToSpan().subspan(n_h_bytes, n_comp_bytes), &tmp, kChunkSize);
223222
common::RefResourceView<std::uint8_t> c_buf = dc::CoalesceCompressedBuffersToHost(
224-
ctx.CUDACtx()->Stream(), this->cache_->pool, params, tmp, &c_out);
225-
auto c_page = dc::MakeSnappyDecomprMgr(ctx.CUDACtx()->Stream(), this->cache_->pool,
223+
ctx->CUDACtx()->Stream(), this->cache_->pool, params, tmp, &c_out);
224+
auto c_page = dc::MakeSnappyDecomprMgr(ctx->CUDACtx()->Stream(), this->cache_->pool,
226225
std::move(c_out), c_buf.ToSpan());
227226
CHECK_EQ(c_page.DecompressedBytes() + new_impl->gidx_buffer.size_bytes(), n_bytes);
228227

@@ -264,13 +263,13 @@ class EllpackHostCacheStreamImpl {
264263
// Push a new page
265264
auto n_bytes = this->cache_->buffer_bytes.at(this->cache_->h_pages.size());
266265
auto n_samples = this->cache_->buffer_rows.at(this->cache_->h_pages.size());
267-
auto new_impl = std::make_unique<EllpackPageImpl>(&ctx, impl->CutsShared(), impl->IsDense(),
266+
auto new_impl = std::make_unique<EllpackPageImpl>(ctx, impl->CutsShared(), impl->IsDense(),
268267
impl->info.row_stride, n_samples);
269268
new_impl->SetBaseRowId(impl->base_rowid);
270269
new_impl->SetNumSymbols(impl->NumSymbols());
271270
new_impl->gidx_buffer =
272-
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(&ctx, n_bytes, 0);
273-
auto offset = new_impl->Copy(&ctx, impl, 0);
271+
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(ctx, n_bytes, 0);
272+
auto offset = new_impl->Copy(ctx, impl, 0);
274273

275274
this->cache_->offsets.push_back(offset);
276275

@@ -284,7 +283,7 @@ class EllpackHostCacheStreamImpl {
284283
CHECK(!this->cache_->h_pages.empty());
285284
CHECK_EQ(cache_idx, this->cache_->h_pages.size() - 1);
286285
auto& new_impl = this->cache_->h_pages.back();
287-
auto offset = new_impl->Copy(&ctx, impl, this->cache_->offsets.back());
286+
auto offset = new_impl->Copy(ctx, impl, this->cache_->offsets.back());
288287
this->cache_->offsets.back() += offset;
289288
}
290289

@@ -382,10 +381,24 @@ void EllpackHostCacheStream::Read(Context const* ctx, EllpackPage* page, bool pr
382381
this->p_impl_->Read(ctx, page, prefetch_copy);
383382
}
384383

385-
[[nodiscard]] bool EllpackHostCacheStream::Write(EllpackPage const& page) {
386-
return this->p_impl_->Write(page);
384+
[[nodiscard]] bool EllpackHostCacheStream::Write(Context const* ctx, EllpackPage const& page) {
385+
return this->p_impl_->Write(ctx, page);
387386
}
388387

388+
/**
389+
* EllpackFormatPolicy
390+
*/
391+
template <typename S>
392+
void EllpackFormatPolicy<S>::DestroyPage(std::shared_ptr<S>* page) const {
393+
if (page && ctx_) {
394+
ctx_->CUDACtx()->Stream().Sync();
395+
}
396+
page->reset();
397+
}
398+
399+
template void EllpackFormatPolicy<EllpackPage>::DestroyPage(
400+
std::shared_ptr<EllpackPage>* page) const;
401+
389402
/**
390403
* EllpackCacheStreamPolicy
391404
*/
@@ -528,13 +541,14 @@ void EllpackPageSourceImpl<F>::Fetch() {
528541
// This is not read from cache so we still need it to be synced with sparse page source.
529542
CHECK_EQ(this->Iter(), this->source_->Iter());
530543
auto const& csr = this->source_->Page();
544+
this->DestroyPage(&this->page_);
531545
this->page_.reset(new EllpackPage{});
532546
auto* impl = this->page_->Impl();
533-
Context ctx = Context{}.MakeCUDA(this->Device().ordinal);
534547
if (this->GetCuts()->HasCategorical()) {
535548
CHECK(!this->feature_types_.empty());
536549
}
537-
*impl = EllpackPageImpl{&ctx, this->GetCuts(), *csr, is_dense_, row_stride_, feature_types_};
550+
*impl =
551+
EllpackPageImpl{this->Ctx(), this->GetCuts(), *csr, is_dense_, row_stride_, feature_types_};
538552
this->page_->SetBaseRowId(csr->base_rowid);
539553
LOG(INFO) << "Generated an Ellpack page with size: "
540554
<< common::HumanMemUnit(impl->MemCostBytes())
@@ -573,6 +587,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
573587
bst_idx_t row_stride = GetRowCounts(this->ctx_, value, row_counts_span,
574588
dh::GetDevice(this->ctx_), this->missing_);
575589
CHECK_LE(row_stride, this->ext_info_.row_stride);
590+
this->DestroyPage(&this->page_);
576591
this->page_.reset(new EllpackPage{});
577592
*this->page_->Impl() = EllpackPageImpl{this->ctx_,
578593
value,

src/data/ellpack_page_source.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2025, XGBoost Contributors
2+
* Copyright 2019-2026, XGBoost Contributors
33
*/
44

55
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
@@ -164,7 +164,7 @@ class EllpackHostCacheStream {
164164
* @return Whether a new cache page is create. False if the new page is appended to the
165165
* previous one.
166166
*/
167-
[[nodiscard]] bool Write(EllpackPage const& page);
167+
[[nodiscard]] bool Write(Context const* ctx, EllpackPage const& page);
168168
};
169169

170170
namespace detail {
@@ -177,6 +177,7 @@ class EllpackFormatPolicy {
177177
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
178178
DeviceOrd device_;
179179
bool has_hmm_{curt::SupportsPageableMem()};
180+
Context const* ctx_{nullptr};
180181

181182
EllpackCacheInfo cache_info_;
182183
static_assert(std::is_same_v<S, EllpackPage>);
@@ -214,11 +215,12 @@ class EllpackFormatPolicy {
214215

215216
[[nodiscard]] auto CreatePageFormat(BatchParam const& param) const {
216217
CHECK_EQ(cuts_->cut_values_.Device(), device_);
217-
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, param, has_hmm_}};
218+
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{ctx_, cuts_, device_, param, has_hmm_}};
218219
return fmt;
219220
}
220-
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
221-
EllpackCacheInfo cinfo) {
221+
void SetCuts(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
222+
DeviceOrd device, EllpackCacheInfo cinfo) {
223+
this->ctx_ = ctx;
222224
std::swap(this->cuts_, cuts);
223225
this->device_ = device;
224226
CHECK(this->device_.IsCUDA());
@@ -230,6 +232,8 @@ class EllpackFormatPolicy {
230232
}
231233
[[nodiscard]] auto Device() const { return this->device_; }
232234
[[nodiscard]] auto const& CacheInfo() { return this->cache_info_; }
235+
[[nodiscard]] auto Ctx() const { return this->ctx_; }
236+
void DestroyPage(std::shared_ptr<S>* page) const;
233237
};
234238

235239
template <typename S, template <typename> typename F>
@@ -311,7 +315,7 @@ class EllpackPageSourceImpl : public PageSourceIncMixIn<EllpackPage, F> {
311315
feature_types_{feature_types} {
312316
this->source_ = source;
313317
cuts->SetDevice(ctx->Device());
314-
this->SetCuts(std::move(cuts), ctx->Device(), cinfo);
318+
this->SetCuts(ctx, std::move(cuts), ctx->Device(), cinfo);
315319
this->Fetch();
316320
}
317321

@@ -353,7 +357,7 @@ class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin<EllpackPage, Forma
353357
info_{info},
354358
ext_info_{std::move(ext_info)} {
355359
cuts->SetDevice(ctx->Device());
356-
this->SetCuts(std::move(cuts), ctx->Device(), cinfo);
360+
this->SetCuts(ctx, std::move(cuts), ctx->Device(), cinfo);
357361
CHECK(!this->cache_info_->written);
358362
this->source_->Reset();
359363
CHECK(this->source_->Next());
@@ -383,6 +387,11 @@ using ExtEllpackPageSource =
383387
ExtEllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
384388

385389
#if !defined(XGBOOST_USE_CUDA)
390+
template <typename S>
391+
inline void EllpackFormatPolicy<S>::DestroyPage(std::shared_ptr<S>* page) const {
392+
page->reset();
393+
}
394+
386395
template <typename F>
387396
inline void EllpackPageSourceImpl<F>::Fetch() {
388397
// silent the warning about unused variables.

0 commit comments

Comments
 (0)