Skip to content

Commit 2383b91

Browse files
committed
[JTH] generalise kma test
1 parent fe86de6 commit 2383b91

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

bluemath_tk/datamining/kma.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ def predict(self, data: pd.DataFrame) -> Tuple[np.ndarray, pd.DataFrame]:
284284
Returns
285285
-------
286286
Tuple[np.ndarray, pd.DataFrame]
287-
A tuple containing the nearest centroid index for each data point and the nearest centroids.
287+
A tuple containing the nearest centroid index for each data point,
288+
and the nearest centroids.
288289
"""
289290

290291
if self.is_fitted is False:
@@ -310,9 +311,11 @@ def fit_predict(
310311
data: pd.DataFrame,
311312
directional_variables: List[str] = [],
312313
custom_scale_factor: dict = {},
314+
min_number_of_points: int = None,
313315
) -> Tuple[np.ndarray, pd.DataFrame]:
314316
"""
315-
Fit the K-Means algorithm to the provided data and predict the nearest centroid for each data point.
317+
Fit the K-Means algorithm to the provided data and predict the nearest centroid
318+
for each data point.
316319
317320
Parameters
318321
----------
@@ -324,17 +327,22 @@ def fit_predict(
324327
custom_scale_factor : dict
325328
A dictionary specifying custom scale factors for normalization.
326329
Default is {}.
330+
min_number_of_points : int, optional
331+
The minimum number of points to consider a cluster.
332+
Default is None.
327333
328334
Returns
329335
-------
330-
Tuple[pd.DataFrame, np.ndarray, pd.DataFrame]
331-
A tuple containing the nearest centroid index for each data point, and the nearest centroids.
336+
Tuple[np.ndarray, pd.DataFrame]
337+
A tuple containing the nearest centroid index for each data point,
338+
and the nearest centroids.
332339
"""
333340

334341
self.fit(
335342
data=data,
336343
directional_variables=directional_variables,
337344
custom_scale_factor=custom_scale_factor,
345+
min_number_of_points=min_number_of_points,
338346
)
339347

340348
return self.predict(data=data)

tests/datamining/test_kma.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def setUp(self):
1616
self.kma = KMA(num_clusters=10)
1717

1818
def test_fit(self):
19-
self.kma.fit(data=self.df, min_number_of_points=80)
19+
self.kma.fit(data=self.df, min_number_of_points=50)
2020
self.assertIsInstance(self.kma.centroids, pd.DataFrame)
2121
self.assertEqual(self.kma.centroids.shape[0], 10)
2222

@@ -36,11 +36,15 @@ def test_predict(self):
3636
self.assertEqual(nearest_centroid_df.shape[0], 15)
3737

3838
def test_fit_predict(self):
39-
nearest_centroids, nearest_centroid_df = self.kma.fit_predict(data=self.df)
40-
self.assertIsInstance(nearest_centroids, np.ndarray)
41-
self.assertEqual(len(nearest_centroids), 1000)
42-
self.assertIsInstance(nearest_centroid_df, pd.DataFrame)
43-
self.assertEqual(nearest_centroid_df.shape[0], 1000)
39+
predicted_labels, predicted_labels_df = self.kma.fit_predict(
40+
data=self.df, min_number_of_points=50
41+
)
42+
_unique_labels, counts = np.unique(predicted_labels, return_counts=True)
43+
self.assertTrue(np.all(counts >= 50))
44+
self.assertIsInstance(predicted_labels, np.ndarray)
45+
self.assertEqual(len(predicted_labels), 1000)
46+
self.assertIsInstance(predicted_labels_df, pd.DataFrame)
47+
self.assertEqual(predicted_labels_df.shape[0], 1000)
4448

4549

4650
if __name__ == "__main__":

0 commit comments

Comments
 (0)