@@ -183,6 +183,34 @@ void CopyGradient(Context const* ctx, linalg::Matrix<GradientPair> const* in_gpa
183183 }
184184}
185185
186+ /* * Increment the prediction on GPU.
187+ *
188+ * \param out_predts Prediction for the whole model.
189+ * \param predts Prediction for current tree.
190+ * \param tree_w Tree weight.
191+ */
192+ void GPUDartPredictInc (common::Span<float >, common::Span<float >, float , size_t , bst_group_t ,
193+ bst_group_t )
194+ #if defined(XGBOOST_USE_CUDA)
195+ ; // NOLINT
196+ #else
197+ {
198+ common::AssertGPUSupport ();
199+ }
200+ #endif
201+
202+ void GPUDartInplacePredictInc (common::Span<float > /* out_predts*/ , common::Span<float > /* predts*/ ,
203+ float /* tree_w*/ , size_t /* n_rows*/ ,
204+ linalg::TensorView<float const , 1 > /* base_score*/ ,
205+ bst_group_t /* n_groups*/ , bst_group_t /* group*/ )
206+ #if defined(XGBOOST_USE_CUDA)
207+ ; // NOLINT
208+ #else
209+ {
210+ common::AssertGPUSupport ();
211+ }
212+ #endif
213+
186214void GBTree::UpdateTreeLeaf (DMatrix const * p_fmat, HostDeviceVector<float > const & predictions,
187215 ObjFunction const * obj, std::int32_t group_idx,
188216 std::vector<HostDeviceVector<bst_node_t >> const & node_position,
@@ -501,49 +529,65 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
501529}
502530
503531void GBTree::PredictBatchImpl (DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
504- bst_layer_t layer_begin, bst_layer_t layer_end) const {
532+ bst_layer_t layer_begin, bst_layer_t layer_end,
533+ std::vector<float > const * tree_weights) const {
534+ // Unweighted prediction can reuse a cached prefix of the model output by tracking how many
535+ // boosting iterations have already been accumulated in `out_preds->version`.
536+ //
537+ // Weighted prediction is used by DART and does not participate in this cache, since tree
538+ // weights can change the accumulated output independently of the cached unweighted prefix.
505539 if (layer_end == 0 ) {
506540 layer_end = this ->BoostedRounds ();
507541 }
508- if (layer_begin != 0 || layer_end < static_cast <bst_layer_t >(out_preds->version )) {
509- // cache is dropped.
542+
543+ auto cache_version = out_preds->version ;
544+ // We can preserve the cache only when:
545+ // - prediction is unweighted
546+ // - prediction starts from iteration 0, so the result is a cacheable prefix
547+ auto preserve_cache = tree_weights == nullptr && layer_begin == 0 ;
548+ // We can reuse the existing cached prefix only when:
549+ // - the result itself is cacheable
550+ // - the requested range does not move backwards past the cached version
551+ auto reuse_cache = preserve_cache && layer_end >= static_cast <bst_layer_t >(cache_version);
552+ // Initialize output when:
553+ // - the cached prefix cannot be reused, or
554+ // - the cache is valid but still empty
555+ auto initialize_output = !reuse_cache || cache_version == 0 ;
556+ auto prediction_begin = reuse_cache ? cache_version : layer_begin;
557+
558+ if (!reuse_cache) {
510559 out_preds->version = 0 ;
560+ cache_version = 0 ;
511561 }
512- bool reset = false ;
513- if (layer_begin == 0 ) {
514- layer_begin = out_preds->version ;
515- } else {
516- // When begin layer is not 0, the cache is not useful.
517- reset = true ;
518- }
562+
519563 if (out_preds->predictions .Size () == 0 && p_fmat->Info ().num_row_ != 0 ) {
520564 CHECK_EQ (out_preds->version , 0 );
521565 }
522566
523567 auto const & predictor = GetPredictor (is_training, &out_preds->predictions , p_fmat);
524- if (out_preds-> version == 0 ) {
568+ if (initialize_output ) {
525569 // out_preds->Size() can be non-zero as it's initialized here before any
526570 // tree is built at the 0^th iterator.
527571 predictor->InitOutPredictions (p_fmat->Info (), &out_preds->predictions , model_);
528572 }
529573
530- auto [tree_begin, tree_end] = detail::LayerToTree (model_, layer_begin , layer_end);
574+ auto [tree_begin, tree_end] = detail::LayerToTree (model_, prediction_begin , layer_end);
531575 CHECK_LE (tree_end, model_.trees .size ()) << " Invalid number of trees." ;
532576 if (tree_end > tree_begin) {
533- predictor->PredictBatch (p_fmat, out_preds, model_, tree_begin, tree_end);
577+ predictor->PredictBatch (p_fmat, out_preds, model_, tree_begin, tree_end, tree_weights );
534578 }
535- if (reset) {
579+
580+ if (!preserve_cache) {
536581 out_preds->version = 0 ;
537582 } else {
538- std::uint32_t delta = layer_end - out_preds->version ;
539- out_preds->Update (delta);
583+ out_preds->Update (layer_end - cache_version);
540584 }
541585}
542586
543587void GBTree::PredictBatch (DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
544588 bst_layer_t layer_begin, bst_layer_t layer_end) {
545589 // dispatch to const function.
546- this ->PredictBatchImpl (p_fmat, out_preds, is_training, layer_begin, layer_end);
590+ this ->PredictBatchImpl (p_fmat, out_preds, is_training, layer_begin, layer_end, nullptr );
547591}
548592
549593void GBTree::InplacePredict (std::shared_ptr<DMatrix> p_m, float missing,
@@ -646,34 +690,6 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
646690 return cpu_predictor_;
647691}
648692
649- /* * Increment the prediction on GPU.
650- *
651- * \param out_predts Prediction for the whole model.
652- * \param predts Prediction for current tree.
653- * \param tree_w Tree weight.
654- */
655- void GPUDartPredictInc (common::Span<float >, common::Span<float >, float , size_t , bst_group_t ,
656- bst_group_t )
657- #if defined(XGBOOST_USE_CUDA)
658- ; // NOLINT
659- #else
660- {
661- common::AssertGPUSupport ();
662- }
663- #endif
664-
665- void GPUDartInplacePredictInc (common::Span<float > /* out_predts*/ , common::Span<float > /* predts*/ ,
666- float /* tree_w*/ , size_t /* n_rows*/ ,
667- linalg::TensorView<float const , 1 > /* base_score*/ ,
668- bst_group_t /* n_groups*/ , bst_group_t /* group*/ )
669- #if defined(XGBOOST_USE_CUDA)
670- ; // NOLINT
671- #else
672- {
673- common::AssertGPUSupport ();
674- }
675- #endif
676-
677693class Dart : public GBTree {
678694 public:
679695 explicit Dart (LearnerModelParam const * booster_config, Context const * ctx)
@@ -737,63 +753,19 @@ class Dart : public GBTree {
737753 out[" dart_train_param" ] = ToJson (dparam_);
738754 }
739755
740- // An independent const function to make sure it's thread safe.
741- void PredictBatchImpl (DMatrix* p_fmat, PredictionCacheEntry* p_out_preds, bool training,
742- bst_layer_t layer_begin, bst_layer_t layer_end) const {
743- CHECK (!this ->model_ .learner_model_param ->IsVectorLeaf ()) << " dart" << MTNotImplemented ();
744- auto & predictor = this ->GetPredictor (training, &p_out_preds->predictions , p_fmat);
745- CHECK (predictor);
746- predictor->InitOutPredictions (p_fmat->Info (), &p_out_preds->predictions , model_);
747- p_out_preds->version = 0 ;
748- auto [tree_begin, tree_end] = detail::LayerToTree (model_, layer_begin, layer_end);
749- auto n_groups = model_.learner_model_param ->num_output_group ;
750-
751- PredictionCacheEntry predts; // temporary storage for prediction
752- if (!ctx_->IsCPU ()) {
753- predts.predictions .SetDevice (ctx_->Device ());
754- }
755- predts.predictions .Resize (p_fmat->Info ().num_row_ * n_groups, 0 );
756- // multi-target is not yet supported.
757- auto layer_trees = [&]() {
758- return model_.param .num_parallel_tree * model_.learner_model_param ->OutputLength ();
759- };
760- auto const & h_tree_info = this ->model_ .tree_info .ConstHostVector ();
761- for (bst_tree_t i = tree_begin; i < tree_end; i += 1 ) {
762- if (training && std::binary_search (idx_drop_.cbegin (), idx_drop_.cend (), i)) {
763- continue ;
764- }
765-
766- CHECK_GE (i, p_out_preds->version );
767- auto version = i / layer_trees ();
768- p_out_preds->version = version;
769- predts.predictions .Fill (0 );
770- predictor->PredictBatch (p_fmat, &predts, model_, i, i + 1 );
771-
772- // Multiple the weight to output prediction.
773- auto w = this ->weight_drop_ .at (i);
774- auto grp_idx = h_tree_info.at (i);
775- CHECK_EQ (p_out_preds->predictions .Size (), predts.predictions .Size ());
776-
777- size_t n_rows = p_fmat->Info ().num_row_ ;
778- if (predts.predictions .Device ().IsCUDA ()) {
779- p_out_preds->predictions .SetDevice (predts.predictions .Device ());
780- GPUDartPredictInc (p_out_preds->predictions .DeviceSpan (), predts.predictions .DeviceSpan (), w,
781- n_rows, n_groups, grp_idx);
782- } else {
783- auto & h_out_predts = p_out_preds->predictions .HostVector ();
784- auto & h_predts = predts.predictions .ConstHostVector ();
785- common::ParallelFor (p_fmat->Info ().num_row_ , ctx_->Threads (), [&](auto ridx) {
786- const size_t offset = ridx * n_groups + grp_idx;
787- h_out_predts[offset] += (h_predts[offset] * w);
788- });
789- }
790- }
791- }
792-
793756 void PredictBatch (DMatrix* p_fmat, PredictionCacheEntry* p_out_preds, bool training,
794757 bst_layer_t layer_begin, bst_layer_t layer_end) override {
795758 DropTrees (training);
796- this ->PredictBatchImpl (p_fmat, p_out_preds, training, layer_begin, layer_end);
759+ auto const * tree_weights = &weight_drop_;
760+ std::vector<float > dropped_weights;
761+ if (training && !idx_drop_.empty ()) {
762+ dropped_weights = weight_drop_;
763+ for (auto idx : idx_drop_) {
764+ dropped_weights.at (idx) = 0 .0f ;
765+ }
766+ tree_weights = &dropped_weights;
767+ }
768+ this ->PredictBatchImpl (p_fmat, p_out_preds, training, layer_begin, layer_end, tree_weights);
797769 }
798770
799771 void InplacePredict (std::shared_ptr<DMatrix> p_fmat, float missing,
@@ -808,7 +780,8 @@ class Dart : public GBTree {
808780 auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
809781 CHECK (proxy) << error::InplacePredictProxy ();
810782 auto p_fmat = data::CreateDMatrixFromProxy (ctx_, proxy, missing);
811- this ->PredictBatchImpl (p_fmat.get (), p_out_preds, false , layer_begin, layer_end);
783+ this ->PredictBatchImpl (p_fmat.get (), p_out_preds, false , layer_begin, layer_end,
784+ &weight_drop_);
812785 return ;
813786 }
814787
0 commit comments