Skip to content

Commit c53d8b7

Browse files
committed
Fix k-means premature convergence bug
When k-means++ initialisation selects data points as initial centroids, points at those locations have upper_bound=0 in Hamerly's algorithm, causing them to be incorrectly pruned from reassignment checks. This could cause the algorithm to declare convergence on the first iteration without ever computing true cluster centroids. This fix updates centroids to be cluster means immediately after the initial assignment, before entering the main convergence loop. This ensures Hamerly bounds are computed against true centroids rather than the k-means++ selected data points. Document k-means local minima and use fixed seed in test - Add documentation explaining that k-means converges to local minima and may produce suboptimal results depending on initialisation - Update test_kmeans_three_clusters to use a fixed seed (42) for deterministic testing which should avoid intermittent failures from unlucky k-means++ initialisation Signed-off-by: Stephan Hügel <[email protected]>
1 parent 240f203 commit c53d8b7

File tree

1 file changed

+59
-14
lines changed
  • geo/src/algorithm/kmeans

1 file changed

+59
-14
lines changed

geo/src/algorithm/kmeans/mod.rs

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@
5454
//!
5555
//! # Notes
5656
//!
57+
//! ## Local Minima
58+
//!
59+
//! _k_-means converges to a local minimum, not necessarily the global optimum. The final clustering
60+
//! depends on the initial centroid placement. Although _k_-means++ provides good initialisation,
61+
//! it can still occasionally place multiple initial centroids within the same "natural" cluster,
62+
//! leading to suboptimal results. For critical applications, consider running _k_-means multiple
63+
//! times with different seeds and selecting the result with lowest inertia (sum of squared
64+
//! distances to centroids).
65+
//!
66+
//! ## Empty Clusters
67+
//!
5768
//! Empty clusters may rarely occur during iteration if all points are reassigned away from a centroid.
5869
//! When this happens, the algorithm attempts to recover by reassigning the farthest point from its
5970
//! current centroid to the empty cluster (following scikit-learn's approach). If recovery fails,
@@ -359,24 +370,55 @@ where
359370
let mut upper_bounds = vec![T::infinity(); n];
360371
let mut lower_bounds = vec![T::zero(); n];
361372

362-
// First iteration: assign all points and initialize bounds
363-
for (i, ((point, assignment), (upper, lower))) in points
373+
// First iteration: assign all points to initial (k-means++) centroids
374+
for (i, (point, assignment)) in points.iter().zip(assignments.iter_mut()).enumerate() {
375+
let (nearest_idx, _, _) = find_nearest_and_second_nearest(
376+
*point,
377+
&centroids,
378+
point_sq_norms[i],
379+
&centroid_sq_norms,
380+
);
381+
*assignment = nearest_idx;
382+
}
383+
384+
// Update centroids to be cluster means (initial centroids from k-means++ are data points,
385+
// not true cluster means). This is required before the first convergence check because
386+
// otherwise points at the initial centroid locations have upper_bound=0 and get pruned.
387+
centroids = update_centroids(points, &mut assignments, &centroids, k, 0)?;
388+
centroid_sq_norms = centroids.iter().map(|c| c.0.magnitude_squared()).collect();
389+
390+
// Initialise Hamerly bounds based on true cluster centroids
391+
for (i, ((point, &assignment), (upper, lower))) in points
364392
.iter()
365-
.zip(assignments.iter_mut())
393+
.zip(assignments.iter())
366394
.zip(upper_bounds.iter_mut().zip(lower_bounds.iter_mut()))
367395
.enumerate()
368396
{
369-
let (nearest_idx, nearest_sq_dist, second_nearest_sq_dist) =
370-
find_nearest_and_second_nearest(
371-
*point,
372-
&centroids,
373-
point_sq_norms[i],
374-
&centroid_sq_norms,
375-
);
376-
*assignment = nearest_idx;
377-
// Store actual distances (take sqrt of squared distances)
378-
*upper = nearest_sq_dist.sqrt();
397+
// Calculate actual distances to all centroids
398+
let (_, nearest_sq_dist, second_nearest_sq_dist) = find_nearest_and_second_nearest(
399+
*point,
400+
&centroids,
401+
point_sq_norms[i],
402+
&centroid_sq_norms,
403+
);
404+
// Upper bound is distance to assigned centroid
405+
let sq_dist_to_assigned = squared_distance_using_norms(
406+
*point,
407+
centroids[assignment],
408+
point_sq_norms[i],
409+
centroid_sq_norms[assignment],
410+
);
411+
*upper = sq_dist_to_assigned.sqrt();
379412
*lower = second_nearest_sq_dist.sqrt();
413+
// After centroids are updated to be the mean of their clusters, a point's
414+
// assigned centroid might no longer be its geometrically closest centroid.
415+
// In this case, `nearest_sq_dist` (to the new closest) will be less than
416+
// `sq_dist_to_assigned`. The lower bound must be the distance to the true
417+
// second-closest centroid, so we take the minimum of the old second-closest
418+
// and the new `nearest_sq_dist`.
419+
if nearest_sq_dist < sq_dist_to_assigned {
420+
*lower = nearest_sq_dist.sqrt().min(*lower);
421+
}
380422
}
381423

382424
// Track convergence state
@@ -1005,6 +1047,8 @@ mod tests {
10051047

10061048
#[test]
10071049
fn test_kmeans_three_clusters() {
1050+
// Use a fixed seed for deterministic testing. K-means can converge to local minima
1051+
// depending on initialisation, so we use a known-good seed to test algorithm correctness.
10081052
let points = [
10091053
// Cluster 1
10101054
point!(x: 0.0, y: 0.0),
@@ -1016,7 +1060,8 @@ mod tests {
10161060
point!(x: 20.0, y: 20.0),
10171061
point!(x: 21.0, y: 20.0),
10181062
];
1019-
let labels = points.kmeans(3).unwrap();
1063+
let params = KMeansParams::new(3).seed(42);
1064+
let labels = points.kmeans_with_params(params).unwrap();
10201065

10211066
// Each pair should be in the same cluster
10221067
assert_eq!(labels[0], labels[1]);

0 commit comments

Comments
 (0)