Skip to content

Commit c19a8a7

Browse files
patconclaude
andcommitted
Improve error message for invalid init strategy
Add InitStrategy Literal type and show valid options in error message when an unsupported init strategy is passed. This helps catch common typos like "kmeans++" instead of "k-means++". Closes #116 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c3a13ab commit c19a8a7

3 files changed

Lines changed: 18 additions & 12 deletions

File tree

reddwarf/sklearn/cluster.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Literal, Optional
22

33
import numpy as np
44
from numpy.typing import ArrayLike, NDArray
@@ -9,6 +9,9 @@
99

1010
from reddwarf.sklearn.model_selection import GridSearchNonCV
1111

12+
InitStrategy = Literal["k-means++", "random", "polis"]
13+
VALID_INIT_STRATEGIES: List[str] = ["k-means++", "random", "polis"]
14+
1215

1316
def _to_range(r) -> range:
1417
"""
@@ -76,8 +79,8 @@ class PolisKMeans(KMeans):
7679
def __init__(
7780
self,
7881
n_clusters=8,
79-
init="k-means++", # or 'random', 'polis'
80-
init_centers: Optional[ArrayLike] = None, # array-like, optional
82+
init: InitStrategy = "k-means++",
83+
init_centers: Optional[ArrayLike] = None,
8184
n_init="auto",
8285
max_iter=300,
8386
tol=1e-4,
@@ -120,7 +123,10 @@ def _generate_centers(self, X, x_squared_norms, n_to_generate, random_state) ->
120123
raise ValueError("Not enough unique rows in X for 'polis' strategy.")
121124
centers = unique_X[:n_to_generate]
122125
else:
123-
raise ValueError(f"Unsupported init strategy: {self._init_strategy}")
126+
raise ValueError(
127+
f"Unsupported init strategy: {self._init_strategy!r}. "
128+
f"Valid options are: {VALID_INIT_STRATEGIES}"
129+
)
124130
return centers
125131

126132
def fit(self, X, y=None, sample_weight=None):
@@ -178,7 +184,7 @@ def __init__(
178184
self,
179185
n_clusters: int = 100,
180186
random_state: Optional[int] = None,
181-
init: str = "k-means++",
187+
init: InitStrategy = "k-means++",
182188
init_centers: Optional[ArrayLike] = None,
183189
):
184190
self.n_clusters = n_clusters
@@ -232,7 +238,7 @@ class BestPolisKMeans(BaseEstimator):
232238
def __init__(
233239
self,
234240
k_bounds: Optional[List[int]] = None,
235-
init: str = "polis",
241+
init: InitStrategy = "polis",
236242
init_centers: Optional[ArrayLike] = None,
237243
random_state: Optional[int] = None,
238244
):

reddwarf/utils/clusterer/kmeans.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pandas as pd
33
import numpy as np
44
from reddwarf.sklearn.model_selection import GridSearchNonCV
5-
from reddwarf.sklearn.cluster import PolisKMeans
5+
from reddwarf.sklearn.cluster import InitStrategy, PolisKMeans
66
from sklearn.metrics import silhouette_score
77
from typing import List, Optional
88

@@ -33,7 +33,7 @@ def to_range(r: RangeLike) -> range:
3333
def run_kmeans(
3434
dataframe: pd.DataFrame,
3535
n_clusters: int = 2,
36-
init="k-means++",
36+
init: InitStrategy = "k-means++",
3737
# TODO: Improve this type. 3d?
3838
init_centers: Optional[List] = None,
3939
random_state: Optional[int] = None,
@@ -66,7 +66,7 @@ def run_kmeans(
6666
def find_best_kmeans(
6767
X_to_cluster: NDArray,
6868
k_bounds: RangeLike = [2, 5],
69-
init="k-means++",
69+
init: InitStrategy = "k-means++",
7070
init_centers: Optional[List] = None,
7171
random_state: Optional[int] = None,
7272
) -> tuple[int, float, PolisKMeans | None]:

tests/sklearn/test_cluster.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ def test_init_centers_wrong_n_features(self, simple_data):
184184
pkm.fit(X)
185185

186186
def test_unsupported_init_strategy(self, simple_data):
187-
"""Test that unsupported init strategy raises error."""
187+
"""Test that unsupported init strategy raises error with valid options."""
188188
X, _ = simple_data
189-
pkm = PolisKMeans(n_clusters=3, init="invalid")
189+
pkm = PolisKMeans(n_clusters=3, init="k-means++")
190190
pkm._init_strategy = "invalid" # Bypass __init__ validation
191191

192-
with pytest.raises(ValueError, match="Unsupported init strategy"):
192+
with pytest.raises(ValueError, match=r"Unsupported init strategy.*k-means\+\+.*random.*polis"):
193193
pkm.fit(X)
194194

195195
def test_reproducibility_with_random_state(self, simple_data):

0 commit comments

Comments
 (0)