|
1 | | -from typing import List, Optional |
| 1 | +from typing import List, Literal, Optional |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | from numpy.typing import ArrayLike, NDArray |
|
9 | 9 |
|
10 | 10 | from reddwarf.sklearn.model_selection import GridSearchNonCV |
11 | 11 |
|
| 12 | +InitStrategy = Literal["k-means++", "random", "polis"] |
| 13 | +VALID_INIT_STRATEGIES: List[str] = ["k-means++", "random", "polis"] |
| 14 | + |
12 | 15 |
|
13 | 16 | def _to_range(r) -> range: |
14 | 17 | """ |
@@ -76,8 +79,8 @@ class PolisKMeans(KMeans): |
76 | 79 | def __init__( |
77 | 80 | self, |
78 | 81 | n_clusters=8, |
79 | | - init="k-means++", # or 'random', 'polis' |
80 | | - init_centers: Optional[ArrayLike] = None, # array-like, optional |
| 82 | + init: InitStrategy = "k-means++", |
| 83 | + init_centers: Optional[ArrayLike] = None, |
81 | 84 | n_init="auto", |
82 | 85 | max_iter=300, |
83 | 86 | tol=1e-4, |
@@ -120,7 +123,10 @@ def _generate_centers(self, X, x_squared_norms, n_to_generate, random_state) -> |
120 | 123 | raise ValueError("Not enough unique rows in X for 'polis' strategy.") |
121 | 124 | centers = unique_X[:n_to_generate] |
122 | 125 | else: |
123 | | - raise ValueError(f"Unsupported init strategy: {self._init_strategy}") |
| 126 | + raise ValueError( |
| 127 | + f"Unsupported init strategy: {self._init_strategy!r}. " |
| 128 | + f"Valid options are: {VALID_INIT_STRATEGIES}" |
| 129 | + ) |
124 | 130 | return centers |
125 | 131 |
|
126 | 132 | def fit(self, X, y=None, sample_weight=None): |
@@ -178,7 +184,7 @@ def __init__( |
178 | 184 | self, |
179 | 185 | n_clusters: int = 100, |
180 | 186 | random_state: Optional[int] = None, |
181 | | - init: str = "k-means++", |
| 187 | + init: InitStrategy = "k-means++", |
182 | 188 | init_centers: Optional[ArrayLike] = None, |
183 | 189 | ): |
184 | 190 | self.n_clusters = n_clusters |
@@ -232,7 +238,7 @@ class BestPolisKMeans(BaseEstimator): |
232 | 238 | def __init__( |
233 | 239 | self, |
234 | 240 | k_bounds: Optional[List[int]] = None, |
235 | | - init: str = "polis", |
| 241 | + init: InitStrategy = "polis", |
236 | 242 | init_centers: Optional[ArrayLike] = None, |
237 | 243 | random_state: Optional[int] = None, |
238 | 244 | ): |
|
0 commit comments