Skip to content

Commit 5baea99

Browse files
Fix Kmeans (#115)
* Fix kmeans Fix while loop convergence, kmeans++ init, tolerance value * Update test_metrics.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update _kmeans.py * Update test_metrics.py * Update test_metrics.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9e41b9c commit 5baea99

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

src/scib_metrics/utils/_kmeans.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@
1212
from ._utils import get_ndarray, validate_seed
1313

1414

15+
def _tolerance(X: jnp.ndarray, tol: float) -> float:
16+
"""Return a tolerance which is dependent on the dataset."""
17+
variances = np.var(X, axis=0)
18+
return np.mean(variances) * tol
19+
20+
1521
def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
1622
"""Initialize cluster centroids randomly."""
1723
n_obs = X.shape[0]
18-
indices = jax.random.choice(key, n_obs, (n_clusters,), replace=False)
24+
key, subkey = jax.random.split(key)
25+
indices = jax.random.choice(subkey, n_obs, (n_clusters,), replace=False)
1926
initial_state = X[indices]
2027
return initial_state
2128

@@ -53,13 +60,14 @@ def _step(state, _):
5360
return state, state["centroid"]
5461

5562
_, centroids = jax.lax.scan(_step, initial_state, jnp.arange(n_clusters - 1))
63+
centroids = jnp.concatenate([initial_centroid[jnp.newaxis, :], centroids])
5664
return centroids
5765

5866

5967
@jax.jit
6068
def _get_dist_labels(X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
6169
"""Get the distance and labels for each observation."""
62-
dist = cdist(X, centroids)
70+
dist = jnp.square(cdist(X, centroids))
6371
labels = jnp.argmin(dist, axis=1)
6472
return dist, labels
6573

@@ -94,15 +102,15 @@ def __init__(
94102
self,
95103
n_clusters: int = 8,
96104
init: Literal["k-means++", "random"] = "k-means++",
97-
n_init: int = 10,
105+
n_init: int = 1,
98106
max_iter: int = 300,
99107
tol: float = 1e-4,
100108
seed: IntOrKey = 0,
101109
):
102110
self.n_clusters = n_clusters
103111
self.n_init = n_init
104112
self.max_iter = max_iter
105-
self.tol = tol
113+
self.tol_scale = tol
106114
self.seed: jax.random.KeyArray = validate_seed(seed)
107115

108116
if init not in ["k-means++", "random"]:
@@ -115,6 +123,7 @@ def __init__(
115123
def fit(self, X: np.ndarray):
116124
"""Fit the model to the data."""
117125
X = check_array(X, dtype=np.float32, order="C")
126+
self.tol = _tolerance(X, self.tol_scale)
118127
# Subtract mean for numerical accuracy
119128
mean = X.mean(axis=0)
120129
X -= mean
@@ -136,8 +145,7 @@ def _fit(self, X: np.ndarray):
136145
@partial(jax.jit, static_argnums=(0,))
137146
def _kmeans_full_run(self, X: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray:
138147
def _kmeans_step(state):
139-
old_inertia = state[1]
140-
centroids, _, _, n_iter = state
148+
centroids, old_inertia, _, n_iter = state
141149
# TODO(adamgayoso): Efficiently compute argmin and min simultaneously.
142150
dist, new_labels = _get_dist_labels(X, centroids)
143151
# From https://colab.research.google.com/drive/1AwS4haUx6swF82w3nXr6QKhajdF8aSvA?usp=sharing
@@ -159,19 +167,22 @@ def _kmeans_step(state):
159167
)
160168
/ counts
161169
)
162-
new_inertia = jnp.mean(jnp.min(dist, axis=1))
170+
new_inertia = jnp.sum(jnp.min(dist, axis=1))
163171
n_iter = n_iter + 1
164172
return new_centroids, new_inertia, old_inertia, n_iter
165173

166174
def _kmeans_convergence(state):
167175
_, new_inertia, old_inertia, n_iter = state
168-
cond1 = jnp.abs(old_inertia - new_inertia) < self.tol
169-
cond2 = n_iter > self.max_iter
176+
cond1 = jnp.abs(old_inertia - new_inertia) > self.tol
177+
cond2 = n_iter < self.max_iter
170178
return jnp.logical_or(cond1, cond2)[0]
171179

172180
centroids = self._initialize(X, self.n_clusters, key)
173181
# centroids, new_inertia, old_inertia, n_iter
174182
state = (centroids, jnp.inf, jnp.inf, jnp.array([0.0]))
175-
state = _kmeans_step(state)
176183
state = jax.lax.while_loop(_kmeans_convergence, _kmeans_step, state)
177-
return state[0], state[1]
184+
# Compute final inertia
185+
centroids = state[0]
186+
dist, _ = _get_dist_labels(X, centroids)
187+
final_intertia = jnp.sum(jnp.min(dist, axis=1))
188+
return centroids, final_intertia

tests/test_metrics.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from scipy.sparse import csr_matrix
88
from scipy.spatial.distance import cdist as sp_cdist
99
from scipy.spatial.distance import pdist, squareform
10+
from sklearn.cluster import KMeans as SKMeans
11+
from sklearn.datasets import make_blobs
1012
from sklearn.metrics import silhouette_samples as sk_silhouette_samples
13+
from sklearn.metrics.pairwise import pairwise_distances_argmin
1114
from sklearn.neighbors import NearestNeighbors
1215

1316
import scib_metrics
@@ -115,11 +118,30 @@ def test_isolated_labels():
115118

116119

117120
def test_kmeans():
118-
X, _ = dummy_x_labels()
119-
kmeans = scib_metrics.utils.KMeans(2)
121+
centers = [[1, 1], [-1, -1], [1, -1]]
122+
len(centers)
123+
X, labels_true = make_blobs(n_samples=3000, centers=centers, cluster_std=0.7)
124+
kmeans = scib_metrics.utils.KMeans(n_clusters=3)
120125
kmeans.fit(X)
121126
assert kmeans.labels_.shape == (X.shape[0],)
122127

128+
skmeans = SKMeans(n_clusters=3)
129+
skmeans.fit(X)
130+
sk_inertia = np.array([skmeans.inertia_])
131+
jax_inertia = np.array([kmeans.inertia_])
132+
np.testing.assert_allclose(sk_inertia, jax_inertia, atol=4e-2)
133+
134+
# Reorder cluster centroids between methods and measure accuracy
135+
k_means_cluster_centers = kmeans.cluster_centroids_
136+
order = pairwise_distances_argmin(kmeans.cluster_centroids_, skmeans.cluster_centers_)
137+
sk_means_cluster_centers = skmeans.cluster_centers_[order]
138+
139+
k_means_labels = pairwise_distances_argmin(X, k_means_cluster_centers)
140+
sk_means_labels = pairwise_distances_argmin(X, sk_means_cluster_centers)
141+
142+
accuracy = (k_means_labels == sk_means_labels).sum() / len(k_means_labels)
143+
assert accuracy > 0.999
144+
123145

124146
def test_kbet():
125147
X, _, batch = dummy_x_labels_batch(x_is_neighbors_graph=True)

0 commit comments

Comments
 (0)