Skip to content

Commit e7358f9

Browse files
authored
Push batch tree weights into predictors (#12068)
1 parent 2706059 commit e7358f9

File tree

11 files changed

+362
-324
lines changed

11 files changed

+362
-324
lines changed

include/xgboost/predictor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ class Predictor {
107107
*/
108108
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
109109
gbm::GBTreeModel const& model, bst_tree_t tree_begin,
110-
bst_tree_t tree_end = 0) const = 0;
110+
bst_tree_t tree_end = 0,
111+
std::vector<float> const* tree_weights = nullptr) const = 0;
111112

112113
/**
113114
* \brief Inplace prediction.

plugin/sycl/predictor/predictor.cc

100755100644
Lines changed: 71 additions & 100 deletions
Large diffs are not rendered by default.

src/gbm/gbtree.cc

Lines changed: 73 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,34 @@ void CopyGradient(Context const* ctx, linalg::Matrix<GradientPair> const* in_gpa
183183
}
184184
}
185185

186+
/** Increment the prediction on GPU.
187+
*
188+
* \param out_predts Prediction for the whole model.
189+
* \param predts Prediction for current tree.
190+
* \param tree_w Tree weight.
191+
*/
192+
void GPUDartPredictInc(common::Span<float>, common::Span<float>, float, size_t, bst_group_t,
193+
bst_group_t)
194+
#if defined(XGBOOST_USE_CUDA)
195+
; // NOLINT
196+
#else
197+
{
198+
common::AssertGPUSupport();
199+
}
200+
#endif
201+
202+
void GPUDartInplacePredictInc(common::Span<float> /*out_predts*/, common::Span<float> /*predts*/,
203+
float /*tree_w*/, size_t /*n_rows*/,
204+
linalg::TensorView<float const, 1> /*base_score*/,
205+
bst_group_t /*n_groups*/, bst_group_t /*group*/)
206+
#if defined(XGBOOST_USE_CUDA)
207+
; // NOLINT
208+
#else
209+
{
210+
common::AssertGPUSupport();
211+
}
212+
#endif
213+
186214
void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const& predictions,
187215
ObjFunction const* obj, std::int32_t group_idx,
188216
std::vector<HostDeviceVector<bst_node_t>> const& node_position,
@@ -501,49 +529,65 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
501529
}
502530

503531
void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
504-
bst_layer_t layer_begin, bst_layer_t layer_end) const {
532+
bst_layer_t layer_begin, bst_layer_t layer_end,
533+
std::vector<float> const* tree_weights) const {
534+
// Unweighted prediction can reuse a cached prefix of the model output by tracking how many
535+
// boosting iterations have already been accumulated in `out_preds->version`.
536+
//
537+
// Weighted prediction is used by DART and does not participate in this cache, since tree
538+
// weights can change the accumulated output independently of the cached unweighted prefix.
505539
if (layer_end == 0) {
506540
layer_end = this->BoostedRounds();
507541
}
508-
if (layer_begin != 0 || layer_end < static_cast<bst_layer_t>(out_preds->version)) {
509-
// cache is dropped.
542+
543+
auto cache_version = out_preds->version;
544+
// We can preserve the cache only when:
545+
// - prediction is unweighted
546+
// - prediction starts from iteration 0, so the result is a cacheable prefix
547+
auto preserve_cache = tree_weights == nullptr && layer_begin == 0;
548+
// We can reuse the existing cached prefix only when:
549+
// - the result itself is cacheable
550+
// - the requested range does not move backwards past the cached version
551+
auto reuse_cache = preserve_cache && layer_end >= static_cast<bst_layer_t>(cache_version);
552+
// Initialize output when:
553+
// - the cached prefix cannot be reused, or
554+
// - the cache is valid but still empty
555+
auto initialize_output = !reuse_cache || cache_version == 0;
556+
auto prediction_begin = reuse_cache ? cache_version : layer_begin;
557+
558+
if (!reuse_cache) {
510559
out_preds->version = 0;
560+
cache_version = 0;
511561
}
512-
bool reset = false;
513-
if (layer_begin == 0) {
514-
layer_begin = out_preds->version;
515-
} else {
516-
// When begin layer is not 0, the cache is not useful.
517-
reset = true;
518-
}
562+
519563
if (out_preds->predictions.Size() == 0 && p_fmat->Info().num_row_ != 0) {
520564
CHECK_EQ(out_preds->version, 0);
521565
}
522566

523567
auto const& predictor = GetPredictor(is_training, &out_preds->predictions, p_fmat);
524-
if (out_preds->version == 0) {
568+
if (initialize_output) {
525569
// out_preds->Size() can be non-zero as it's initialized here before any
526570
// tree is built at the 0^th iterator.
527571
predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, model_);
528572
}
529573

530-
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
574+
auto [tree_begin, tree_end] = detail::LayerToTree(model_, prediction_begin, layer_end);
531575
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
532576
if (tree_end > tree_begin) {
533-
predictor->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end);
577+
predictor->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end, tree_weights);
534578
}
535-
if (reset) {
579+
580+
if (!preserve_cache) {
536581
out_preds->version = 0;
537582
} else {
538-
std::uint32_t delta = layer_end - out_preds->version;
539-
out_preds->Update(delta);
583+
out_preds->Update(layer_end - cache_version);
540584
}
541585
}
542586

543587
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
544588
bst_layer_t layer_begin, bst_layer_t layer_end) {
545589
// dispatch to const function.
546-
this->PredictBatchImpl(p_fmat, out_preds, is_training, layer_begin, layer_end);
590+
this->PredictBatchImpl(p_fmat, out_preds, is_training, layer_begin, layer_end, nullptr);
547591
}
548592

