|
1 | 1 | import pytest |
| 2 | +import numpy as np |
2 | 3 | from reddwarf.utils.clusterer.kmeans import run_kmeans, find_best_kmeans |
3 | 4 | from tests.fixtures import polis_convo_data |
4 | 5 | from tests.helpers import transform_base_clusters_to_participant_coords |
5 | 6 | import pandas as pd |
6 | 7 |
|
| 8 | + |
| 9 | +def test_find_best_kmeans_rejects_singleton_clusters(): |
| 10 | + """find_best_kmeans should never select a k that produces a singleton cluster.""" |
| 11 | + np.random.seed(42) |
| 12 | + cluster1 = np.random.normal(loc=[0, 0], scale=0.3, size=(30, 2)) |
| 13 | + cluster2 = np.random.normal(loc=[5, 5], scale=0.3, size=(30, 2)) |
| 14 | + outlier = np.array([[10, 10]]) |
| 15 | + X = np.vstack([cluster1, cluster2, outlier]) |
| 16 | + |
| 17 | + best_k, _, best_kmeans = find_best_kmeans( |
| 18 | + X_to_cluster=X, |
| 19 | + k_bounds=[2, 5], |
| 20 | + random_state=42, |
| 21 | + ) |
| 22 | + |
| 23 | + if best_kmeans is not None: |
| 24 | + unique, counts = np.unique(best_kmeans.labels_, return_counts=True) |
| 25 | + assert counts.min() >= 2, ( |
| 26 | + f"k={best_k} produced singleton cluster(s): {dict(zip(unique, counts))}" |
| 27 | + ) |
| 28 | + |
7 | 29 | @pytest.mark.parametrize("polis_convo_data", ["small"], indirect=True) |
8 | 30 | def test_run_kmeans_real_data_reproducible(polis_convo_data): |
9 | 31 | fixture = polis_convo_data |
|
0 commit comments