@@ -2612,18 +2612,15 @@ def _check_pos_label_consistency(pos_label, y_true):
26122612 # Compute classes only if pos_label is not specified:
26132613 xp , _ , device = get_namespace_and_device (y_true )
26142614 classes = xp .unique_values (y_true )
2615- cls_dtype = classes .dtype
26162615 if (
26172616 (_is_numpy_namespace (xp ) and classes .dtype .kind in "OUS" )
26182617 or classes .shape [0 ] > 2
26192618 or not (
2620- xp .all (classes == xp .asarray ([0 , 1 ], dtype = cls_dtype , device = device ))
2621- or xp .all (
2622- classes == xp .asarray ([- 1 , 1 ], dtype = cls_dtype , device = device )
2623- )
2624- or xp .all (classes == xp .asarray ([0 ], dtype = cls_dtype , device = device ))
2625- or xp .all (classes == xp .asarray ([- 1 ], dtype = cls_dtype , device = device ))
2626- or xp .all (classes == xp .asarray ([1 ], dtype = cls_dtype , device = device ))
2619+ xp .all (classes == xp .asarray ([0 , 1 ], device = device ))
2620+ or xp .all (classes == xp .asarray ([- 1 , 1 ], device = device ))
2621+ or xp .all (classes == xp .asarray ([0 ], device = device ))
2622+ or xp .all (classes == xp .asarray ([- 1 ], device = device ))
2623+ or xp .all (classes == xp .asarray ([1 ], device = device ))
26272624 )
26282625 ):
26292626 classes = _convert_to_numpy (classes , xp = xp )
0 commit comments