@@ -253,17 +253,32 @@ class CPUPredictor : public Predictor {
253
253
gbm::GBTreeModel const &model, int32_t tree_begin,
254
254
int32_t tree_end) const {
255
255
const int threads = omp_get_max_threads ();
256
+ constexpr double kDensityThresh = .5 ;
257
+ size_t total = std::max (p_fmat->Info ().num_row_ * p_fmat->Info ().num_col_ ,
258
+ static_cast <uint64_t >(1 ));
259
+ double density = static_cast <double >(p_fmat->Info ().num_nonzero_ ) /
260
+ static_cast <double >(total);
261
+ bool blocked = density > kDensityThresh ;
262
+
256
263
std::vector<RegTree::FVec> feat_vecs;
257
- InitThreadTemp (threads * kBlockOfRowsSize ,
264
+ InitThreadTemp (threads * (blocked ? kBlockOfRowsSize : 1 ) ,
258
265
model.learner_model_param ->num_feature , &feat_vecs);
259
- for (auto const & batch : p_fmat->GetBatches <SparsePage>()) {
266
+ for (auto const & batch : p_fmat->GetBatches <SparsePage>()) {
260
267
CHECK_EQ (out_preds->size (),
261
- p_fmat->Info ().num_row_ * model.learner_model_param ->num_output_group );
268
+ p_fmat->Info ().num_row_ *
269
+ model.learner_model_param ->num_output_group );
262
270
size_t constexpr kUnroll = 8 ;
263
- PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll >,
264
- kBlockOfRowsSize >(SparsePageView<kUnroll >{&batch},
265
- out_preds, model, tree_begin,
266
- tree_end, &feat_vecs);
271
+ if (blocked) {
272
+ PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll >,
273
+ kBlockOfRowsSize >(
274
+ SparsePageView<kUnroll >{&batch}, out_preds, model, tree_begin,
275
+ tree_end, &feat_vecs);
276
+
277
+ } else {
278
+ PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll >, 1 >(
279
+ SparsePageView<kUnroll >{&batch}, out_preds, model, tree_begin,
280
+ tree_end, &feat_vecs);
281
+ }
267
282
}
268
283
}
269
284
@@ -316,7 +331,7 @@ class CPUPredictor : public Predictor {
316
331
tree_end);
317
332
}
318
333
319
- template <typename Adapter>
334
+ template <typename Adapter, size_t kBlockSize >
320
335
void DispatchedInplacePredict (dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
321
336
const gbm::GBTreeModel &model, float missing,
322
337
PredictionCacheEntry *out_preds,
@@ -336,9 +351,9 @@ class CPUPredictor : public Predictor {
336
351
std::vector<Entry> workspace (m->NumColumns () * 8 * threads);
337
352
auto &predictions = out_preds->predictions .HostVector ();
338
353
std::vector<RegTree::FVec> thread_temp;
339
- InitThreadTemp (threads * kBlockOfRowsSize ,
340
- model. learner_model_param -> num_feature , &thread_temp);
341
- PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockOfRowsSize >(
354
+ InitThreadTemp (threads * kBlockSize , model. learner_model_param -> num_feature ,
355
+ &thread_temp);
356
+ PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize >(
342
357
AdapterView<Adapter>(m.get (), missing, common::Span<Entry>{workspace}),
343
358
&predictions, model, tree_begin, tree_end, &thread_temp);
344
359
}
@@ -348,16 +363,16 @@ class CPUPredictor : public Predictor {
348
363
PredictionCacheEntry *out_preds, uint32_t tree_begin,
349
364
unsigned tree_end) const override {
350
365
if (x.type () == typeid (std::shared_ptr<data::DenseAdapter>)) {
351
- this ->DispatchedInplacePredict <data::DenseAdapter>(
366
+ this ->DispatchedInplacePredict <data::DenseAdapter, kBlockOfRowsSize >(
352
367
x, p_m, model, missing, out_preds, tree_begin, tree_end);
353
368
} else if (x.type () == typeid (std::shared_ptr<data::CSRAdapter>)) {
354
- this ->DispatchedInplacePredict <data::CSRAdapter>(
369
+ this ->DispatchedInplacePredict <data::CSRAdapter, 1 >(
355
370
x, p_m, model, missing, out_preds, tree_begin, tree_end);
356
371
} else if (x.type () == typeid (std::shared_ptr<data::ArrayAdapter>)) {
357
- this ->DispatchedInplacePredict <data::ArrayAdapter> (
372
+ this ->DispatchedInplacePredict <data::ArrayAdapter, kBlockOfRowsSize > (
358
373
x, p_m, model, missing, out_preds, tree_begin, tree_end);
359
374
} else if (x.type () == typeid (std::shared_ptr<data::CSRArrayAdapter>)) {
360
- this ->DispatchedInplacePredict <data::CSRArrayAdapter> (
375
+ this ->DispatchedInplacePredict <data::CSRArrayAdapter, 1 > (
361
376
x, p_m, model, missing, out_preds, tree_begin, tree_end);
362
377
} else {
363
378
return false ;
0 commit comments