Skip to content

Commit 39ddf40

Browse files
authored
[backport] Optimize prediction with QuantileDMatrix. (#9096) (#9303)
1 parent 573f1c7 commit 39ddf40

File tree

5 files changed

+64
-37
lines changed

5 files changed

+64
-37
lines changed

include/xgboost/tree_model.h

+5-8
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ class RegTree : public Model {
508508
* \brief drop the trace after fill, must be called after fill.
509509
* \param inst The sparse instance to drop.
510510
*/
511-
void Drop(const SparsePage::Inst& inst);
511+
void Drop();
512512
/*!
513513
* \brief returns the size of the feature vector
514514
* \return the size of the feature vector
@@ -709,13 +709,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
709709
has_missing_ = data_.size() != feature_count;
710710
}
711711

712-
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
713-
for (auto const& entry : inst) {
714-
if (entry.index >= data_.size()) {
715-
continue;
716-
}
717-
data_[entry.index].flag = -1;
718-
}
712+
inline void RegTree::FVec::Drop() {
713+
Entry e{};
714+
e.flag = -1;
715+
std::fill_n(data_.data(), data_.size(), e);
719716
has_missing_ = true;
720717
}
721718

src/data/gradient_index.cc

+35-14
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,28 @@ common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
149149
return *columns_;
150150
}
151151

152+
bst_bin_t GHistIndexMatrix::GetGindex(size_t ridx, size_t fidx) const {
153+
auto begin = RowIdx(ridx);
154+
if (IsDense()) {
155+
return static_cast<bst_bin_t>(index[begin + fidx]);
156+
}
157+
auto end = RowIdx(ridx + 1);
158+
auto const& cut_ptrs = cut.Ptrs();
159+
auto f_begin = cut_ptrs[fidx];
160+
auto f_end = cut_ptrs[fidx + 1];
161+
return BinarySearchBin(begin, end, index, f_begin, f_end);
162+
}
163+
152164
float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
153165
auto const &values = cut.Values();
154166
auto const &mins = cut.MinValues();
155167
auto const &ptrs = cut.Ptrs();
168+
return this->GetFvalue(ptrs, values, mins, ridx, fidx, is_cat);
169+
}
170+
171+
float GHistIndexMatrix::GetFvalue(std::vector<std::uint32_t> const &ptrs,
172+
std::vector<float> const &values, std::vector<float> const &mins,
173+
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const {
156174
if (is_cat) {
157175
auto f_begin = ptrs[fidx];
158176
auto f_end = ptrs[fidx + 1];
@@ -172,24 +190,27 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
172190
}
173191
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
174192
};
175-
176-
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
177-
if (columns_->AnyMissing()) {
178-
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
179-
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
180-
return get_bin_val(column);
181-
});
182-
} else {
193+
switch (columns_->GetColumnType(fidx)) {
194+
case common::kDenseColumn: {
195+
if (columns_->AnyMissing()) {
196+
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
197+
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
198+
return get_bin_val(column);
199+
});
200+
} else {
201+
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
202+
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
203+
auto bin_idx = column[ridx];
204+
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
205+
});
206+
}
207+
}
208+
case common::kSparseColumn: {
183209
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
184-
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
210+
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
185211
return get_bin_val(column);
186212
});
187213
}
188-
} else {
189-
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
190-
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
191-
return get_bin_val(column);
192-
});
193214
}
194215

195216
SPAN_CHECK(false);

src/data/gradient_index.h

+5
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,12 @@ class GHistIndexMatrix {
227227

228228
common::ColumnMatrix const& Transpose() const;
229229

230+
bst_bin_t GetGindex(size_t ridx, size_t fidx) const;
231+
230232
float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
233+
float GetFvalue(std::vector<std::uint32_t> const& ptrs, std::vector<float> const& values,
234+
std::vector<float> const& mins, bst_row_t ridx, bst_feature_t fidx,
235+
bool is_cat) const;
231236

232237
private:
233238
std::unique_ptr<common::ColumnMatrix> columns_;

src/predictor/cpu_predictor.cc

+18-14
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ bst_float PredValue(const SparsePage::Inst &inst,
6363
psum += (*trees[i])[nidx].LeafValue();
6464
}
6565
}
66-
p_feats->Drop(inst);
66+
p_feats->Drop();
6767
return psum;
6868
}
6969

@@ -116,13 +116,11 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
116116
}
117117
}
118118

119-
template <typename DataView>
120-
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
121-
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
119+
void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset,
120+
std::vector<RegTree::FVec> *p_feats) {
122121
for (size_t i = 0; i < block_size; ++i) {
123122
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
124-
const SparsePage::Inst inst = (*batch)[batch_offset + i];
125-
feats.Drop(inst);
123+
feats.Drop();
126124
}
127125
}
128126

@@ -142,11 +140,15 @@ struct SparsePageView {
142140
struct GHistIndexMatrixView {
143141
private:
144142
GHistIndexMatrix const &page_;
145-
uint64_t n_features_;
143+
std::uint64_t const n_features_;
146144
common::Span<FeatureType const> ft_;
147145
common::Span<Entry> workspace_;
148146
std::vector<size_t> current_unroll_;
149147

148+
std::vector<std::uint32_t> const& ptrs_;
149+
std::vector<float> const& mins_;
150+
std::vector<float> const& values_;
151+
150152
public:
151153
size_t base_rowid;
152154

@@ -159,6 +161,9 @@ struct GHistIndexMatrixView {
159161
ft_{ft},
160162
workspace_{workplace},
161163
current_unroll_(n_threads > 0 ? n_threads : 1, 0),
164+
ptrs_{_page.cut.Ptrs()},
165+
mins_{_page.cut.MinValues()},
166+
values_{_page.cut.Values()},
162167
base_rowid{_page.base_rowid} {}
163168

164169
SparsePage::Inst operator[](size_t r) {
@@ -167,7 +172,7 @@ struct GHistIndexMatrixView {
167172
size_t non_missing{beg};
168173

169174
for (bst_feature_t c = 0; c < n_features_; ++c) {
170-
float f = page_.GetFvalue(r, c, common::IsCat(ft_, c));
175+
float f = page_.GetFvalue(ptrs_, values_, mins_, r, c, common::IsCat(ft_, c));
171176
if (!common::CheckNAN(f)) {
172177
workspace_[non_missing] = Entry{c, f};
173178
++non_missing;
@@ -250,10 +255,9 @@ void PredictBatchByBlockOfRowsKernel(
250255
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
251256
p_thread_temp);
252257
// process block of rows through all trees to keep cache locality
253-
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
254-
batch_offset + batch.base_rowid, num_group, thread_temp,
255-
fvec_offset, block_size);
256-
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
258+
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
259+
num_group, thread_temp, fvec_offset, block_size);
260+
FVecDrop(block_size, fvec_offset, p_thread_temp);
257261
});
258262
}
259263

@@ -470,7 +474,7 @@ class CPUPredictor : public Predictor {
470474
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
471475
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
472476
}
473-
feats.Drop(page[i]);
477+
feats.Drop();
474478
});
475479
}
476480
}
@@ -544,7 +548,7 @@ class CPUPredictor : public Predictor {
544548
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
545549
}
546550
}
547-
feats.Drop(page[i]);
551+
feats.Drop();
548552
// add base margin to BIAS
549553
if (base_margin.Size() != 0) {
550554
CHECK_EQ(base_margin.Shape(1), ngroup);

src/tree/updater_refresh.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class TreeRefresher : public TreeUpdater {
8989
dmlc::BeginPtr(stemp[tid]) + offset);
9090
offset += tree->param.num_nodes;
9191
}
92-
feats.Drop(inst);
92+
feats.Drop();
9393
});
9494
}
9595
// aggregate the statistics

0 commit comments

Comments
 (0)