Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions geo/src/algorithm/kmeans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@
//!
//! # Notes
//!
//! ## Local Minima
//!
//! _k_-means converges to a local minimum, not necessarily the global optimum. The final clustering
//! depends on the initial centroid placement. Although _k_-means++ provides good initialisation,
//! it can still occasionally place multiple initial centroids within the same "natural" cluster,
//! leading to suboptimal results. For critical applications, consider running _k_-means multiple
//! times with different seeds and selecting the result with lowest inertia (sum of squared
//! distances to centroids).
//!
//! ## Empty Clusters
//!
//! Empty clusters may rarely occur during iteration if all points are reassigned away from a centroid.
//! When this happens, the algorithm attempts to recover by reassigning the farthest point from its
//! current centroid to the empty cluster (following scikit-learn's approach). If recovery fails,
Expand Down Expand Up @@ -359,24 +370,55 @@ where
let mut upper_bounds = vec![T::infinity(); n];
let mut lower_bounds = vec![T::zero(); n];

// First iteration: assign all points and initialize bounds
for (i, ((point, assignment), (upper, lower))) in points
// First iteration: assign all points to initial (k-means++) centroids
for (i, (point, assignment)) in points.iter().zip(assignments.iter_mut()).enumerate() {
let (nearest_idx, _, _) = find_nearest_and_second_nearest(
*point,
&centroids,
point_sq_norms[i],
&centroid_sq_norms,
);
*assignment = nearest_idx;
}

// Update centroids to be cluster means (initial centroids from k-means++ are data points,
// not true cluster means). This is required before the first convergence check because
// otherwise points at the initial centroid locations have upper_bound=0 and get pruned.
centroids = update_centroids(points, &mut assignments, &centroids, k, 0)?;
centroid_sq_norms = centroids.iter().map(|c| c.0.magnitude_squared()).collect();

// Initialise Hamerly bounds based on true cluster centroids
for (i, ((point, &assignment), (upper, lower))) in points
.iter()
.zip(assignments.iter_mut())
.zip(assignments.iter())
.zip(upper_bounds.iter_mut().zip(lower_bounds.iter_mut()))
.enumerate()
{
let (nearest_idx, nearest_sq_dist, second_nearest_sq_dist) =
find_nearest_and_second_nearest(
*point,
&centroids,
point_sq_norms[i],
&centroid_sq_norms,
);
*assignment = nearest_idx;
// Store actual distances (take sqrt of squared distances)
*upper = nearest_sq_dist.sqrt();
// Calculate actual distances to all centroids
let (_, nearest_sq_dist, second_nearest_sq_dist) = find_nearest_and_second_nearest(
*point,
&centroids,
point_sq_norms[i],
&centroid_sq_norms,
);
// Upper bound is distance to assigned centroid
let sq_dist_to_assigned = squared_distance_using_norms(
*point,
centroids[assignment],
point_sq_norms[i],
centroid_sq_norms[assignment],
);
*upper = sq_dist_to_assigned.sqrt();
*lower = second_nearest_sq_dist.sqrt();
// After centroids are updated to be the mean of their clusters, a point's
// assigned centroid might no longer be its geometrically closest centroid.
// In this case, `nearest_sq_dist` (to the new closest) will be less than
// `sq_dist_to_assigned`. The lower bound must be the distance to the true
// second-closest centroid, so we take the minimum of the old second-closest
// and the new `nearest_sq_dist`.
if nearest_sq_dist < sq_dist_to_assigned {
*lower = nearest_sq_dist.sqrt().min(*lower);
}
}

// Track convergence state
Expand Down Expand Up @@ -1005,6 +1047,8 @@ mod tests {

#[test]
fn test_kmeans_three_clusters() {
// Use a fixed seed for deterministic testing. K-means can converge to local minima
// depending on initialisation, so we use a known-good seed to test algorithm correctness.
let points = [
// Cluster 1
point!(x: 0.0, y: 0.0),
Expand All @@ -1016,7 +1060,8 @@ mod tests {
point!(x: 20.0, y: 20.0),
point!(x: 21.0, y: 20.0),
];
let labels = points.kmeans(3).unwrap();
let params = KMeansParams::new(3).seed(42);
let labels = points.kmeans_with_params(params).unwrap();

// Each pair should be in the same cluster
assert_eq!(labels[0], labels[1]);
Expand Down