Skip to content

Commit deef14c

Browse files
committed
array api compat for ch score
1 parent b3f2904 commit deef14c

File tree

2 files changed

+101
-7
lines changed

2 files changed

+101
-7
lines changed

sklearn/metrics/cluster/_unsupervised.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from sklearn.preprocessing import LabelEncoder
1818
from sklearn.utils import _safe_indexing, check_random_state, check_X_y
19-
from sklearn.utils._array_api import xpx
19+
from sklearn.utils._array_api import get_namespace_and_device, xpx
2020
from sklearn.utils._param_validation import Interval, StrOptions, validate_params
2121

2222

@@ -362,22 +362,24 @@ def calinski_harabasz_score(X, labels):
362362
>>> calinski_harabasz_score(X, kmeans.labels_)
363363
114.8...
364364
"""
365+
366+
xp, _, _ = get_namespace_and_device(X, labels)
365367
X, labels = check_X_y(X, labels)
366368
le = LabelEncoder()
367369
labels = le.fit_transform(labels)
368370

369371
n_samples, _ = X.shape
370-
n_labels = len(le.classes_)
372+
n_labels = le.classes_.shape[0]
371373

372374
check_number_of_labels(n_labels, n_samples)
373375

374376
extra_disp, intra_disp = 0.0, 0.0
375-
mean = np.mean(X, axis=0)
377+
mean = xp.sum(X, axis=0) / X.shape[0]
376378
for k in range(n_labels):
377379
cluster_k = X[labels == k]
378-
mean_k = np.mean(cluster_k, axis=0)
379-
extra_disp += len(cluster_k) * np.sum((mean_k - mean) ** 2)
380-
intra_disp += np.sum((cluster_k - mean_k) ** 2)
380+
mean_k = xp.sum(cluster_k, axis=0) / cluster_k.shape[0]
381+
extra_disp += cluster_k.shape[0] * xp.sum((mean_k - mean) ** 2)
382+
intra_disp += xp.sum((cluster_k - mean_k) ** 2)
381383

382384
return float(
383385
1.0

sklearn/metrics/cluster/tests/test_common.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66

7+
from sklearn._config import config_context
78
from sklearn.metrics.cluster import (
89
adjusted_mutual_info_score,
910
adjusted_rand_score,
@@ -18,7 +19,13 @@
1819
silhouette_score,
1920
v_measure_score,
2021
)
21-
from sklearn.utils._testing import assert_allclose
22+
from sklearn.utils._array_api import (
23+
_atol_for_type,
24+
_convert_to_numpy,
25+
_get_namespace_device_dtype_ids,
26+
yield_namespace_device_dtype_combinations,
27+
)
28+
from sklearn.utils._testing import _array_api_for_tests, assert_allclose
2229

2330
# Dictionaries of metrics
2431
# ------------------------
@@ -232,3 +239,88 @@ def test_returned_value_consistency(name):
232239

233240
assert isinstance(score, float)
234241
assert not isinstance(score, (np.float64, np.float32))
242+
243+
244+
def check_array_api_metric(
245+
metric, array_namespace, device, dtype_name, a_np, b_np, **metric_kwargs
246+
):
247+
xp = _array_api_for_tests(array_namespace, device)
248+
249+
a_xp = xp.asarray(a_np, device=device)
250+
b_xp = xp.asarray(b_np, device=device)
251+
252+
metric_np = metric(a_np, b_np, **metric_kwargs)
253+
254+
# When array API dispatch is disabled, and np.asarray works (for example PyTorch
255+
# with CPU device), calling the metric function with such numpy compatible inputs
256+
# should work (albeit by implicitly converting to numpy arrays instead of
257+
# dispatching to the array library).
258+
try:
259+
np.asarray(a_xp)
260+
np.asarray(b_xp)
261+
numpy_as_array_works = True
262+
except (TypeError, RuntimeError, ValueError):
263+
# PyTorch with CUDA device and CuPy raise TypeError consistently.
264+
# array-api-strict chose to raise RuntimeError instead. NumPy raises
265+
# a ValueError if the `__array__` dunder does not return an array.
266+
# Exception type may need to be updated in the future for other libraries.
267+
numpy_as_array_works = False
268+
269+
def _check_metric_matches(metric_a, metric_b, convert_a=False):
270+
if convert_a:
271+
metric_a = _convert_to_numpy(xp.asarray(metric_a), xp)
272+
assert_allclose(metric_a, metric_b, atol=_atol_for_type(dtype_name))
273+
274+
if numpy_as_array_works:
275+
metric_xp = metric(a_xp, b_xp, **metric_kwargs)
276+
277+
# Handle cases where multiple return values are not of the same shape,
278+
# e.g. precision_recall_curve:
279+
_check_metric_matches(metric_xp, metric_np)
280+
281+
metric_xp_mixed_1 = metric(a_np, b_xp, **metric_kwargs)
282+
_check_metric_matches(metric_xp_mixed_1, metric_np)
283+
284+
metric_xp_mixed_2 = metric(a_xp, b_np, **metric_kwargs)
285+
_check_metric_matches(metric_xp_mixed_2, metric_np)
286+
287+
with config_context(array_api_dispatch=True):
288+
metric_xp = metric(a_xp, b_xp, **metric_kwargs)
289+
_check_metric_matches(metric_xp, metric_np, convert_a=True)
290+
291+
292+
def check_array_api_unsupervised_metric(metric, array_namespace, device, dtype_name):
293+
y_pred = np.array([1, 0, 1, 0, 1, 1, 0])
294+
X = np.random.randint(10, size=(7, 10))
295+
296+
check_array_api_metric(
297+
metric,
298+
array_namespace,
299+
device,
300+
dtype_name,
301+
a_np=X,
302+
b_np=y_pred,
303+
)
304+
305+
306+
array_api_metric_checkers = {
307+
calinski_harabasz_score: [
308+
check_array_api_unsupervised_metric,
309+
]
310+
}
311+
312+
313+
def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers):
314+
for metric, checkers in metric_checkers.items():
315+
for checker in checkers:
316+
yield metric, checker
317+
318+
319+
@pytest.mark.parametrize(
320+
"array_namespace, device, dtype_name",
321+
yield_namespace_device_dtype_combinations(),
322+
ids=_get_namespace_device_dtype_ids,
323+
)
324+
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
325+
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
326+
check_func(metric, array_namespace, device, dtype_name)

0 commit comments

Comments
 (0)