Skip to content

Commit 7db72a4

Browse files
committed
use xp.mean
1 parent 996cf26 commit 7db72a4

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

sklearn/metrics/cluster/_unsupervised.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
)
1717
from sklearn.preprocessing import LabelEncoder
1818
from sklearn.utils import _safe_indexing, check_random_state, check_X_y
19-
from sklearn.utils._array_api import get_namespace_and_device, xpx
19+
from sklearn.utils._array_api import (
20+
_max_precision_float_dtype,
21+
get_namespace_and_device,
22+
xpx,
23+
)
2024
from sklearn.utils._param_validation import Interval, StrOptions, validate_params
2125

2226

@@ -363,7 +367,8 @@ def calinski_harabasz_score(X, labels):
363367
114.8...
364368
"""
365369

366-
xp, _, _ = get_namespace_and_device(X, labels)
370+
xp, _, device_ = get_namespace_and_device(X, labels)
371+
X = xp.astype(X, _max_precision_float_dtype(xp, device_), copy=False)
367372
X, labels = check_X_y(X, labels)
368373
le = LabelEncoder()
369374
labels = le.fit_transform(labels)
@@ -374,10 +379,10 @@ def calinski_harabasz_score(X, labels):
374379
check_number_of_labels(n_labels, n_samples)
375380

376381
extra_disp, intra_disp = 0.0, 0.0
377-
mean = xp.sum(X, axis=0) / X.shape[0]
382+
mean = xp.mean(X, axis=0)
378383
for k in range(n_labels):
379384
cluster_k = X[labels == k]
380-
mean_k = xp.sum(cluster_k, axis=0) / cluster_k.shape[0]
385+
mean_k = xp.mean(cluster_k, axis=0)
381386
extra_disp += cluster_k.shape[0] * xp.sum((mean_k - mean) ** 2)
382387
intra_disp += xp.sum((cluster_k - mean_k) ** 2)
383388

0 commit comments

Comments
 (0)