Skip to content

Commit 18686e3

Browse files
committed
cleanup
1 parent 7a97277 commit 18686e3

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

sklearn/metrics/_ranking.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,11 +1425,11 @@ def coverage_error(y_true, y_score, *, sample_weight=None):
14251425
if y_true.shape != y_score.shape:
14261426
raise ValueError("y_true and y_score have different shape")
14271427

1428-
y_true_logical_not = xp.astype(y_true, xp.bool, device=device_, copy=False)
1429-
inf_val = xp.asarray(xp.inf, dtype=y_score.dtype, device=device_)
1430-
y_score_relevant_only = xp.where(y_true_logical_not, y_score, inf_val)
1431-
y_min_relevant = xp.reshape(xp.min(y_score_relevant_only, axis=1), (-1, 1))
1428+
y_true_bool = xp.astype(y_true, xp.bool, device=device_, copy=False)
1429+
y_score_masked = xp.where(y_true_bool, y_score, xp.inf)
1430+
y_min_relevant = xp.reshape(xp.min(y_score_masked, axis=1), (-1, 1))
14321431
coverage = xp.count_nonzero(y_score >= y_min_relevant, axis=1)
1432+
14331433
return _average(coverage, weights=sample_weight)
14341434

14351435

0 commit comments

Comments
 (0)