@@ -2612,15 +2612,18 @@ def _check_pos_label_consistency(pos_label, y_true):
26122612 # Compute classes only if pos_label is not specified:
26132613 xp , _ , device = get_namespace_and_device (y_true )
26142614 classes = xp .unique_values (y_true )
2615+ cls_dtype = classes .dtype
26152616 if (
26162617 (_is_numpy_namespace (xp ) and classes .dtype .kind in "OUS" )
26172618 or classes .shape [0 ] > 2
26182619 or not (
2619- xp .all (classes == xp .asarray ([0 , 1 ], device = device ))
2620- or xp .all (classes == xp .asarray ([- 1 , 1 ], device = device ))
2621- or xp .all (classes == xp .asarray ([0 ], device = device ))
2622- or xp .all (classes == xp .asarray ([- 1 ], device = device ))
2623- or xp .all (classes == xp .asarray ([1 ], device = device ))
2620+ xp .all (classes == xp .asarray ([0 , 1 ], dtype = cls_dtype , device = device ))
2621+ or xp .all (
2622+ classes == xp .asarray ([- 1 , 1 ], dtype = cls_dtype , device = device )
2623+ )
2624+ or xp .all (classes == xp .asarray ([0 ], dtype = cls_dtype , device = device ))
2625+ or xp .all (classes == xp .asarray ([- 1 ], dtype = cls_dtype , device = device ))
2626+ or xp .all (classes == xp .asarray ([1 ], dtype = cls_dtype , device = device ))
26242627 )
26252628 ):
26262629 classes = _convert_to_numpy (classes , xp = xp )
0 commit comments