Skip to content

Push batch tree weights into predictors#12068

Merged
RAMitchell merged 6 commits intodmlc:masterfrom
RAMitchell:partial-unify-dart-predict
Mar 11, 2026
Merged

Push batch tree weights into predictors#12068
RAMitchell merged 6 commits intodmlc:masterfrom
RAMitchell:partial-unify-dart-predict

Conversation

@RAMitchell
Copy link
Member

This PR is a small step toward the follow-up work to remove GBDart as a separate booster and treat DART as weighted GBTree prediction/training behavior. It moves batch tree weighting into the predictor interface so GBTree::PredictBatchImpl can share one prediction path for both plain GBTree and DART-style weighted prediction. The scope here is intentionally limited to batch prediction; inplace prediction and SHAP are left unchanged for now. This is groundwork for the larger PR that removes GBDart.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR moves per-tree weighting for batch prediction into the Predictor interface so GBTree::PredictBatchImpl can use a unified prediction path for both plain GBTree and DART-style weighted prediction (as groundwork for removing GBDart).

Changes:

  • Extends Predictor::PredictBatch to accept optional per-tree weights and wires weight application into CPU/GPU batch prediction code paths.
  • Refactors GBTree::PredictBatchImpl and DART’s PredictBatch to use a single weighted/unweighted prediction flow, with cache reuse only for unweighted predictions.
  • Adds a new unit test exercising weighted batch prediction on both CPU and GPU predictors.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
include/xgboost/predictor.h Extends predictor batch API with optional tree weights.
src/gbm/gbtree.h Threads optional weights through GBTree::PredictBatchImpl.
src/gbm/gbtree.cc Unifies weighted/unweighted batch prediction (DART now supplies weights).
src/predictor/cpu_predictor.cc Applies tree weights during CPU batch prediction (row-block + column-split).
src/predictor/gpu_predictor.cu Applies tree weights in GPU kernels and threads weights through GPU predictor paths.
tests/cpp/predictor/test_predictor.h Declares new weighted-batch test helper.
tests/cpp/predictor/test_predictor.cc Implements weighted-batch test helper.
tests/cpp/predictor/test_cpu_predictor.cc Adds CPU test case for weighted batch prediction.
tests/cpp/predictor/test_gpu_predictor.cu Adds GPU test case for weighted batch prediction.
tests/cpp/gbm/test_gbtree.cc Formatting-only changes.
Comments suppressed due to low confidence (1)

src/predictor/gpu_predictor.cu:412

  • ColumnSplitHelper::PredictLeaf calls PredictDMatrix<true>(...) without the new tree_weights argument, but PredictDMatrix now requires common::OptionalWeights tree_weights. This will fail to compile for column-split leaf prediction. Pass a default OptionalWeights{1.0f} (or add an overload/default argument) for the leaf path.
  void PredictLeaf(DMatrix* dmat, HostDeviceVector<float>* out_preds, gbm::GBTreeModel const& model,
                   DeviceModel const& d_model) const {
    CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
    PredictDMatrix<true>(dmat, out_preds, d_model, model.learner_model_param->num_feature,
                         model.learner_model_param->num_output_group);

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

pred_weights = common::MakeOptionalWeights(ctx_->Device(), weights);
}
this->PredictDMatrix(dmat, out_preds, model, tree_begin, tree_end, pred_weights);
if (tree_weights != nullptr && tree_begin != tree_end) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a real concern? Is PredictDMatrix using a different CUDA stream than the one used here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weights falls out of scope at the end of the function and can be deleted but the cpu often runs ahead of the gpu. Cuda free synchronizes but we might be using a pool that can reallocate the memory to something else. So there is a technical risk of the memory being reused before the kernel finishes. If all subsequent code was using the same stream it would be fine but we don't strictly have that guarantee.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot that we only ever use 1 stream. I might remove it.

@RAMitchell RAMitchell merged commit e7358f9 into dmlc:master Mar 11, 2026
80 checks passed
@RAMitchell RAMitchell deleted the partial-unify-dart-predict branch March 11, 2026 12:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants