Skip to content

Commit adfa00a

Browse files
committed
Unify DART SHAP forwarding
1 parent c98d476 commit adfa00a

File tree

5 files changed

+51
-20
lines changed

5 files changed

+51
-20
lines changed

src/gbm/gbtree.cc

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -797,23 +797,9 @@ class Dart : public GBTree {
797797
}
798798
}
799799

800-
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
801-
bst_layer_t layer_begin, bst_layer_t layer_end,
802-
bool approximate) override {
803-
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
804-
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, &weight_drop_,
805-
approximate);
806-
}
807-
808-
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
809-
bst_layer_t layer_begin, bst_layer_t layer_end,
810-
bool approximate) override {
811-
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
812-
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, tree_end,
813-
&weight_drop_, approximate);
814-
}
815-
816800
protected:
801+
[[nodiscard]] std::vector<float> const* TreeWeights() const override { return &weight_drop_; }
802+
817803
// commit new trees all at once
818804
void CommitModel(TreesOneIter&& new_trees) override {
819805
auto n_new_trees = model_.CommitModel(std::forward<TreesOneIter>(new_trees));

src/gbm/gbtree.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,8 @@ class GBTree : public GradientBooster {
303303
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
304304
CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: [0, "
305305
"n_iteration), using model slicing instead.";
306-
this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, nullptr,
307-
approximate);
306+
this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end,
307+
this->TreeWeights(), approximate);
308308
}
309309

310310
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
@@ -313,8 +313,8 @@ class GBTree : public GradientBooster {
313313
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
314314
CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: [0, "
315315
"n_iteration), using model slicing instead.";
316-
this->GetPredictor(false)->PredictInteractionContributions(p_fmat, out_contribs, model_,
317-
tree_end, nullptr, approximate);
316+
this->GetPredictor(false)->PredictInteractionContributions(
317+
p_fmat, out_contribs, model_, tree_end, this->TreeWeights(), approximate);
318318
}
319319

320320
[[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
@@ -323,6 +323,8 @@ class GBTree : public GradientBooster {
323323
}
324324

325325
protected:
326+
[[nodiscard]] virtual std::vector<float> const* TreeWeights() const { return nullptr; }
327+
326328
void BoostNewTrees(GradientContainer* gpair, DMatrix* p_fmat, int bst_group,
327329
std::vector<HostDeviceVector<bst_node_t>>* out_position,
328330
std::vector<std::unique_ptr<RegTree>>* ret);

tests/cpp/predictor/test_shap.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,38 @@ void CheckShapOutput(DMatrix* dmat, Args const& model_args) {
215215
CheckShapAdditivity(kRows, kCols, shap_interactions, margin_predt);
216216
}
217217

218+
void CheckDartShapOutput(Context const* ctx) {
219+
size_t constexpr kRows = 64, kCols = 8;
220+
auto dmat = RandomDataGenerator(kRows, kCols, 0.0).Device(ctx->Device()).GenerateDMatrix();
221+
SetLabels(dmat.get(), 1);
222+
223+
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
224+
learner->SetParams(Args{{"booster", "dart"},
225+
{"objective", "binary:logistic"},
226+
{"max_depth", "3"},
227+
{"rate_drop", "0.5"},
228+
{"sample_type", "uniform"},
229+
{"normalize_type", "tree"},
230+
{"device", ctx->IsSycl() ? "cpu" : ctx->DeviceName()}});
231+
learner->Configure();
232+
for (size_t i = 0; i < 4; ++i) {
233+
learner->UpdateOneIter(i, dmat);
234+
}
235+
236+
HostDeviceVector<float> margin_predt;
237+
learner->Predict(dmat, true, &margin_predt, 0, 0, false, false, false, false, false);
238+
239+
HostDeviceVector<float> shap_values;
240+
learner->Predict(dmat, false, &shap_values, 0, 0, false, false, true, false, false);
241+
ASSERT_EQ(shap_values.Size(), kRows * (kCols + 1));
242+
CheckShapAdditivity(kRows, kCols, shap_values, margin_predt);
243+
244+
HostDeviceVector<float> shap_interactions;
245+
learner->Predict(dmat, false, &shap_interactions, 0, 0, false, false, false, false, true);
246+
ASSERT_EQ(shap_interactions.Size(), kRows * (kCols + 1) * (kCols + 1));
247+
CheckShapAdditivity(kRows, kCols, shap_interactions, margin_predt);
248+
}
249+
218250
void CheckShapAdditivity(size_t rows, size_t cols, HostDeviceVector<float> const& shap_values,
219251
HostDeviceVector<float> const& margin_predt) {
220252
auto const& h_shap = shap_values.ConstHostVector();
@@ -254,6 +286,11 @@ TEST(Predictor, ShapOutputCasesCPU) {
254286
}
255287
}
256288

289+
TEST(Predictor, DartShapOutputCPU) {
290+
Context ctx;
291+
CheckDartShapOutput(&ctx);
292+
}
293+
257294
TEST(Predictor, ApproxContribsBasic) {
258295
Context ctx;
259296
size_t constexpr kRows = 64;

tests/cpp/predictor/test_shap.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,9 @@ TEST(GPUPredictor, ShapOutputCasesGPU) {
8484
CheckShapOutput(dmat.get(), args);
8585
}
8686
}
87+
88+
TEST(GPUPredictor, DartShapOutputGPU) {
89+
auto ctx = MakeCUDACtx(0);
90+
CheckDartShapOutput(&ctx);
91+
}
8792
} // namespace xgboost

tests/cpp/predictor/test_shap.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class Learner;
1919

2020
namespace xgboost {
2121
void CheckShapOutput(DMatrix* dmat, Args const& model_args);
22+
void CheckDartShapOutput(Context const* ctx);
2223
void CheckShapAdditivity(size_t rows, size_t cols, HostDeviceVector<float> const& shap_values,
2324
HostDeviceVector<float> const& margin_predt);
2425

0 commit comments

Comments
 (0)