Skip to content

Commit a97f3d4

Browse files
committed
[JTH] add kma max iterations flag
1 parent de8eeb7 commit a97f3d4

File tree

5 files changed

+475
-275
lines changed

5 files changed

+475
-275
lines changed

bluemath_tk/core/decorators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def wrapper(
120120
directional_variables: List[str] = [],
121121
custom_scale_factor: dict = {},
122122
min_number_of_points: int = None,
123+
max_number_of_iterations: int = 10,
123124
normalize_data: bool = True,
124125
):
125126
if data is None:
@@ -133,6 +134,11 @@ def wrapper(
133134
if min_number_of_points is not None:
134135
if not isinstance(min_number_of_points, int) or min_number_of_points <= 0:
135136
raise ValueError("Minimum number of points must be integer and > 0")
137+
if (
138+
not isinstance(max_number_of_iterations, int)
139+
or max_number_of_iterations <= 0
140+
):
141+
raise ValueError("Maximum number of iterations must be integer and > 0")
136142
if not isinstance(normalize_data, bool):
137143
raise TypeError("Normalize data must be a boolean")
138144
return func(

bluemath_tk/datamining/kma.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def fit(
182182
directional_variables: List[str] = [],
183183
custom_scale_factor: dict = {},
184184
min_number_of_points: int = None,
185+
max_number_of_iterations: int = 10,
185186
normalize_data: bool = True,
186187
) -> None:
187188
"""
@@ -206,6 +207,10 @@ def fit(
206207
min_number_of_points : int, optional
207208
The minimum number of points to consider a cluster.
208209
Default is None.
210+
max_number_of_iterations : int, optional
211+
The maximum number of iterations for the K-Means algorithm.
212+
This is used when min_number_of_points is not None.
213+
Default is 10.
209214
normalize_data : bool, optional
210215
A flag to normalize the data. Default is True.
211216
"""
@@ -248,9 +253,10 @@ def fit(
248253
if np.all(counts >= min_number_of_points):
249254
stable_kma_child = True
250255
number_of_tries += 1
251-
if number_of_tries > 10:
256+
if number_of_tries > max_number_of_iterations:
252257
raise ValueError(
253-
"Failed to find a stable K-Means configuration after 10 attempts."
258+
f"Failed to find a stable K-Means configuration after {max_number_of_iterations} attempts."
259+
"Change max_number_of_iterations or min_number_of_points."
254260
)
255261
self.logger.info(
256262
f"Found a stable K-Means configuration after {number_of_tries} attempts."

0 commit comments

Comments
 (0)