Skip to content

Commit 40d0f79

Browse files
committed
remove specific numpy test use comon
1 parent 6db2edc commit 40d0f79

File tree

2 files changed

+2
-33
lines changed

2 files changed

+2
-33
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,6 +2252,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
22522252
check_array_api_binary_classification_metric,
22532253
check_array_api_multiclass_classification_metric,
22542254
],
2255+
det_curve: [check_array_api_binary_classification_metric],
22552256
f1_score: [
22562257
check_array_api_binary_classification_metric,
22572258
check_array_api_multiclass_classification_metric,

sklearn/metrics/tests/test_ranking.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from scipy import stats
77

8-
from sklearn import config_context, datasets
8+
from sklearn import datasets
99
from sklearn.datasets import make_multilabel_classification
1010
from sklearn.exceptions import UndefinedMetricWarning
1111
from sklearn.linear_model import LogisticRegression
@@ -28,12 +28,7 @@
2828
from sklearn.model_selection import train_test_split
2929
from sklearn.preprocessing import label_binarize
3030
from sklearn.random_projection import _sparse_random_matrix
31-
from sklearn.utils._array_api import (
32-
_convert_to_numpy,
33-
yield_namespace_device_dtype_combinations,
34-
)
3531
from sklearn.utils._testing import (
36-
_array_api_for_tests,
3732
_convert_container,
3833
assert_allclose,
3934
assert_almost_equal,
@@ -1397,33 +1392,6 @@ def test_det_curve_pos_label():
13971392
assert_allclose(fnr_pos_cancer, fpr_pos_not_cancer[::-1])
13981393

13991394

1400-
@pytest.mark.parametrize(
1401-
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
1402-
)
1403-
def test_det_curve_array_api(array_namespace, device_, dtype_name):
1404-
xp = _array_api_for_tests(array_namespace, device_)
1405-
1406-
y_true_np = np.array([0, 1, 0], dtype=dtype_name)
1407-
y_score_np = np.array([0, 0.5, 1], dtype=dtype_name)
1408-
1409-
# baseline numpy results
1410-
fpr_np, fnr_np, thresholds_np = det_curve(y_true_np, y_score_np)
1411-
1412-
y_true_xp = xp.asarray(y_true_np, device=device_)
1413-
y_score_xp = xp.asarray(y_score_np, device=device_)
1414-
1415-
with config_context(array_api_dispatch=True):
1416-
fpr_xp, fnr_xp, thresholds_xp = det_curve(y_true_xp, y_score_xp)
1417-
1418-
fpr_from_xp = _convert_to_numpy(fpr_xp, xp=xp)
1419-
fnr_from_xp = _convert_to_numpy(fnr_xp, xp=xp)
1420-
thresholds_from_xp = _convert_to_numpy(thresholds_xp, xp=xp)
1421-
1422-
assert_allclose(fpr_np, fpr_from_xp)
1423-
assert_allclose(fnr_np, fnr_from_xp)
1424-
assert_allclose(thresholds_np, thresholds_from_xp)
1425-
1426-
14271395
def check_lrap_toy(lrap_score):
14281396
# Check on several small example that it works
14291397
assert_almost_equal(lrap_score([[0, 1]], [[0.25, 0.75]]), 1)

0 commit comments

Comments
 (0)