Skip to content

Commit f6160c6

Browse files
committed
Revert "ensure dtypes match"
This reverts commit 54f4b53.
1 parent 7710b5f commit f6160c6

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

sklearn/utils/validation.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,18 +2612,15 @@ def _check_pos_label_consistency(pos_label, y_true):
26122612
# Compute classes only if pos_label is not specified:
26132613
xp, _, device = get_namespace_and_device(y_true)
26142614
classes = xp.unique_values(y_true)
2615-
cls_dtype = classes.dtype
26162615
if (
26172616
(_is_numpy_namespace(xp) and classes.dtype.kind in "OUS")
26182617
or classes.shape[0] > 2
26192618
or not (
2620-
xp.all(classes == xp.asarray([0, 1], dtype=cls_dtype, device=device))
2621-
or xp.all(
2622-
classes == xp.asarray([-1, 1], dtype=cls_dtype, device=device)
2623-
)
2624-
or xp.all(classes == xp.asarray([0], dtype=cls_dtype, device=device))
2625-
or xp.all(classes == xp.asarray([-1], dtype=cls_dtype, device=device))
2626-
or xp.all(classes == xp.asarray([1], dtype=cls_dtype, device=device))
2619+
xp.all(classes == xp.asarray([0, 1], device=device))
2620+
or xp.all(classes == xp.asarray([-1, 1], device=device))
2621+
or xp.all(classes == xp.asarray([0], device=device))
2622+
or xp.all(classes == xp.asarray([-1], device=device))
2623+
or xp.all(classes == xp.asarray([1], device=device))
26272624
)
26282625
):
26292626
classes = _convert_to_numpy(classes, xp=xp)

0 commit comments

Comments
 (0)