549593
void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
@@ -646,34 +690,6 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
646690
return cpu_predictor_;
647691
}
648692

649-
/** Increment the prediction on GPU.
650-
*
651-
* \param out_predts Prediction for the whole model.
652-
* \param predts Prediction for current tree.
653-
* \param tree_w Tree weight.
654-
*/
655-
void GPUDartPredictInc(common::Span<float>, common::Span<float>, float, size_t, bst_group_t,
656-
bst_group_t)
657-
#if defined(XGBOOST_USE_CUDA)
658-
; // NOLINT
659-
#else
660-
{
661-
common::AssertGPUSupport();
662-
}
663-
#endif
664-
665-
void GPUDartInplacePredictInc(common::Span<float> /*out_predts*/, common::Span<float> /*predts*/,
666-
float /*tree_w*/, size_t /*n_rows*/,
667-
linalg::TensorView<float const, 1> /*base_score*/,
668-
bst_group_t /*n_groups*/, bst_group_t /*group*/)
669-
#if defined(XGBOOST_USE_CUDA)
670-
; // NOLINT
671-
#else
672-
{
673-
common::AssertGPUSupport();
674-
}
675-
#endif
676-
677693
class Dart : public GBTree {
678694
public:
679695
explicit Dart(LearnerModelParam const* booster_config, Context const* ctx)
@@ -737,63 +753,19 @@ class Dart : public GBTree {
737753
out["dart_train_param"] = ToJson(dparam_);
738754
}
739755

740-
// An independent const function to make sure it's thread safe.
741-
void PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* p_out_preds, bool training,
742-
bst_layer_t layer_begin, bst_layer_t layer_end) const {
743-
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
744-
auto& predictor = this->GetPredictor(training, &p_out_preds->predictions, p_fmat);
745-
CHECK(predictor);
746-
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
747-
p_out_preds->version = 0;
748-
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
749-
auto n_groups = model_.learner_model_param->num_output_group;
750-
751-
PredictionCacheEntry predts; // temporary storage for prediction
752-
if (!ctx_->IsCPU()) {
753-
predts.predictions.SetDevice(ctx_->Device());
754-
}
755-
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
756-
// multi-target is not yet supported.
757-
auto layer_trees = [&]() {
758-
return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength();
759-
};
760-
auto const& h_tree_info = this->model_.tree_info.ConstHostVector();
761-
for (bst_tree_t i = tree_begin; i < tree_end; i += 1) {
762-
if (training && std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
763-
continue;
764-
}
765-
766-
CHECK_GE(i, p_out_preds->version);
767-
auto version = i / layer_trees();
768-
p_out_preds->version = version;
769-
predts.predictions.Fill(0);
770-
predictor->PredictBatch(p_fmat, &predts, model_, i, i + 1);
771-
772-
// Multiple the weight to output prediction.
773-
auto w = this->weight_drop_.at(i);
774-
auto grp_idx = h_tree_info.at(i);
775-
CHECK_EQ(p_out_preds->predictions.Size(), predts.predictions.Size());
776-
777-
size_t n_rows = p_fmat->Info().num_row_;
778-
if (predts.predictions.Device().IsCUDA()) {
779-
p_out_preds->predictions.SetDevice(predts.predictions.Device());
780-
GPUDartPredictInc(p_out_preds->predictions.DeviceSpan(), predts.predictions.DeviceSpan(), w,
781-
n_rows, n_groups, grp_idx);
782-
} else {
783-
auto& h_out_predts = p_out_preds->predictions.HostVector();
784-
auto& h_predts = predts.predictions.ConstHostVector();
785-
common::ParallelFor(p_fmat->Info().num_row_, ctx_->Threads(), [&](auto ridx) {
786-
const size_t offset = ridx * n_groups + grp_idx;
787-
h_out_predts[offset] += (h_predts[offset] * w);
788-
});
789-
}
790-
}
791-
}
792-
793756
void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* p_out_preds, bool training,
794757
bst_layer_t layer_begin, bst_layer_t layer_end) override {
795758
DropTrees(training);
796-
this->PredictBatchImpl(p_fmat, p_out_preds, training, layer_begin, layer_end);
759+
auto const* tree_weights = &weight_drop_;
760+
std::vector<float> dropped_weights;
761+
if (training && !idx_drop_.empty()) {
762+
dropped_weights = weight_drop_;
763+
for (auto idx : idx_drop_) {
764+
dropped_weights.at(idx) = 0.0f;
765+
}
766+
tree_weights = &dropped_weights;
767+
}
768+
this->PredictBatchImpl(p_fmat, p_out_preds, training, layer_begin, layer_end, tree_weights);
797769
}
798770

799771
void InplacePredict(std::shared_ptr<DMatrix> p_fmat, float missing,
@@ -808,7 +780,8 @@ class Dart : public GBTree {
808780
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
809781
CHECK(proxy) << error::InplacePredictProxy();
810782
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
811-
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
783+
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end,
784+
&weight_drop_);
812785
return;
813786
}
814787

src/gbm/gbtree.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ class GBTree : public GradientBooster {
207207
}
208208

209209
void PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
210-
bst_layer_t layer_begin, bst_layer_t layer_end) const;
210+
bst_layer_t layer_begin, bst_layer_t layer_end,
211+
std::vector<float> const* tree_weights = nullptr) const;
211212

212213
void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training,
213214
bst_layer_t layer_begin, bst_layer_t layer_end) override;

0 commit comments

Comments
 (0)