Skip to content

Commit 8147d78

Browse files
authored
[back port] Fix multiclass auc with empty dataset. (#6947) (#6960)
1 parent 651c4ac commit 8147d78

File tree

5 files changed

+83
-48
lines changed

5 files changed

+83
-48
lines changed

src/metric/auc.cc

+12-7
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ std::tuple<float, float, float> BinaryAUC(std::vector<float> const &predts,
8787
* - Kleiman, Ross and Page, David. $AUC_{\mu}$: A Performance Metric for Multi-Class
8888
* Machine Learning Models
8989
*/
90-
float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info) {
91-
auto n_classes = predts.size() / info.labels_.Size();
90+
float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info, size_t n_classes) {
9291
CHECK_NE(n_classes, 0);
9392
auto const& labels = info.labels_.ConstHostVector();
9493

@@ -230,6 +229,10 @@ class EvalAUC : public Metric {
230229
info.labels_.SetDevice(tparam_->gpu_id);
231230
info.weights_.SetDevice(tparam_->gpu_id);
232231
}
232+
// We use the global size to handle empty dataset.
233+
std::array<size_t, 2> meta{info.labels_.Size(), preds.Size()};
234+
rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size());
235+
233236
if (!info.group_ptr_.empty()) {
234237
/**
235238
* learning to rank
@@ -261,16 +264,17 @@ class EvalAUC : public Metric {
261264
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
262265
<< ", valid groups: " << valid_groups;
263266
}
264-
} else if (info.labels_.Size() != preds.Size() &&
265-
preds.Size() % info.labels_.Size() == 0) {
267+
} else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) {
266268
/**
267269
* multi class
268270
*/
271+
size_t n_classes = meta[1] / meta[0];
272+
CHECK_NE(n_classes, 0);
269273
if (tparam_->gpu_id == GenericParameter::kCpuId) {
270-
auc = MultiClassOVR(preds.ConstHostVector(), info);
274+
auc = MultiClassOVR(preds.ConstHostVector(), info, n_classes);
271275
} else {
272276
auc = GPUMultiClassAUCOVR(preds.ConstDeviceSpan(), info, tparam_->gpu_id,
273-
&this->d_cache_);
277+
&this->d_cache_, n_classes);
274278
}
275279
} else {
276280
/**
@@ -323,7 +327,8 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
323327
}
324328

325329
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
326-
int32_t device, std::shared_ptr<DeviceAUCCache>* cache) {
330+
int32_t device, std::shared_ptr<DeviceAUCCache>* cache,
331+
size_t n_classes) {
327332
common::AssertGPUSupport();
328333
return 0;
329334
}

src/metric/auc.cu

+55-31
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,12 @@ struct DeviceAUCCache {
6161
neg_pos.resize(sorted_idx.size());
6262
if (is_multi) {
6363
predts_t.resize(sorted_idx.size());
64-
reducer.reset(new dh::AllReducer);
65-
reducer->Init(rabit::GetRank());
6664
}
6765
}
66+
if (is_multi && !reducer) {
67+
reducer.reset(new dh::AllReducer);
68+
reducer->Init(device);
69+
}
6870
}
6971
};
7072

@@ -197,12 +199,48 @@ XGBOOST_DEVICE size_t LastOf(size_t group, common::Span<Idx> indptr) {
197199
return indptr[group + 1] - 1;
198200
}
199201

202+
203+
float ScaleClasses(common::Span<float> results, common::Span<float> local_area,
204+
common::Span<float> fp, common::Span<float> tp,
205+
common::Span<float> auc, std::shared_ptr<DeviceAUCCache> cache,
206+
size_t n_classes) {
207+
dh::XGBDeviceAllocator<char> alloc;
208+
if (rabit::IsDistributed()) {
209+
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
210+
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
211+
}
212+
auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>(
213+
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
214+
if (local_area[i] > 0) {
215+
return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
216+
}
217+
return thrust::make_pair(std::numeric_limits<float>::quiet_NaN(), 0.0f);
218+
});
219+
220+
float tp_sum;
221+
float auc_sum;
222+
thrust::tie(auc_sum, tp_sum) = thrust::reduce(
223+
thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
224+
thrust::make_pair(0.0f, 0.0f),
225+
[=] __device__(auto const &l, auto const &r) {
226+
return thrust::make_pair(l.first + r.first, l.second + r.second);
227+
});
228+
if (tp_sum != 0 && !std::isnan(auc_sum)) {
229+
auc_sum /= tp_sum;
230+
} else {
231+
return std::numeric_limits<float>::quiet_NaN();
232+
}
233+
return auc_sum;
234+
}
235+
200236
/**
201237
* MultiClass implementation is similar to binary classification, except we need to split
202238
* up each class in all kernels.
203239
*/
204240
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
205-
int32_t device, std::shared_ptr<DeviceAUCCache>* p_cache) {
241+
int32_t device, std::shared_ptr<DeviceAUCCache>* p_cache,
242+
size_t n_classes) {
243+
dh::safe_cuda(cudaSetDevice(device));
206244
auto& cache = *p_cache;
207245
if (!cache) {
208246
cache.reset(new DeviceAUCCache);
@@ -213,8 +251,19 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
213251
auto weights = info.weights_.ConstDeviceSpan();
214252

215253
size_t n_samples = labels.size();
216-
size_t n_classes = predts.size() / labels.size();
217-
CHECK_NE(n_classes, 0);
254+
255+
if (n_samples == 0) {
256+
dh::TemporaryArray<float> resutls(n_classes * 4, 0.0f);
257+
auto d_results = dh::ToSpan(resutls);
258+
dh::LaunchN(device, n_classes * 4, [=]__device__(size_t i) {
259+
d_results[i] = 0.0f;
260+
});
261+
auto local_area = d_results.subspan(0, n_classes);
262+
auto fp = d_results.subspan(n_classes, n_classes);
263+
auto tp = d_results.subspan(2 * n_classes, n_classes);
264+
auto auc = d_results.subspan(3 * n_classes, n_classes);
265+
return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
266+
}
218267

219268
/**
220269
* Create sorted index for each class
@@ -377,32 +426,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
377426
tp[c] = last.second;
378427
local_area[c] = last.first * last.second;
379428
});
380-
if (rabit::IsDistributed()) {
381-
cache->reducer->AllReduceSum(resutls.data().get(), resutls.data().get(),
382-
resutls.size());
383-
}
384-
auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>(
385-
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
386-
if (local_area[i] > 0) {
387-
return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
388-
}
389-
return thrust::make_pair(std::numeric_limits<float>::quiet_NaN(), 0.0f);
390-
});
391-
392-
float tp_sum;
393-
float auc_sum;
394-
thrust::tie(auc_sum, tp_sum) = thrust::reduce(
395-
thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
396-
thrust::make_pair(0.0f, 0.0f),
397-
[=] __device__(auto const &l, auto const &r) {
398-
return thrust::make_pair(l.first + r.first, l.second + r.second);
399-
});
400-
if (tp_sum != 0 && !std::isnan(auc_sum)) {
401-
auc_sum /= tp_sum;
402-
} else {
403-
return std::numeric_limits<float>::quiet_NaN();
404-
}
405-
return auc_sum;
429+
return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
406430
}
407431

408432
namespace {

src/metric/auc.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
2626
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache);
2727

2828
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
29-
int32_t device, std::shared_ptr<DeviceAUCCache>* cache);
29+
int32_t device, std::shared_ptr<DeviceAUCCache>* cache,
30+
size_t n_classes);
3031

3132
std::pair<float, uint32_t>
3233
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,

tests/python-gpu/test_gpu_with_dask.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def test_dask_classifier(self, model, local_cuda_cluster: LocalCUDACluster) -> N
277277
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
278278
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
279279
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
280-
run_dask_classifier(X, y, w, model, client, 10)
280+
run_dask_classifier(X, y, w, model, "gpu_hist", client, 10)
281281

282282
@pytest.mark.skipif(**tm.no_dask())
283283
@pytest.mark.skipif(**tm.no_dask_cuda())

tests/python/test_with_dask.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -317,18 +317,19 @@ def run_dask_classifier(
317317
y: xgb.dask._DaskCollection,
318318
w: xgb.dask._DaskCollection,
319319
model: str,
320+
tree_method: Optional[str],
320321
client: "Client",
321322
n_classes,
322323
) -> None:
323324
metric = "merror" if n_classes > 2 else "logloss"
324325

325326
if model == "boosting":
326327
classifier = xgb.dask.DaskXGBClassifier(
327-
verbosity=1, n_estimators=2, eval_metric=metric
328+
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
328329
)
329330
else:
330331
classifier = xgb.dask.DaskXGBRFClassifier(
331-
verbosity=1, n_estimators=2, eval_metric=metric
332+
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
332333
)
333334

334335
assert classifier._estimator_type == "classifier"
@@ -397,12 +398,12 @@ def run_dask_classifier(
397398
def test_dask_classifier(model: str, client: "Client") -> None:
398399
X, y, w = generate_array(with_weights=True)
399400
y = (y * 10).astype(np.int32)
400-
run_dask_classifier(X, y, w, model, client, 10)
401+
run_dask_classifier(X, y, w, model, None, client, 10)
401402

402403
y_bin = y.copy()
403404
y_bin[y > 5] = 1.0
404405
y_bin[y <= 5] = 0.0
405-
run_dask_classifier(X, y_bin, w, model, client, 2)
406+
run_dask_classifier(X, y_bin, w, model, None, client, 2)
406407

407408

408409
@pytest.mark.skipif(**tm.no_sklearn())
@@ -568,22 +569,26 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
568569
# multiclass
569570
X_, y_ = make_classification(
570571
n_samples=n_samples,
571-
n_classes=10,
572+
n_classes=n_workers,
572573
n_informative=n_features,
573574
n_redundant=0,
574575
n_repeated=0
575576
)
577+
for i in range(y_.shape[0]):
578+
y_[i] = i % n_workers
576579
X = dd.from_array(X_, chunksize=10)
577580
y = dd.from_array(y_, chunksize=10)
578581

579582
n_samples = n_workers - 1
580583
valid_X_, valid_y_ = make_classification(
581584
n_samples=n_samples,
582-
n_classes=10,
585+
n_classes=n_workers,
583586
n_informative=n_features,
584587
n_redundant=0,
585588
n_repeated=0
586589
)
590+
for i in range(valid_y_.shape[0]):
591+
valid_y_[i] = i % n_workers
587592
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
588593
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
589594

@@ -594,9 +599,9 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
594599

595600

596601
def test_empty_dmatrix_auc() -> None:
597-
with LocalCluster(n_workers=2) as cluster:
602+
with LocalCluster(n_workers=8) as cluster:
598603
with Client(cluster) as client:
599-
run_empty_dmatrix_auc(client, "hist", 2)
604+
run_empty_dmatrix_auc(client, "hist", 8)
600605

601606

602607
def run_auc(client: "Client", tree_method: str) -> None:

0 commit comments

Comments
 (0)