Skip to content

Commit e73a665

Browse files
authored
Merge pull request #117 from polis-community/116-fix-bestkmeans
Improve error message for invalid init strategy
2 parents c3a13ab + 4c2b74c commit e73a665

4 files changed

Lines changed: 22 additions & 12 deletions

File tree

CLAUDE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@ Tests use real Polis API data fixtures in `tests/fixtures/`. The test suite incl
6262
- `plots`: matplotlib, seaborn, concave-hull (visualization)
6363
- `dev`: pytest, mkdocs, nbmake (development)
6464
- `all`: everything
65+
66+
## Git Conventions
67+
68+
- When working on a branch that references an issue (e.g., `116-fix-bestkmeans`), include `Closes #116` in the commit message or PR description to auto-close the issue when merged

reddwarf/sklearn/cluster.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Literal, Optional
22

33
import numpy as np
44
from numpy.typing import ArrayLike, NDArray
@@ -9,6 +9,9 @@
99

1010
from reddwarf.sklearn.model_selection import GridSearchNonCV
1111

12+
InitStrategy = Literal["k-means++", "random", "polis"]
13+
VALID_INIT_STRATEGIES: List[str] = ["k-means++", "random", "polis"]
14+
1215

1316
def _to_range(r) -> range:
1417
"""
@@ -76,8 +79,8 @@ class PolisKMeans(KMeans):
7679
def __init__(
7780
self,
7881
n_clusters=8,
79-
init="k-means++", # or 'random', 'polis'
80-
init_centers: Optional[ArrayLike] = None, # array-like, optional
82+
init: InitStrategy = "k-means++",
83+
init_centers: Optional[ArrayLike] = None,
8184
n_init="auto",
8285
max_iter=300,
8386
tol=1e-4,
@@ -120,7 +123,10 @@ def _generate_centers(self, X, x_squared_norms, n_to_generate, random_state) ->
120123
raise ValueError("Not enough unique rows in X for 'polis' strategy.")
121124
centers = unique_X[:n_to_generate]
122125
else:
123-
raise ValueError(f"Unsupported init strategy: {self._init_strategy}")
126+
raise ValueError(
127+
f"Unsupported init strategy: {self._init_strategy!r}. "
128+
f"Valid options are: {VALID_INIT_STRATEGIES}"
129+
)
124130
return centers
125131

126132
def fit(self, X, y=None, sample_weight=None):
@@ -178,7 +184,7 @@ def __init__(
178184
self,
179185
n_clusters: int = 100,
180186
random_state: Optional[int] = None,
181-
init: str = "k-means++",
187+
init: InitStrategy = "k-means++",
182188
init_centers: Optional[ArrayLike] = None,
183189
):
184190
self.n_clusters = n_clusters
@@ -232,7 +238,7 @@ class BestPolisKMeans(BaseEstimator):
232238
def __init__(
233239
self,
234240
k_bounds: Optional[List[int]] = None,
235-
init: str = "polis",
241+
init: InitStrategy = "polis",
236242
init_centers: Optional[ArrayLike] = None,
237243
random_state: Optional[int] = None,
238244
):

reddwarf/utils/clusterer/kmeans.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pandas as pd
33
import numpy as np
44
from reddwarf.sklearn.model_selection import GridSearchNonCV
5-
from reddwarf.sklearn.cluster import PolisKMeans
5+
from reddwarf.sklearn.cluster import InitStrategy, PolisKMeans
66
from sklearn.metrics import silhouette_score
77
from typing import List, Optional
88

@@ -33,7 +33,7 @@ def to_range(r: RangeLike) -> range:
3333
def run_kmeans(
3434
dataframe: pd.DataFrame,
3535
n_clusters: int = 2,
36-
init="k-means++",
36+
init: InitStrategy = "k-means++",
3737
# TODO: Improve this type. 3d?
3838
init_centers: Optional[List] = None,
3939
random_state: Optional[int] = None,
@@ -66,7 +66,7 @@ def run_kmeans(
6666
def find_best_kmeans(
6767
X_to_cluster: NDArray,
6868
k_bounds: RangeLike = [2, 5],
69-
init="k-means++",
69+
init: InitStrategy = "k-means++",
7070
init_centers: Optional[List] = None,
7171
random_state: Optional[int] = None,
7272
) -> tuple[int, float, PolisKMeans | None]:

tests/sklearn/test_cluster.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ def test_init_centers_wrong_n_features(self, simple_data):
184184
pkm.fit(X)
185185

186186
def test_unsupported_init_strategy(self, simple_data):
187-
"""Test that unsupported init strategy raises error."""
187+
"""Test that unsupported init strategy raises error with valid options."""
188188
X, _ = simple_data
189-
pkm = PolisKMeans(n_clusters=3, init="invalid")
189+
pkm = PolisKMeans(n_clusters=3, init="k-means++")
190190
pkm._init_strategy = "invalid" # Bypass __init__ validation
191191

192-
with pytest.raises(ValueError, match="Unsupported init strategy"):
192+
with pytest.raises(ValueError, match=r"Unsupported init strategy.*k-means\+\+.*random.*polis"):
193193
pkm.fit(X)
194194

195195
def test_reproducibility_with_random_state(self, simple_data):

0 commit comments

Comments
 (0)