1212from ._utils import get_ndarray , validate_seed
1313
1414
15+ def _tolerance (X : jnp .ndarray , tol : float ) -> float :
16+ """Return a tolerance which is dependent on the dataset."""
17+ variances = np .var (X , axis = 0 )
18+ return np .mean (variances ) * tol
19+
20+
1521def _initialize_random (X : jnp .ndarray , n_clusters : int , key : jax .random .KeyArray ) -> jnp .ndarray :
1622 """Initialize cluster centroids randomly."""
1723 n_obs = X .shape [0 ]
18- indices = jax .random .choice (key , n_obs , (n_clusters ,), replace = False )
24+ key , subkey = jax .random .split (key )
25+ indices = jax .random .choice (subkey , n_obs , (n_clusters ,), replace = False )
1926 initial_state = X [indices ]
2027 return initial_state
2128
@@ -53,13 +60,14 @@ def _step(state, _):
5360 return state , state ["centroid" ]
5461
5562 _ , centroids = jax .lax .scan (_step , initial_state , jnp .arange (n_clusters - 1 ))
63+ centroids = jnp .concatenate ([initial_centroid [jnp .newaxis , :], centroids ])
5664 return centroids
5765
5866
5967@jax .jit
6068def _get_dist_labels (X : jnp .ndarray , centroids : jnp .ndarray ) -> jnp .ndarray :
6169 """Get the distance and labels for each observation."""
62- dist = cdist (X , centroids )
70+ dist = jnp . square ( cdist (X , centroids ) )
6371 labels = jnp .argmin (dist , axis = 1 )
6472 return dist , labels
6573
@@ -94,15 +102,15 @@ def __init__(
94102 self ,
95103 n_clusters : int = 8 ,
96104 init : Literal ["k-means++" , "random" ] = "k-means++" ,
97- n_init : int = 10 ,
105+ n_init : int = 1 ,
98106 max_iter : int = 300 ,
99107 tol : float = 1e-4 ,
100108 seed : IntOrKey = 0 ,
101109 ):
102110 self .n_clusters = n_clusters
103111 self .n_init = n_init
104112 self .max_iter = max_iter
105- self .tol = tol
113+ self .tol_scale = tol
106114 self .seed : jax .random .KeyArray = validate_seed (seed )
107115
108116 if init not in ["k-means++" , "random" ]:
@@ -115,6 +123,7 @@ def __init__(
115123 def fit (self , X : np .ndarray ):
116124 """Fit the model to the data."""
117125 X = check_array (X , dtype = np .float32 , order = "C" )
126+ self .tol = _tolerance (X , self .tol_scale )
118127 # Subtract mean for numerical accuracy
119128 mean = X .mean (axis = 0 )
120129 X -= mean
@@ -136,8 +145,7 @@ def _fit(self, X: np.ndarray):
136145 @partial (jax .jit , static_argnums = (0 ,))
137146 def _kmeans_full_run (self , X : jnp .ndarray , key : jnp .ndarray ) -> jnp .ndarray :
138147 def _kmeans_step (state ):
139- old_inertia = state [1 ]
140- centroids , _ , _ , n_iter = state
148+ centroids , old_inertia , _ , n_iter = state
141149 # TODO(adamgayoso): Efficiently compute argmin and min simultaneously.
142150 dist , new_labels = _get_dist_labels (X , centroids )
143151 # From https://colab.research.google.com/drive/1AwS4haUx6swF82w3nXr6QKhajdF8aSvA?usp=sharing
@@ -159,19 +167,22 @@ def _kmeans_step(state):
159167 )
160168 / counts
161169 )
162- new_inertia = jnp .mean (jnp .min (dist , axis = 1 ))
170+ new_inertia = jnp .sum (jnp .min (dist , axis = 1 ))
163171 n_iter = n_iter + 1
164172 return new_centroids , new_inertia , old_inertia , n_iter
165173
166174 def _kmeans_convergence (state ):
167175 _ , new_inertia , old_inertia , n_iter = state
168- cond1 = jnp .abs (old_inertia - new_inertia ) < self .tol
169- cond2 = n_iter > self .max_iter
176+ cond1 = jnp .abs (old_inertia - new_inertia ) > self .tol
177+ cond2 = n_iter < self .max_iter
170178 return jnp .logical_or (cond1 , cond2 )[0 ]
171179
172180 centroids = self ._initialize (X , self .n_clusters , key )
173181 # centroids, new_inertia, old_inertia, n_iter
174182 state = (centroids , jnp .inf , jnp .inf , jnp .array ([0.0 ]))
175- state = _kmeans_step (state )
176183 state = jax .lax .while_loop (_kmeans_convergence , _kmeans_step , state )
177- return state [0 ], state [1 ]
184+ # Compute final inertia
185+ centroids = state [0 ]
186+ dist , _ = _get_dist_labels (X , centroids )
187+ final_intertia = jnp .sum (jnp .min (dist , axis = 1 ))
188+ return centroids , final_intertia
0 commit comments