Skip to content

Commit e4d7c15

Browse files
authored
Merge pull request #18 from GeoOcean/clean/datamining
[JTH] merge last changes in datamining
2 parents 6ad3f19 + 50439da commit e4d7c15

File tree

11 files changed

+703
-208
lines changed

11 files changed

+703
-208
lines changed

bluemath_tk/core/decorators.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def wrapper(
7272
directional_variables: List[str] = [],
7373
custom_scale_factor: dict = {},
7474
):
75-
# NOTE: Default custom scale factors are defined below
76-
_default_custom_scale_factor = {}
7775
if data is None:
7876
raise ValueError("Data cannot be None")
7977
elif not isinstance(data, pd.DataFrame):
@@ -82,19 +80,6 @@ def wrapper(
8280
raise TypeError("Directional variables must be a list")
8381
if not isinstance(custom_scale_factor, dict):
8482
raise TypeError("Custom scale factor must be a dict")
85-
for directional_variable in directional_variables:
86-
if directional_variable not in custom_scale_factor:
87-
if directional_variable in _default_custom_scale_factor:
88-
custom_scale_factor[directional_variable] = (
89-
_default_custom_scale_factor[directional_variable]
90-
)
91-
self.logger.warning(
92-
f"Using default custom scale factor for {directional_variable}"
93-
)
94-
else:
95-
self.logger.warning(
96-
f"No custom scale factor provided for {directional_variable}, min and max values will be used"
97-
)
9883
return func(self, data, directional_variables, custom_scale_factor)
9984

10085
return wrapper
@@ -119,11 +104,9 @@ def validate_data_kma(func):
119104
def wrapper(
120105
self,
121106
data: pd.DataFrame,
122-
directional_variables: List[str],
123-
custom_scale_factor: dict,
107+
directional_variables: List[str] = [],
108+
custom_scale_factor: dict = {},
124109
):
125-
# NOTE: Default custom scale factors are defined below
126-
_default_custom_scale_factor = {}
127110
if data is None:
128111
raise ValueError("Data cannot be None")
129112
elif not isinstance(data, pd.DataFrame):
@@ -132,24 +115,46 @@ def wrapper(
132115
raise TypeError("Directional variables must be a list")
133116
if not isinstance(custom_scale_factor, dict):
134117
raise TypeError("Custom scale factor must be a dict")
135-
for directional_variable in directional_variables:
136-
if directional_variable not in custom_scale_factor:
137-
if directional_variable in _default_custom_scale_factor:
138-
custom_scale_factor[directional_variable] = (
139-
_default_custom_scale_factor[directional_variable]
140-
)
141-
self.logger.warning(
142-
f"Using default custom scale factor for {directional_variable}"
143-
)
144-
else:
145-
self.logger.warning(
146-
f"No custom scale factor provided for {directional_variable}, min and max values will be used"
147-
)
148118
return func(self, data, directional_variables, custom_scale_factor)
149119

150120
return wrapper
151121

152122

123+
def validate_data_som(func):
124+
"""
125+
Decorator to validate data in SOM class fit method.
126+
127+
Parameters
128+
----------
129+
func : callable
130+
The function to be decorated
131+
132+
Returns
133+
-------
134+
callable
135+
The decorated function
136+
"""
137+
138+
@functools.wraps(func)
139+
def wrapper(
140+
self,
141+
data: pd.DataFrame,
142+
directional_variables: List[str] = [],
143+
num_iteration: int = 1000,
144+
):
145+
if data is None:
146+
raise ValueError("Data cannot be None")
147+
elif not isinstance(data, pd.DataFrame):
148+
raise TypeError("Data must be a pandas DataFrame")
149+
if not isinstance(directional_variables, list):
150+
raise TypeError("Directional variables must be a list")
151+
if not isinstance(num_iteration, int) or num_iteration <= 0:
152+
raise ValueError("Number of iterations must be integer and > 0")
153+
return func(self, data, directional_variables, num_iteration)
154+
155+
return wrapper
156+
157+
153158
def validate_data_pca(func):
154159
"""
155160
Decorator to validate data in PCA class fit method.

bluemath_tk/core/models.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,85 @@ def get_metrics(
270270
}
271271

272272
return pd.DataFrame(metrics).T
273+
274+
@staticmethod
275+
def _get_uv_components(x_deg: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
276+
"""
277+
This method calculates the u and v components for the given directional data.
278+
279+
Here, we assume that the directional data is in degrees,
280+
beign 0° the North direction,
281+
and increasing clockwise.
282+
283+
0° N
284+
|
285+
|
286+
270° W <---------> 90° E
287+
|
288+
|
289+
90° S
290+
291+
Parameters
292+
----------
293+
x_deg : np.ndarray
294+
The directional data in degrees.
295+
296+
Returns
297+
-------
298+
Tuple[np.ndarray, np.ndarray]
299+
The u and v components.
300+
301+
Notes
302+
-----
303+
- TODO: This method can be moved to a separate utility module.
304+
"""
305+
306+
# Convert degrees to radians and adjust by subtracting from π/2
307+
x_rad = x_deg * np.pi / 180
308+
309+
# Calculate x and y components using cosine and sine
310+
xu = np.sin(x_rad)
311+
xv = np.cos(x_rad)
312+
313+
# Return the u and v components
314+
return xu, xv
315+
316+
@staticmethod
317+
def _get_degrees_from_uv(xu: np.ndarray, xv: np.ndarray) -> np.ndarray:
318+
"""
319+
This method calculates the degrees from the u and v components.
320+
321+
Here, we assume u and v represent angles between 0 and 360 degrees,
322+
where 0° is the North direction,
323+
and increasing clockwise.
324+
325+
(u=0, v=1)
326+
|
327+
|
328+
(u=-1, v=0) <---------> (u=1, v=0)
329+
|
330+
|
331+
(u=0, v=-1)
332+
333+
Parameters
334+
----------
335+
xu : np.ndarray
336+
The u component.
337+
xv : np.ndarray
338+
The v component.
339+
340+
Returns
341+
-------
342+
np.ndarray
343+
The degrees.
344+
345+
Notes
346+
-----
347+
- TODO: This method can be moved to a separate utility module.
348+
"""
349+
350+
# Calculate the degrees using the arctangent function
351+
x_deg = np.arctan2(xu, xv) * 180 / np.pi % 360
352+
353+
# Return the degrees
354+
return x_deg

bluemath_tk/core/operations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def denormalize(
201201
return data
202202

203203

204+
# TODO: Return pd.DataFrame or xr.Dataset depending on input type
204205
def standarize(
205206
data: Union[np.ndarray, pd.DataFrame, xr.Dataset],
206207
scaler: StandardScaler = None,

bluemath_tk/core/plotting/base_plotting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def plot_scatter(self, ax, **kwargs):
129129
ax.scatter(**kwargs)
130130
self.set_grid(ax)
131131

132+
def plot_pie(self, ax, **kwargs):
133+
ax.pie(**kwargs)
134+
132135
def plot_map(self, ax, **kwargs):
133136
ax.set_global()
134137
ax.coastlines()

bluemath_tk/datamining/_base_datamining.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import Tuple
2+
from typing import Tuple, List
33
import numpy as np
44
import pandas as pd
55
from matplotlib import pyplot as plt
@@ -209,7 +209,7 @@ def plot_selected_centroids(
209209
"""
210210

211211
if (
212-
list(self.data.columns) == list(self.centroids.columns)
212+
len(self.data.columns) == len(self.centroids.columns)
213213
and list(self.data.columns) != []
214214
):
215215
variables_names = list(self.data.columns)
@@ -350,3 +350,60 @@ class BaseReduction(BlueMathModel):
350350
@abstractmethod
351351
def __init__(self) -> None:
352352
super().__init__()
353+
354+
355+
class ClusteringComparator:
356+
"""
357+
Class for comparing clustering models.
358+
"""
359+
360+
def __init__(self, list_of_models: List[BaseClustering]) -> None:
361+
"""
362+
Initializes the ClusteringComparator class.
363+
"""
364+
365+
self.list_of_models = list_of_models
366+
367+
def fit(
368+
self,
369+
data: pd.DataFrame,
370+
directional_variables: List[str] = [],
371+
custom_scale_factor: dict = {},
372+
) -> None:
373+
"""
374+
Fits the clustering models.
375+
"""
376+
377+
for model in self.list_of_models:
378+
if model.__class__.__name__ == "SOM":
379+
model.fit(
380+
data=data,
381+
directional_variables=directional_variables,
382+
)
383+
else:
384+
model.fit(
385+
data=data,
386+
directional_variables=directional_variables,
387+
custom_scale_factor=custom_scale_factor,
388+
)
389+
390+
def plot_selected_centroids(self) -> None:
391+
"""
392+
Plots the selected centroids for the clustering models.
393+
"""
394+
395+
for model in self.list_of_models:
396+
fig, axes = model.plot_selected_centroids()
397+
fig.suptitle(f"Selected centroids for {model.__class__.__name__}")
398+
399+
def plot_data_as_clusters(self, data: pd.DataFrame) -> None:
400+
"""
401+
Plots the data as clusters for the clustering models.
402+
"""
403+
404+
for model in self.list_of_models:
405+
nearest_centroids, _ = model.predict(data=data)
406+
fig, axes = model.plot_data_as_clusters(
407+
data=data, nearest_centroids=nearest_centroids
408+
)
409+
fig.suptitle(f"Data as clusters for {model.__class__.__name__}")

0 commit comments

Comments
 (0)