Skip to content

Commit 85fa4f6

Browse files
committed
[JTH] little modifications in kma reg guided albas implementation
1 parent 4e3441f commit 85fa4f6

File tree

5 files changed

+87
-41
lines changed

5 files changed

+87
-41
lines changed

bluemath_tk/core/decorators.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,28 +130,42 @@ def wrapper(
130130
min_number_of_points: int = None,
131131
max_number_of_iterations: int = 10,
132132
normalize_data: bool = False,
133-
regression_guided: Dict[str, Dict[str, Any]] = {},
133+
regression_guided: Dict[str, List] = {},
134134
):
135135
if data is None:
136-
raise ValueError("Data cannot be None")
136+
raise ValueError("data cannot be None")
137137
elif not isinstance(data, pd.DataFrame):
138-
raise TypeError("Data must be a pandas DataFrame")
138+
raise TypeError("data must be a pandas DataFrame")
139139
if not isinstance(directional_variables, list):
140-
raise TypeError("Directional variables must be a list")
140+
raise TypeError("directional_variables must be a list")
141141
if not isinstance(custom_scale_factor, dict):
142-
raise TypeError("Custom scale factor must be a dict")
142+
raise TypeError("custom_scale_factor must be a dict")
143143
if min_number_of_points is not None:
144144
if not isinstance(min_number_of_points, int) or min_number_of_points <= 0:
145-
raise ValueError("Minimum number of points must be integer and > 0")
145+
raise ValueError("min_number_of_points must be integer and > 0")
146146
if (
147147
not isinstance(max_number_of_iterations, int)
148148
or max_number_of_iterations <= 0
149149
):
150-
raise ValueError("Maximum number of iterations must be integer and > 0")
150+
raise ValueError("max_number_of_iterations must be integer and > 0")
151151
if not isinstance(normalize_data, bool):
152-
raise TypeError("Normalize data must be a boolean")
152+
raise TypeError("normalize_data must be a boolean")
153153
if not isinstance(regression_guided, dict):
154154
raise TypeError("regression_guided must be a dictionary")
155+
if not all(
156+
isinstance(var, str) and var in data.columns
157+
for var in regression_guided.get("vars", [])
158+
):
159+
raise TypeError(
160+
"regression_guided vars must be a list of strings and must exist in data"
161+
)
162+
if not all(
163+
isinstance(alpha, float) and alpha >= 0 and alpha <= 1
164+
for alpha in regression_guided.get("alpha", [])
165+
):
166+
raise TypeError(
167+
"regression_guided alpha must be a list of floats between 0 and 1"
168+
)
155169
return func(
156170
self,
157171
data,
@@ -160,7 +174,7 @@ def wrapper(
160174
min_number_of_points,
161175
max_number_of_iterations,
162176
normalize_data,
163-
regression_guided
177+
regression_guided,
164178
)
165179

166180
return wrapper
@@ -388,7 +402,6 @@ def wrapper(
388402
self,
389403
data: xr.Dataset,
390404
fit_params: Dict[str, Dict[str, Any]] = {},
391-
regression_guided: Dict[str, Dict[str, Any]] = {},
392405
variable_to_sort_bmus: str = None,
393406
):
394407
if not isinstance(data, xr.Dataset):

bluemath_tk/datamining/kma.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Any, Dict
1+
from typing import Dict, List, Tuple
22

33
import numpy as np
44
import pandas as pd
@@ -178,17 +178,32 @@ def data_to_fit(self) -> pd.DataFrame:
178178

179179
return self._data_to_fit
180180

181+
@staticmethod
182+
def add_regression_guided(
183+
data: pd.DataFrame, vars: List[str], alpha: List[float]
184+
) -> pd.DataFrame:
185+
"""
186+
Calculate regression-guided variables.
181187
182-
def add_regression_guided(self, data: pd.DataFrame, vars: List[str], alpha: List[float]) -> pd.DataFrame:
188+
Parameters
189+
----------
190+
data : pd.DataFrame
191+
The data to fit the K-Means algorithm.
192+
vars : List[str]
193+
The variables to use for regression-guided clustering.
194+
alpha : List[float]
195+
The alpha values to use for regression-guided clustering.
183196
184-
"""
185-
Help KMA clustering features with regression-guided variables.
197+
Returns
198+
-------
199+
pd.DataFrame
200+
The data with the regression-guided variables.
186201
"""
187202

188203
# Stack guiding variables into (time, n_vars) array
189204
X = data.drop(columns=vars)
190205
Y = np.stack([data[var].values for var in vars], axis=1)
191-
206+
192207
# Normalize input features
193208
X_std = X.std().replace(0, 1)
194209
X_norm = X / X_std
@@ -223,7 +238,7 @@ def fit(
223238
min_number_of_points: int = None,
224239
max_number_of_iterations: int = 10,
225240
normalize_data: bool = False,
226-
regression_guided: Dict[str, Dict[str, Any]] = {},
241+
regression_guided: Dict[str, List] = {},
227242
) -> None:
228243
"""
229244
Fit the K-Means algorithm to the provided data.
@@ -232,8 +247,7 @@ def fit(
232247
provided dataframe and custom scale factor.
233248
It normalizes the data, and returns the calculated centroids.
234249
235-
TODO: Implement KMA regression guided with variable.
236-
Add option to force KMA initialization with MDA centroids.
250+
TODO: Add option to force KMA initialization with MDA centroids.
237251
238252
Parameters
239253
----------
@@ -256,22 +270,23 @@ def fit(
256270
A flag to normalize the data. Default is False.
257271
regression_guided: dict, optional
258272
A dictionary specifying regression-guided clustering variables and relative weights.
273+
Example: {"vars":["Fe"],"alpha":[0.6]}. Default is {}.
259274
"""
260-
275+
261276
if regression_guided:
262277
data = self.add_regression_guided(
263-
data=data,
264-
vars = regression_guided.get("vars", None),
265-
alpha = regression_guided.get("alpha", None)
278+
data=data,
279+
vars=regression_guided.get("vars", None),
280+
alpha=regression_guided.get("alpha", None),
266281
)
267-
282+
268283
super().fit(
269284
data=data,
270285
directional_variables=directional_variables,
271286
custom_scale_factor=custom_scale_factor,
272287
normalize_data=normalize_data,
273288
)
274-
289+
275290
# Fit K-Means algorithm
276291
if min_number_of_points is not None:
277292
stable_kma_child = False
@@ -303,7 +318,7 @@ def fit(
303318
self.centroids = self.denormalize(
304319
normalized_data=self.normalized_centroids, scale_factor=self.scale_factor
305320
)
306-
321+
307322
for directional_variable in self.directional_variables:
308323
self.centroids[directional_variable] = self.get_degrees_from_uv(
309324
xu=self.centroids[f"{directional_variable}_u"].values,
@@ -348,7 +363,7 @@ def fit_predict(
348363
min_number_of_points: int = None,
349364
max_number_of_iterations: int = 10,
350365
normalize_data: bool = False,
351-
regression_guided: Dict[str, Dict[str, Any]] = {},
366+
regression_guided: Dict[str, List] = {},
352367
) -> Tuple[pd.DataFrame, pd.DataFrame]:
353368
"""
354369
Fit the K-Means algorithm to the provided data and predict the nearest centroid
@@ -373,22 +388,25 @@ def fit_predict(
373388
Default is 10.
374389
normalize_data : bool, optional
375390
A flag to normalize the data. Default is False.
391+
regression_guided: dict, optional
392+
A dictionary specifying regression-guided clustering variables and relative weights.
393+
Example: {"vars":["Fe"],"alpha":[0.6]}. Default is {}.
376394
377395
Returns
378396
-------
379397
Tuple[pd.DataFrame, pd.DataFrame]
380398
A tuple containing the nearest centroid index for each data point,
381399
and the nearest centroids.
382400
"""
383-
401+
384402
self.fit(
385403
data=data,
386404
directional_variables=directional_variables,
387405
custom_scale_factor=custom_scale_factor,
388406
min_number_of_points=min_number_of_points,
389407
max_number_of_iterations=max_number_of_iterations,
390408
normalize_data=normalize_data,
391-
regression_guided=regression_guided
409+
regression_guided=regression_guided,
392410
)
393411

394412
return self.predict(data=data)

bluemath_tk/predictor/xwt.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import List
2-
31
import logging
42
import warnings
53
from datetime import datetime, timedelta
@@ -303,7 +301,7 @@ def get_conditioned_probabilities(self) -> pd.DataFrame:
303301
)
304302

305303
return df_cond_probs
306-
304+
307305
@validate_data_xwt
308306
def fit(
309307
self,
@@ -327,6 +325,9 @@ def fit(
327325
------
328326
XWTError
329327
If the data is not PCA formatted.
328+
329+
TODO: Standarize PCs by first PC variance.
330+
pca.pcs_df / pca.pcs.stds.isel(n_component=0).values ??
330331
"""
331332

332333
# Make a copy of the data to avoid modifying the original dataset
@@ -346,17 +347,16 @@ def fit(
346347

347348
kma: KMA = self.steps.get("kma")
348349
self.num_clusters = kma.num_clusters
349-
# TODO: standarize PCs by first PC variance
350-
351-
data_to_kma = pca.pcs_df
352-
350+
351+
data_to_kma = pca.pcs_df.copy()
352+
353353
if "regression_guided" in fit_params.get("kma", {}):
354354
guiding_vars = fit_params["kma"]["regression_guided"].get("vars", [])
355-
356-
if guiding_vars:
355+
356+
if guiding_vars:
357357
guiding_data = pd.DataFrame(
358358
{var: data[var].values for var in guiding_vars},
359-
index=data.time.values
359+
index=data.time.values,
360360
)
361361
data_to_kma = pd.concat([data_to_kma, guiding_data], axis=1)
362362

tests/datamining/test_kma.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import unittest
2+
23
import numpy as np
34
import pandas as pd
5+
46
from bluemath_tk.datamining.kma import KMA
57

68

@@ -46,6 +48,19 @@ def test_fit_predict(self):
4648
self.assertIsInstance(predicted_labels_df, pd.DataFrame)
4749
self.assertEqual(predicted_labels_df.shape[0], 1000)
4850

51+
def test_add_regression_guided(self):
52+
data = self.df.copy()
53+
data["Fe"] = data["Hs"] ** 2 * data["Tp"]
54+
predicted_labels, predicted_labels_df = self.kma.fit_predict(
55+
data=data,
56+
directional_variables=["Dir"],
57+
regression_guided={"vars": ["Fe"], "alpha": [0.6]},
58+
)
59+
self.assertIsInstance(predicted_labels, pd.DataFrame)
60+
self.assertEqual(len(predicted_labels), 1000)
61+
self.assertIsInstance(predicted_labels_df, pd.DataFrame)
62+
self.assertEqual(predicted_labels_df.shape[0], 1000)
63+
4964

5065
if __name__ == "__main__":
5166
unittest.main()

tests/interpolation/test_rbf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ class TestRBF(unittest.TestCase):
1010
def setUp(self):
1111
self.dataset = pd.DataFrame(
1212
{
13-
"Hs": np.random.rand(1000) * 7,
14-
"Tp": np.random.rand(1000) * 20,
15-
"Dir": np.random.rand(1000) * 360,
13+
"Hs": np.random.rand(100) * 7,
14+
"Tp": np.random.rand(100) * 20,
15+
"Dir": np.random.rand(100) * 360,
1616
}
1717
)
1818
self.subset = self.dataset.sample(frac=0.25)

0 commit comments

Comments
 (0)