Skip to content

Commit 54f4b53

Browse files
committed
ensure dtypes match
1 parent b8a222d commit 54f4b53

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

sklearn/utils/validation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,15 +2612,18 @@ def _check_pos_label_consistency(pos_label, y_true):
26122612
# Compute classes only if pos_label is not specified:
26132613
xp, _, device = get_namespace_and_device(y_true)
26142614
classes = xp.unique_values(y_true)
2615+
cls_dtype = classes.dtype
26152616
if (
26162617
(_is_numpy_namespace(xp) and classes.dtype.kind in "OUS")
26172618
or classes.shape[0] > 2
26182619
or not (
2619-
xp.all(classes == xp.asarray([0, 1], device=device))
2620-
or xp.all(classes == xp.asarray([-1, 1], device=device))
2621-
or xp.all(classes == xp.asarray([0], device=device))
2622-
or xp.all(classes == xp.asarray([-1], device=device))
2623-
or xp.all(classes == xp.asarray([1], device=device))
2620+
xp.all(classes == xp.asarray([0, 1], dtype=cls_dtype, device=device))
2621+
or xp.all(
2622+
classes == xp.asarray([-1, 1], dtype=cls_dtype, device=device)
2623+
)
2624+
or xp.all(classes == xp.asarray([0], dtype=cls_dtype, device=device))
2625+
or xp.all(classes == xp.asarray([-1], dtype=cls_dtype, device=device))
2626+
or xp.all(classes == xp.asarray([1], dtype=cls_dtype, device=device))
26242627
)
26252628
):
26262629
classes = _convert_to_numpy(classes, xp=xp)

0 commit comments

Comments
 (0)