Skip to content

Commit fa3f4d1

Browse files
committed
add array api support for det_curve
1 parent 72d97bc commit fa3f4d1

File tree

3 files changed

+55
-12
lines changed

3 files changed

+55
-12
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ Metrics
151151
- :func:`sklearn.metrics.d2_brier_score`
152152
- :func:`sklearn.metrics.d2_log_loss_score`
153153
- :func:`sklearn.metrics.d2_tweedie_score`
154+
- :func:`sklearn.metrics.det_curve`
154155
- :func:`sklearn.metrics.explained_variance_score`
155156
- :func:`sklearn.metrics.f1_score`
156157
- :func:`sklearn.metrics.fbeta_score`

sklearn/metrics/_ranking.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,17 @@ def det_curve(
372372
>>> thresholds
373373
array([0.35, 0.4 , 0.8 ])
374374
"""
375+
376+
xp, _, device = get_namespace_and_device(y_true, y_score)
375377
fps, tps, thresholds = _binary_clf_curve(
376378
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
377379
)
378380

379381
# add a threshold at inf where the clf always predicts the negative class
380382
# i.e. tps = fps = 0
381-
tps = np.concatenate(([0], tps))
382-
fps = np.concatenate(([0], fps))
383-
thresholds = np.concatenate(([np.inf], thresholds))
383+
tps = xp.concatenate((xp.asarray([0], device=device), tps))
384+
fps = xp.concatenate((xp.asarray([0], device=device), fps))
385+
thresholds = xp.concatenate((xp.asarray([np.inf], device=device), thresholds))
384386

385387
if drop_intermediate and len(fps) > 2:
386388
# Drop thresholds where true positives (tp) do not change from the
@@ -389,16 +391,20 @@ def det_curve(
389391
# false positive rate (fpr) changes, producing horizontal line segments
390392
# in the transformed (normal deviate) scale. These intermediate points
391393
# can be dropped to create lighter DET curve plots.
392-
optimal_idxs = np.where(
393-
np.concatenate(
394-
[[True], np.logical_or(np.diff(tps[:-1]), np.diff(tps[1:])), [True]]
394+
optimal_idxs = xp.where(
395+
xp.concatenate(
396+
[
397+
xp.asarray([True], device=device),
398+
xp.logical_or(xp.diff(tps[:-1]), xp.diff(tps[1:])),
399+
xp.asarray([True], device=device),
400+
]
395401
)
396402
)[0]
397403
fps = fps[optimal_idxs]
398404
tps = tps[optimal_idxs]
399405
thresholds = thresholds[optimal_idxs]
400406

401-
if len(np.unique(y_true)) != 2:
407+
if len(xp.unique(y_true)) != 2:
402408
raise ValueError(
403409
"Only one class is present in y_true. Detection error "
404410
"tradeoff curve is not defined in that case."
@@ -410,16 +416,20 @@ def det_curve(
410416

411417
# start with false positives zero, which may be at a finite threshold
412418
first_ind = (
413-
fps.searchsorted(fps[0], side="right") - 1
414-
if fps.searchsorted(fps[0], side="right") > 0
419+
xp.searchsorted(fps, fps[0], side="right") - 1
420+
if xp.searchsorted(fps, fps[0], side="right") > 0
415421
else None
416422
)
417423
# stop with false negatives zero
418-
last_ind = tps.searchsorted(tps[-1]) + 1
424+
last_ind = xp.searchsorted(tps, tps[-1]) + 1
419425
sl = slice(first_ind, last_ind)
420426

421427
# reverse the output such that list of false positives is decreasing
422-
return (fps[sl][::-1] / n_count, fns[sl][::-1] / p_count, thresholds[sl][::-1])
428+
return (
429+
xp.flip(fps[sl]) / n_count,
430+
xp.flip(fns[sl]) / p_count,
431+
xp.flip(thresholds[sl]),
432+
)
423433

424434

425435
def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):

sklearn/metrics/tests/test_ranking.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from scipy import stats
77

8-
from sklearn import datasets
8+
from sklearn import config_context, datasets
99
from sklearn.datasets import make_multilabel_classification
1010
from sklearn.exceptions import UndefinedMetricWarning
1111
from sklearn.linear_model import LogisticRegression
@@ -28,7 +28,12 @@
2828
from sklearn.model_selection import train_test_split
2929
from sklearn.preprocessing import label_binarize
3030
from sklearn.random_projection import _sparse_random_matrix
31+
from sklearn.utils._array_api import (
32+
_convert_to_numpy,
33+
yield_namespace_device_dtype_combinations,
34+
)
3135
from sklearn.utils._testing import (
36+
_array_api_for_tests,
3237
_convert_container,
3338
assert_allclose,
3439
assert_almost_equal,
@@ -1392,6 +1397,33 @@ def test_det_curve_pos_label():
13921397
assert_allclose(fnr_pos_cancer, fpr_pos_not_cancer[::-1])
13931398

13941399

1400+
@pytest.mark.parametrize(
1401+
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
1402+
)
1403+
def test_det_curve_array_api(array_namespace, device_, dtype_name):
1404+
xp = _array_api_for_tests(array_namespace, device_)
1405+
1406+
y_true_np = np.array([0, 1, 0], dtype=dtype_name)
1407+
y_score_np = np.array([0, 0.5, 1], dtype=dtype_name)
1408+
1409+
# baseline numpy results
1410+
fpr_np, fnr_np, thresholds_np = det_curve(y_true_np, y_score_np)
1411+
1412+
y_true_xp = xp.asarray(y_true_np, device=device_)
1413+
y_score_xp = xp.asarray(y_score_np, device=device_)
1414+
1415+
with config_context(array_api_dispatch=True):
1416+
fpr_xp, fnr_xp, thresholds_xp = det_curve(y_true_xp, y_score_xp)
1417+
1418+
fpr_from_xp = _convert_to_numpy(fpr_xp, xp=xp)
1419+
fnr_from_xp = _convert_to_numpy(fnr_xp, xp=xp)
1420+
thresholds_from_xp = _convert_to_numpy(thresholds_xp, xp=xp)
1421+
1422+
assert_allclose(fpr_np, fpr_from_xp)
1423+
assert_allclose(fnr_np, fnr_from_xp)
1424+
assert_allclose(thresholds_np, thresholds_from_xp)
1425+
1426+
13951427
def check_lrap_toy(lrap_score):
13961428
# Check on several small example that it works
13971429
assert_almost_equal(lrap_score([[0, 1]], [[0.25, 0.75]]), 1)

0 commit comments

Comments
 (0)