Skip to content

Commit 4b7f92c

Browse files
committed
all float y_score_np
I think numpy converts a mixed int/float array (used in the test_ranking.py tests for `coverage_error`) to np.float64. Which is causing errors on MPS
1 parent 1ae8c65 commit 4b7f92c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2129,7 +2129,7 @@ def check_array_api_continuous_classification_metric(
21292129
metric, array_namespace, device, dtype_name
21302130
):
21312131
y_true_np = np.array([[0, 1, 0], [1, 1, 0]])
2132-
y_score_np = np.array([[0.1, 10.0, -3], [0, 1, 3]])
2132+
y_score_np = np.array([[0.1, 10.0, -3], [0.0, 1.0, 3.0]])
21332133

21342134
check_array_api_metric(
21352135
metric,

0 commit comments

Comments
 (0)