Skip to content

Commit 7a97277

Browse files
committed
wip
1 parent d41f734 commit 7a97277

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

sklearn/metrics/_ranking.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
column_or_1d,
3030
)
3131
from sklearn.utils._array_api import (
32+
_average,
3233
_max_precision_float_dtype,
3334
get_namespace_and_device,
3435
size,
@@ -1412,6 +1413,7 @@ def coverage_error(y_true, y_score, *, sample_weight=None):
14121413
>>> coverage_error(y_true, y_score)
14131414
1.5
14141415
"""
1416+
xp, _, device_ = get_namespace_and_device(y_true, y_score, sample_weight)
14151417
y_true = check_array(y_true, ensure_2d=True)
14161418
y_score = check_array(y_score, ensure_2d=True)
14171419
check_consistent_length(y_true, y_score, sample_weight)
@@ -1423,12 +1425,12 @@ def coverage_error(y_true, y_score, *, sample_weight=None):
14231425
if y_true.shape != y_score.shape:
14241426
raise ValueError("y_true and y_score have different shape")
14251427

1426-
y_score_mask = np.ma.masked_array(y_score, mask=np.logical_not(y_true))
1427-
y_min_relevant = y_score_mask.min(axis=1).reshape((-1, 1))
1428-
coverage = (y_score >= y_min_relevant).sum(axis=1)
1429-
coverage = coverage.filled(0)
1430-
1431-
return float(np.average(coverage, weights=sample_weight))
1428+
y_true_logical_not = xp.astype(y_true, xp.bool, device=device_, copy=False)
1429+
inf_val = xp.asarray(xp.inf, dtype=y_score.dtype, device=device_)
1430+
y_score_relevant_only = xp.where(y_true_logical_not, y_score, inf_val)
1431+
y_min_relevant = xp.reshape(xp.min(y_score_relevant_only, axis=1), (-1, 1))
1432+
coverage = xp.count_nonzero(y_score >= y_min_relevant, axis=1)
1433+
return _average(coverage, weights=sample_weight)
14321434

14331435

14341436
@validate_params(

sklearn/metrics/tests/test_common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,6 +2125,35 @@ def check_array_api_multilabel_classification_metric(
21252125
)
21262126

21272127

2128+
def check_array_api_continuous_classification_metric(
2129+
metric, array_namespace, device, dtype_name
2130+
):
2131+
y_true_np = np.array([[0, 1, 0], [1, 1, 0]])
2132+
y_score_np = np.array([[0.1, 10.0, -3], [0, 1, 3]])
2133+
2134+
check_array_api_metric(
2135+
metric,
2136+
array_namespace,
2137+
device,
2138+
dtype_name,
2139+
a_np=y_true_np,
2140+
b_np=y_score_np,
2141+
sample_weight=None,
2142+
)
2143+
2144+
sample_weight = np.array([0.1, 2.0], dtype=dtype_name)
2145+
2146+
check_array_api_metric(
2147+
metric,
2148+
array_namespace,
2149+
device,
2150+
dtype_name,
2151+
a_np=y_true_np,
2152+
b_np=y_score_np,
2153+
sample_weight=sample_weight,
2154+
)
2155+
2156+
21282157
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
21292158
func_name = metric.func.__name__ if isinstance(metric, partial) else metric.__name__
21302159
if func_name == "mean_poisson_deviance" and sp_version < parse_version("1.14.0"):

0 commit comments

Comments
 (0)