|
5 | 5 | import pytest |
6 | 6 | from scipy import stats |
7 | 7 |
|
8 | | -from sklearn import config_context, datasets |
| 8 | +from sklearn import datasets |
9 | 9 | from sklearn.datasets import make_multilabel_classification |
10 | 10 | from sklearn.exceptions import UndefinedMetricWarning |
11 | 11 | from sklearn.linear_model import LogisticRegression |
|
28 | 28 | from sklearn.model_selection import train_test_split |
29 | 29 | from sklearn.preprocessing import label_binarize |
30 | 30 | from sklearn.random_projection import _sparse_random_matrix |
31 | | -from sklearn.utils._array_api import ( |
32 | | - _convert_to_numpy, |
33 | | - yield_namespace_device_dtype_combinations, |
34 | | -) |
35 | 31 | from sklearn.utils._testing import ( |
36 | | - _array_api_for_tests, |
37 | 32 | _convert_container, |
38 | 33 | assert_allclose, |
39 | 34 | assert_almost_equal, |
@@ -1397,33 +1392,6 @@ def test_det_curve_pos_label(): |
1397 | 1392 | assert_allclose(fnr_pos_cancer, fpr_pos_not_cancer[::-1]) |
1398 | 1393 |
|
1399 | 1394 |
|
1400 | | -@pytest.mark.parametrize( |
1401 | | - "array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations() |
1402 | | -) |
1403 | | -def test_det_curve_array_api(array_namespace, device_, dtype_name): |
1404 | | - xp = _array_api_for_tests(array_namespace, device_) |
1405 | | - |
1406 | | - y_true_np = np.array([0, 1, 0], dtype=dtype_name) |
1407 | | - y_score_np = np.array([0, 0.5, 1], dtype=dtype_name) |
1408 | | - |
1409 | | - # baseline numpy results |
1410 | | - fpr_np, fnr_np, thresholds_np = det_curve(y_true_np, y_score_np) |
1411 | | - |
1412 | | - y_true_xp = xp.asarray(y_true_np, device=device_) |
1413 | | - y_score_xp = xp.asarray(y_score_np, device=device_) |
1414 | | - |
1415 | | - with config_context(array_api_dispatch=True): |
1416 | | - fpr_xp, fnr_xp, thresholds_xp = det_curve(y_true_xp, y_score_xp) |
1417 | | - |
1418 | | - fpr_from_xp = _convert_to_numpy(fpr_xp, xp=xp) |
1419 | | - fnr_from_xp = _convert_to_numpy(fnr_xp, xp=xp) |
1420 | | - thresholds_from_xp = _convert_to_numpy(thresholds_xp, xp=xp) |
1421 | | - |
1422 | | - assert_allclose(fpr_np, fpr_from_xp) |
1423 | | - assert_allclose(fnr_np, fnr_from_xp) |
1424 | | - assert_allclose(thresholds_np, thresholds_from_xp) |
1425 | | - |
1426 | | - |
1427 | 1395 | def check_lrap_toy(lrap_score): |
1428 | 1396 | # Check on several small example that it works |
1429 | 1397 | assert_almost_equal(lrap_score([[0, 1]], [[0.25, 0.75]]), 1) |
|
0 commit comments