| 
33 | 33 |     _bincount,  | 
34 | 34 |     _convert_to_numpy,  | 
35 | 35 |     _count_nonzero,  | 
 | 36 | +    _fill_diagonal,  | 
36 | 37 |     _find_matching_floating_dtype,  | 
37 | 38 |     _is_numpy_namespace,  | 
38 | 39 |     _is_xp_namespace,  | 
@@ -970,23 +971,30 @@ class labels [2]_.  | 
970 | 971 |             raise ValueError(msg) from e  | 
971 | 972 |         raise  | 
972 | 973 | 
 
  | 
 | 974 | +    xp, _, device_ = get_namespace_and_device(y1, y2)  | 
973 | 975 |     n_classes = confusion.shape[0]  | 
974 |  | -    sum0 = np.sum(confusion, axis=0)  | 
975 |  | -    sum1 = np.sum(confusion, axis=1)  | 
976 |  | -    expected = np.outer(sum0, sum1) / np.sum(sum0)  | 
 | 976 | +    # array_api_strict only supports floating point dtypes for __truediv__  | 
 | 977 | +    # which is used below to compute `expected` as well as `k`. Therefore  | 
 | 978 | +    # we use the maximum floating point dtype available for relevant arrays  | 
 | 979 | +    # to avoid running into this problem.  | 
 | 980 | +    max_float_dtype = _max_precision_float_dtype(xp, device=device_)  | 
 | 981 | +    confusion = xp.astype(confusion, max_float_dtype, copy=False)  | 
 | 982 | +    sum0 = xp.sum(confusion, axis=0)  | 
 | 983 | +    sum1 = xp.sum(confusion, axis=1)  | 
 | 984 | +    expected = xp.linalg.outer(sum0, sum1) / xp.sum(sum0)  | 
977 | 985 | 
 
  | 
978 | 986 |     if weights is None:  | 
979 |  | -        w_mat = np.ones([n_classes, n_classes], dtype=int)  | 
980 |  | -        w_mat.flat[:: n_classes + 1] = 0  | 
 | 987 | +        w_mat = xp.ones([n_classes, n_classes], dtype=max_float_dtype, device=device_)  | 
 | 988 | +        _fill_diagonal(w_mat, 0, xp=xp)  | 
981 | 989 |     else:  # "linear" or "quadratic"  | 
982 |  | -        w_mat = np.zeros([n_classes, n_classes], dtype=int)  | 
983 |  | -        w_mat += np.arange(n_classes)  | 
 | 990 | +        w_mat = xp.zeros([n_classes, n_classes], dtype=max_float_dtype, device=device_)  | 
 | 991 | +        w_mat += xp.arange(n_classes)  | 
984 | 992 |         if weights == "linear":  | 
985 |  | -            w_mat = np.abs(w_mat - w_mat.T)  | 
 | 993 | +            w_mat = xp.abs(w_mat - w_mat.T)  | 
986 | 994 |         else:  | 
987 | 995 |             w_mat = (w_mat - w_mat.T) ** 2  | 
988 | 996 | 
 
  | 
989 |  | -    k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)  | 
 | 997 | +    k = xp.sum(w_mat * confusion) / xp.sum(w_mat * expected)  | 
990 | 998 |     return float(1 - k)  | 
991 | 999 | 
 
  | 
992 | 1000 | 
 
  | 
 | 
0 commit comments