11/* *
2- * Copyright 2019-2025 , XGBoost contributors
2+ * Copyright 2019-2026 , XGBoost contributors
33 */
44#include < algorithm> // for max
55#include < cstddef> // for size_t
@@ -145,9 +145,8 @@ class EllpackHostCacheStreamImpl {
145145 ptr_ = k;
146146 }
147147
148- [[nodiscard]] bool Write (EllpackPage const & page) {
148+ [[nodiscard]] bool Write (Context const * ctx, EllpackPage const & page) {
149149 auto impl = page.Impl ();
150- auto ctx = Context{}.MakeCUDA (dh::CurrentDevice ());
151150
152151 this ->cache_ ->sizes_orig .push_back (page.Impl ()->MemCostBytes ());
153152 auto orig_ptr = this ->cache_ ->sizes_orig .size () - 1 ;
@@ -219,10 +218,10 @@ class EllpackHostCacheStreamImpl {
219218 dc::CuMemParams c_out;
220219 std::size_t constexpr kChunkSize = 1ul << 21 ;
221220 auto params = dc::CompressSnappy (
222- & ctx, old_impl->gidx_buffer .ToSpan ().subspan (n_h_bytes, n_comp_bytes), &tmp, kChunkSize );
221+ ctx, old_impl->gidx_buffer .ToSpan ().subspan (n_h_bytes, n_comp_bytes), &tmp, kChunkSize );
223222 common::RefResourceView<std::uint8_t > c_buf = dc::CoalesceCompressedBuffersToHost (
224- ctx. CUDACtx ()->Stream (), this ->cache_ ->pool , params, tmp, &c_out);
225- auto c_page = dc::MakeSnappyDecomprMgr (ctx. CUDACtx ()->Stream (), this ->cache_ ->pool ,
223+ ctx-> CUDACtx ()->Stream (), this ->cache_ ->pool , params, tmp, &c_out);
224+ auto c_page = dc::MakeSnappyDecomprMgr (ctx-> CUDACtx ()->Stream (), this ->cache_ ->pool ,
226225 std::move (c_out), c_buf.ToSpan ());
227226 CHECK_EQ (c_page.DecompressedBytes () + new_impl->gidx_buffer .size_bytes (), n_bytes);
228227
@@ -264,13 +263,13 @@ class EllpackHostCacheStreamImpl {
264263 // Push a new page
265264 auto n_bytes = this ->cache_ ->buffer_bytes .at (this ->cache_ ->h_pages .size ());
266265 auto n_samples = this ->cache_ ->buffer_rows .at (this ->cache_ ->h_pages .size ());
267- auto new_impl = std::make_unique<EllpackPageImpl>(& ctx, impl->CutsShared (), impl->IsDense (),
266+ auto new_impl = std::make_unique<EllpackPageImpl>(ctx, impl->CutsShared (), impl->IsDense (),
268267 impl->info .row_stride , n_samples);
269268 new_impl->SetBaseRowId (impl->base_rowid );
270269 new_impl->SetNumSymbols (impl->NumSymbols ());
271270 new_impl->gidx_buffer =
272- common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(& ctx, n_bytes, 0 );
273- auto offset = new_impl->Copy (& ctx, impl, 0 );
271+ common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(ctx, n_bytes, 0 );
272+ auto offset = new_impl->Copy (ctx, impl, 0 );
274273
275274 this ->cache_ ->offsets .push_back (offset);
276275
@@ -284,7 +283,7 @@ class EllpackHostCacheStreamImpl {
284283 CHECK (!this ->cache_ ->h_pages .empty ());
285284 CHECK_EQ (cache_idx, this ->cache_ ->h_pages .size () - 1 );
286285 auto & new_impl = this ->cache_ ->h_pages .back ();
287- auto offset = new_impl->Copy (& ctx, impl, this ->cache_ ->offsets .back ());
286+ auto offset = new_impl->Copy (ctx, impl, this ->cache_ ->offsets .back ());
288287 this ->cache_ ->offsets .back () += offset;
289288 }
290289
@@ -382,10 +381,24 @@ void EllpackHostCacheStream::Read(Context const* ctx, EllpackPage* page, bool pr
382381 this ->p_impl_ ->Read (ctx, page, prefetch_copy);
383382}
384383
385- [[nodiscard]] bool EllpackHostCacheStream::Write (EllpackPage const & page) {
386- return this ->p_impl_ ->Write (page);
384+ [[nodiscard]] bool EllpackHostCacheStream::Write (Context const * ctx, EllpackPage const & page) {
385+ return this ->p_impl_ ->Write (ctx, page);
387386}
388387
388+ /* *
389+ * EllpackFormatPolicy
390+ */
391+ template <typename S>
392+ void EllpackFormatPolicy<S>::DestroyPage(std::shared_ptr<S>* page) const {
393+ if (page && ctx_) {
394+ ctx_->CUDACtx ()->Stream ().Sync ();
395+ }
396+ page->reset ();
397+ }
398+
399+ template void EllpackFormatPolicy<EllpackPage>::DestroyPage(
400+ std::shared_ptr<EllpackPage>* page) const ;
401+
389402/* *
390403 * EllpackCacheStreamPolicy
391404 */
@@ -528,13 +541,14 @@ void EllpackPageSourceImpl<F>::Fetch() {
528541 // This is not read from cache so we still need it to be synced with sparse page source.
529542 CHECK_EQ (this ->Iter (), this ->source_ ->Iter ());
530543 auto const & csr = this ->source_ ->Page ();
544+ this ->DestroyPage (&this ->page_ );
531545 this ->page_ .reset (new EllpackPage{});
532546 auto * impl = this ->page_ ->Impl ();
533- Context ctx = Context{}.MakeCUDA (this ->Device ().ordinal );
534547 if (this ->GetCuts ()->HasCategorical ()) {
535548 CHECK (!this ->feature_types_ .empty ());
536549 }
537- *impl = EllpackPageImpl{&ctx, this ->GetCuts (), *csr, is_dense_, row_stride_, feature_types_};
550+ *impl =
551+ EllpackPageImpl{this ->Ctx (), this ->GetCuts (), *csr, is_dense_, row_stride_, feature_types_};
538552 this ->page_ ->SetBaseRowId (csr->base_rowid );
539553 LOG (INFO) << " Generated an Ellpack page with size: "
540554 << common::HumanMemUnit (impl->MemCostBytes ())
@@ -573,6 +587,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
573587 bst_idx_t row_stride = GetRowCounts (this ->ctx_ , value, row_counts_span,
574588 dh::GetDevice (this ->ctx_ ), this ->missing_ );
575589 CHECK_LE (row_stride, this ->ext_info_ .row_stride );
590+ this ->DestroyPage (&this ->page_ );
576591 this ->page_ .reset (new EllpackPage{});
577592 *this ->page_ ->Impl () = EllpackPageImpl{this ->ctx_ ,
578593 value,
0 commit comments