Push batch tree weights into predictors#12068
Conversation
There was a problem hiding this comment.
Pull request overview
This PR moves per-tree weighting for batch prediction into the Predictor interface so GBTree::PredictBatchImpl can use a unified prediction path for both plain GBTree and DART-style weighted prediction (as groundwork for removing GBDart).
Changes:
- Extends
Predictor::PredictBatchto accept optional per-tree weights and wires weight application into CPU/GPU batch prediction code paths. - Refactors
GBTree::PredictBatchImpland DART’sPredictBatchto use a single weighted/unweighted prediction flow, with cache reuse only for unweighted predictions. - Adds a new unit test exercising weighted batch prediction on both CPU and GPU predictors.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| include/xgboost/predictor.h | Extends predictor batch API with optional tree weights. |
| src/gbm/gbtree.h | Threads optional weights through GBTree::PredictBatchImpl. |
| src/gbm/gbtree.cc | Unifies weighted/unweighted batch prediction (DART now supplies weights). |
| src/predictor/cpu_predictor.cc | Applies tree weights during CPU batch prediction (row-block + column-split). |
| src/predictor/gpu_predictor.cu | Applies tree weights in GPU kernels and threads weights through GPU predictor paths. |
| tests/cpp/predictor/test_predictor.h | Declares new weighted-batch test helper. |
| tests/cpp/predictor/test_predictor.cc | Implements weighted-batch test helper. |
| tests/cpp/predictor/test_cpu_predictor.cc | Adds CPU test case for weighted batch prediction. |
| tests/cpp/predictor/test_gpu_predictor.cu | Adds GPU test case for weighted batch prediction. |
| tests/cpp/gbm/test_gbtree.cc | Formatting-only changes. |
Comments suppressed due to low confidence (1)
src/predictor/gpu_predictor.cu:412
ColumnSplitHelper::PredictLeafcallsPredictDMatrix<true>(...)without the newtree_weightsargument, butPredictDMatrixnow requirescommon::OptionalWeights tree_weights. This will fail to compile for column-split leaf prediction. Pass a defaultOptionalWeights{1.0f}(or add an overload/default argument) for the leaf path.
void PredictLeaf(DMatrix* dmat, HostDeviceVector<float>* out_preds, gbm::GBTreeModel const& model,
DeviceModel const& d_model) const {
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
PredictDMatrix<true>(dmat, out_preds, d_model, model.learner_model_param->num_feature,
model.learner_model_param->num_output_group);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pred_weights = common::MakeOptionalWeights(ctx_->Device(), weights); | ||
| } | ||
| this->PredictDMatrix(dmat, out_preds, model, tree_begin, tree_end, pred_weights); | ||
| if (tree_weights != nullptr && tree_begin != tree_end) { |
There was a problem hiding this comment.
Is this a real concern? Is PredictDMatrix using a different CUDA stream than the one used here?
There was a problem hiding this comment.
Weights falls out of scope at the end of the function and can be deleted but the cpu often runs ahead of the gpu. Cuda free synchronizes but we might be using a pool that can reallocate the memory to something else. So there is a technical risk of the memory being reused before the kernel finishes. If all subsequent code was using the same stream it would be fine but we don't strictly have that guarantee.
There was a problem hiding this comment.
I forgot that we only ever use 1 stream. I might remove it.
This PR is a small step toward the follow-up work to remove GBDart as a separate booster and treat DART as weighted GBTree prediction/training behavior. It moves batch tree weighting into the predictor interface so GBTree::PredictBatchImpl can share one prediction path for both plain GBTree and DART-style weighted prediction. The scope here is intentionally limited to batch prediction; inplace prediction and SHAP are left unchanged for now. This is groundwork for the larger PR that removes GBDart.