Skip to content

Commit ca39ad1

Browse files
authored
FEA Add array API support for cohen_kappa_score (scikit-learn#32619)
1 parent e40c216 commit ca39ad1

File tree

4 files changed

+24
-9
lines changed

4 files changed

+24
-9
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ Metrics
149149
- :func:`sklearn.metrics.accuracy_score`
150150
- :func:`sklearn.metrics.balanced_accuracy_score`
151151
- :func:`sklearn.metrics.brier_score_loss`
152+
- :func:`sklearn.metrics.cohen_kappa_score`
152153
- :func:`sklearn.metrics.confusion_matrix`
153154
- :func:`sklearn.metrics.d2_brier_score`
154155
- :func:`sklearn.metrics.d2_log_loss_score`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.metrics.cohen_kappa_score` now supports array API compatible inputs.
2+
By :user:`Omar Salman <OmarManzoor>`.

sklearn/metrics/_classification.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_bincount,
3434
_convert_to_numpy,
3535
_count_nonzero,
36+
_fill_diagonal,
3637
_find_matching_floating_dtype,
3738
_is_numpy_namespace,
3839
_is_xp_namespace,
@@ -970,23 +971,30 @@ class labels [2]_.
970971
raise ValueError(msg) from e
971972
raise
972973

974+
xp, _, device_ = get_namespace_and_device(y1, y2)
973975
n_classes = confusion.shape[0]
974-
sum0 = np.sum(confusion, axis=0)
975-
sum1 = np.sum(confusion, axis=1)
976-
expected = np.outer(sum0, sum1) / np.sum(sum0)
976+
# array_api_strict only supports floating point dtypes for __truediv__
977+
# which is used below to compute `expected` as well as `k`. Therefore
978+
# we use the maximum floating point dtype available for relevant arrays
979+
# to avoid running into this problem.
980+
max_float_dtype = _max_precision_float_dtype(xp, device=device_)
981+
confusion = xp.astype(confusion, max_float_dtype, copy=False)
982+
sum0 = xp.sum(confusion, axis=0)
983+
sum1 = xp.sum(confusion, axis=1)
984+
expected = xp.linalg.outer(sum0, sum1) / xp.sum(sum0)
977985

978986
if weights is None:
979-
w_mat = np.ones([n_classes, n_classes], dtype=int)
980-
w_mat.flat[:: n_classes + 1] = 0
987+
w_mat = xp.ones([n_classes, n_classes], dtype=max_float_dtype, device=device_)
988+
_fill_diagonal(w_mat, 0, xp=xp)
981989
else: # "linear" or "quadratic"
982-
w_mat = np.zeros([n_classes, n_classes], dtype=int)
983-
w_mat += np.arange(n_classes)
990+
w_mat = xp.zeros([n_classes, n_classes], dtype=max_float_dtype, device=device_)
991+
w_mat += xp.arange(n_classes)
984992
if weights == "linear":
985-
w_mat = np.abs(w_mat - w_mat.T)
993+
w_mat = xp.abs(w_mat - w_mat.T)
986994
else:
987995
w_mat = (w_mat - w_mat.T) ** 2
988996

989-
k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
997+
k = xp.sum(w_mat * confusion) / xp.sum(w_mat * expected)
990998
return float(1 - k)
991999

9921000

sklearn/metrics/tests/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,6 +2255,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
22552255
check_array_api_binary_classification_metric,
22562256
check_array_api_multiclass_classification_metric,
22572257
],
2258+
cohen_kappa_score: [
2259+
check_array_api_binary_classification_metric,
2260+
check_array_api_multiclass_classification_metric,
2261+
],
22582262
confusion_matrix: [
22592263
check_array_api_binary_classification_metric,
22602264
check_array_api_multiclass_classification_metric,

0 commit comments

Comments
 (0)