1- from typing import List , Tuple , Any , Dict
1+ from typing import Dict , List , Tuple
22
33import numpy as np
44import pandas as pd
@@ -178,17 +178,32 @@ def data_to_fit(self) -> pd.DataFrame:
178178
179179 return self ._data_to_fit
180180
181+ @staticmethod
182+ def add_regression_guided (
183+ data : pd .DataFrame , vars : List [str ], alpha : List [float ]
184+ ) -> pd .DataFrame :
185+ """
186+ Calculate regression-guided variables.
181187
182- def add_regression_guided (self , data : pd .DataFrame , vars : List [str ], alpha : List [float ]) -> pd .DataFrame :
188+ Parameters
189+ ----------
190+ data : pd.DataFrame
191+ The data to fit the K-Means algorithm.
192+ vars : List[str]
193+ The variables to use for regression-guided clustering.
194+ alpha : List[float]
195+ The alpha values to use for regression-guided clustering.
183196
184- """
185- Help KMA clustering features with regression-guided variables.
197+ Returns
198+ -------
199+ pd.DataFrame
200+ The data with the regression-guided variables.
186201 """
187202
188203 # Stack guiding variables into (time, n_vars) array
189204 X = data .drop (columns = vars )
190205 Y = np .stack ([data [var ].values for var in vars ], axis = 1 )
191-
206+
192207 # Normalize input features
193208 X_std = X .std ().replace (0 , 1 )
194209 X_norm = X / X_std
@@ -223,7 +238,7 @@ def fit(
223238 min_number_of_points : int = None ,
224239 max_number_of_iterations : int = 10 ,
225240 normalize_data : bool = False ,
226- regression_guided : Dict [str , Dict [ str , Any ] ] = {},
241+ regression_guided : Dict [str , List ] = {},
227242 ) -> None :
228243 """
229244 Fit the K-Means algorithm to the provided data.
@@ -232,8 +247,7 @@ def fit(
232247 provided dataframe and custom scale factor.
233248 It normalizes the data, and returns the calculated centroids.
234249
235- TODO: Implement KMA regression guided with variable.
236- Add option to force KMA initialization with MDA centroids.
250+ TODO: Add option to force KMA initialization with MDA centroids.
237251
238252 Parameters
239253 ----------
@@ -256,22 +270,23 @@ def fit(
256270 A flag to normalize the data. Default is False.
257271 regression_guided: dict, optional
258272 A dictionary specifying regression-guided clustering variables and relative weights.
273+ Example: {"vars":["Fe"],"alpha":[0.6]}. Default is {}.
259274 """
260-
275+
261276 if regression_guided :
262277 data = self .add_regression_guided (
263- data = data ,
264- vars = regression_guided .get ("vars" , None ),
265- alpha = regression_guided .get ("alpha" , None )
278+ data = data ,
279+ vars = regression_guided .get ("vars" , None ),
280+ alpha = regression_guided .get ("alpha" , None ),
266281 )
267-
282+
268283 super ().fit (
269284 data = data ,
270285 directional_variables = directional_variables ,
271286 custom_scale_factor = custom_scale_factor ,
272287 normalize_data = normalize_data ,
273288 )
274-
289+
275290 # Fit K-Means algorithm
276291 if min_number_of_points is not None :
277292 stable_kma_child = False
@@ -303,7 +318,7 @@ def fit(
303318 self .centroids = self .denormalize (
304319 normalized_data = self .normalized_centroids , scale_factor = self .scale_factor
305320 )
306-
321+
307322 for directional_variable in self .directional_variables :
308323 self .centroids [directional_variable ] = self .get_degrees_from_uv (
309324 xu = self .centroids [f"{ directional_variable } _u" ].values ,
@@ -348,7 +363,7 @@ def fit_predict(
348363 min_number_of_points : int = None ,
349364 max_number_of_iterations : int = 10 ,
350365 normalize_data : bool = False ,
351- regression_guided : Dict [str , Dict [ str , Any ] ] = {},
366+ regression_guided : Dict [str , List ] = {},
352367 ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
353368 """
354369 Fit the K-Means algorithm to the provided data and predict the nearest centroid
@@ -373,22 +388,25 @@ def fit_predict(
373388 Default is 10.
374389 normalize_data : bool, optional
375390 A flag to normalize the data. Default is False.
391+ regression_guided: dict, optional
392+ A dictionary specifying regression-guided clustering variables and relative weights.
393+ Example: {"vars":["Fe"],"alpha":[0.6]}. Default is {}.
376394
377395 Returns
378396 -------
379397 Tuple[pd.DataFrame, pd.DataFrame]
380398 A tuple containing the nearest centroid index for each data point,
381399 and the nearest centroids.
382400 """
383-
401+
384402 self .fit (
385403 data = data ,
386404 directional_variables = directional_variables ,
387405 custom_scale_factor = custom_scale_factor ,
388406 min_number_of_points = min_number_of_points ,
389407 max_number_of_iterations = max_number_of_iterations ,
390408 normalize_data = normalize_data ,
391- regression_guided = regression_guided
409+ regression_guided = regression_guided ,
392410 )
393411
394412 return self .predict (data = data )
0 commit comments