Skip to content

Commit f19b625

Browse files
authored
Merge pull request #119 from nicobao/fix/reject-singleton-clusters
fix(kmeans): reject k values that produce singleton clusters
2 parents 8bd5881 + 41e5504 commit f19b625

2 files changed

Lines changed: 25 additions & 0 deletions

File tree

reddwarf/utils/clusterer/kmeans.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def find_best_kmeans(
9090

9191
def scoring_function(estimator, X):
9292
labels = estimator.fit_predict(X)
93+
unique, counts = np.unique(labels, return_counts=True)
94+
if counts.min() < 2:
95+
return -1
9396
return silhouette_score(X, labels)
9497

9598
search = GridSearchNonCV(

tests/utils/clusterer/test_kmeans.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,31 @@
11
import pytest
2+
import numpy as np
23
from reddwarf.utils.clusterer.kmeans import run_kmeans, find_best_kmeans
34
from tests.fixtures import polis_convo_data
45
from tests.helpers import transform_base_clusters_to_participant_coords
56
import pandas as pd
67

8+
9+
def test_find_best_kmeans_rejects_singleton_clusters():
10+
"""find_best_kmeans should never select a k that produces a singleton cluster."""
11+
np.random.seed(42)
12+
cluster1 = np.random.normal(loc=[0, 0], scale=0.3, size=(30, 2))
13+
cluster2 = np.random.normal(loc=[5, 5], scale=0.3, size=(30, 2))
14+
outlier = np.array([[10, 10]])
15+
X = np.vstack([cluster1, cluster2, outlier])
16+
17+
best_k, _, best_kmeans = find_best_kmeans(
18+
X_to_cluster=X,
19+
k_bounds=[2, 5],
20+
random_state=42,
21+
)
22+
23+
if best_kmeans is not None:
24+
unique, counts = np.unique(best_kmeans.labels_, return_counts=True)
25+
assert counts.min() >= 2, (
26+
f"k={best_k} produced singleton cluster(s): {dict(zip(unique, counts))}"
27+
)
28+
729
@pytest.mark.parametrize("polis_convo_data", ["small"], indirect=True)
830
def test_run_kmeans_real_data_reproducible(polis_convo_data):
931
fixture = polis_convo_data

0 commit comments

Comments
 (0)