Skip to content

Commit 205c378

Browse files
author
Javier Tausia
committed
Merge branch 'main' of https://github.com/GeoOcean/BlueMath_tk into feature/downloaders
2 parents a6bacce + df91620 commit 205c378

File tree

11 files changed

+379
-151
lines changed

11 files changed

+379
-151
lines changed

bluemath_tk/core/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212

1313
class BlueMathModel(ABC):
14+
@abstractmethod
1415
def __init__(self):
1516
self._logger = get_file_logger(name=self.__class__.__name__)
1617

bluemath_tk/core/plotting/base_plotting.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from abc import ABC, abstractmethod
2+
from typing import Union
23
import matplotlib.pyplot as plt
4+
from matplotlib.colors import Colormap
35
import plotly.graph_objects as go
46
import cartopy.crs as ccrs
57

@@ -40,6 +42,29 @@ def plot_map(self, markers=None):
4042
"""
4143
pass
4244

45+
def get_list_of_colors_for_colormap(
46+
self, cmap: Union[str, Colormap], num_colors: int
47+
) -> list:
48+
"""
49+
Get a list of colors from a colormap.
50+
51+
Parameters
52+
----------
53+
cmap : str or Colormap
54+
The colormap to use.
55+
num_colors : int
56+
The number of colors to generate.
57+
58+
Returns
59+
-------
60+
list
61+
A list of colors generated from the colormap.
62+
"""
63+
64+
if isinstance(cmap, str):
65+
cmap = plt.get_cmap(cmap)
66+
return [cmap(i) for i in range(0, 256, 256 // num_colors)]
67+
4368

4469
class DefaultStaticPlotting(BasePlotting):
4570
"""

bluemath_tk/datamining/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,19 @@
88
"""
99

1010
# Import essential functions/classes to be available at the package level.
11+
from ._base_datamining import BaseSampling, BaseClustering, BaseReduction
1112
from .mda import MDA
1213
from .lhs import LHS
1314
from .kma import KMA
15+
from .pca import PCA
1416

1517
# Optionally, define the module's `__all__` variable to control what gets imported when using `from module import *`.
16-
__all__ = ["MDA", "LHS", "KMA", "PCA"]
18+
__all__ = [
19+
"BaseSampling",
20+
"BaseClustering",
21+
"BaseReduction",
22+
"MDA",
23+
"LHS",
24+
"KMA",
25+
"PCA",
26+
]
Lines changed: 185 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,145 @@
1+
from abc import abstractmethod
2+
from typing import Tuple
3+
import numpy as np
14
import pandas as pd
5+
from matplotlib import pyplot as plt
26
from ..core.models import BlueMathModel
37
from ..core.plotting.base_plotting import DefaultStaticPlotting
48

59

10+
class BaseSampling(BlueMathModel):
11+
"""
12+
Base class for all sampling BlueMath models.
13+
This class provides the basic structure for all sampling models.
14+
15+
Methods
16+
-------
17+
generate(*args, **kwargs)
18+
"""
19+
20+
@abstractmethod
21+
def __init__(self):
22+
super().__init__()
23+
24+
@abstractmethod
25+
def generate(self, *args, **kwargs) -> pd.DataFrame:
26+
"""
27+
Generates samples.
28+
29+
Parameters
30+
----------
31+
*args : list
32+
Positional arguments.
33+
**kwargs : dict
34+
Keyword arguments.
35+
36+
Returns
37+
-------
38+
pd.DataFrame
39+
The generated samples.
40+
"""
41+
42+
return pd.DataFrame()
43+
44+
645
class BaseClustering(BlueMathModel):
746
"""
847
Base class for all clustering BlueMath models.
948
This class provides the basic structure for all clustering models.
1049
1150
Methods
1251
-------
13-
plot_selected_data(data, centroids, data_color, centroids_color, **kwargs)
52+
fit(*args, **kwargs)
53+
predict(*args, **kwargs)
54+
fit_predict(*args, **kwargs)
55+
plot_selected_data(data_color, centroids_color, **kwargs)
1456
"""
1557

58+
@abstractmethod
1659
def __init__(self):
1760
super().__init__()
1861

19-
def plot_selected_data(
62+
@abstractmethod
63+
def fit(self, *args, **kwargs):
64+
"""
65+
Fits the model to the data.
66+
67+
Parameters
68+
----------
69+
*args : list
70+
Positional arguments.
71+
**kwargs : dict
72+
Keyword arguments.
73+
"""
74+
75+
pass
76+
77+
@abstractmethod
78+
def predict(self, *args, **kwargs):
79+
"""
80+
Predicts the clusters for the provided data.
81+
82+
Parameters
83+
----------
84+
*args : list
85+
Positional arguments.
86+
**kwargs : dict
87+
Keyword arguments.
88+
"""
89+
90+
pass
91+
92+
@abstractmethod
93+
def fit_predict(self, *args, **kwargs):
94+
"""
95+
Fits the model to the data and predicts the clusters.
96+
97+
Parameters
98+
----------
99+
*args : list
100+
Positional arguments.
101+
**kwargs : dict
102+
Keyword arguments.
103+
"""
104+
105+
pass
106+
107+
def plot_selected_centroids(
20108
self,
21-
data: pd.DataFrame,
22-
centroids: pd.DataFrame = None,
23109
data_color: str = "blue",
24110
centroids_color: str = "red",
25111
**kwargs,
26-
):
112+
) -> Tuple[plt.figure, plt.axes]:
27113
"""
28-
Plots selected data and centroids on a scatter plot matrix.
114+
Plots data and selected centroids on a scatter plot matrix.
29115
30-
Parameters:
31-
-----------
32-
data : pd.DataFrame
33-
DataFrame containing the data to be plotted.
34-
centroids : pd.DataFrame, optional
35-
DataFrame containing the centroids to be plotted. Default is None.
116+
Parameters
117+
----------
36118
data_color : str, optional
37119
Color for the data points. Default is "blue".
38120
centroids_color : str, optional
39121
Color for the centroid points. Default is "red".
40122
**kwargs : dict, optional
41123
Additional keyword arguments to be passed to the scatter plot function.
42124
43-
Returns:
44-
--------
45-
fig : matplotlib.figure.Figure
125+
Returns
126+
-------
127+
fig : plt.figure
46128
The figure object containing the plot.
47-
axes : numpy.ndarray
129+
axes : plt.axes
48130
Array of axes objects for the subplots.
49131
50-
Raises:
51-
-------
132+
Raises
133+
------
52134
ValueError
53135
If the data and centroids do not have the same number of columns or if the columns are empty.
54136
"""
55137

56-
if list(data.columns) == list(centroids.columns) and list(data.columns) != []:
57-
variables_names = list(data.columns)
138+
if (
139+
list(self.data.columns) == list(self.centroids.columns)
140+
and list(self.data.columns) != []
141+
):
142+
variables_names = list(self.data.columns)
58143
num_variables = len(variables_names)
59144
else:
60145
raise ValueError(
@@ -74,18 +159,18 @@ def plot_selected_data(
74159
for c2, v2 in enumerate(variables_names[:-1]):
75160
default_static_plot.plot_scatter(
76161
ax=axes[c2, c1],
77-
x=data[v1],
78-
y=data[v2],
162+
x=self.data[v1],
163+
y=self.data[v2],
79164
c=data_color,
80165
s=kwargs.get("s", default_static_plot.default_scatter_size),
81166
alpha=kwargs.get("alpha", 0.7),
82167
)
83168
# Plot centroids in selected ax if passed
84-
if centroids is not None:
169+
if self.centroids is not None:
85170
default_static_plot.plot_scatter(
86171
ax=axes[c2, c1],
87-
x=centroids[v1],
88-
y=centroids[v2],
172+
x=self.centroids[v1],
173+
y=self.centroids[v2],
89174
c=centroids_color,
90175
s=kwargs.get("s", default_static_plot.default_scatter_size),
91176
alpha=kwargs.get("alpha", 0.9),
@@ -101,12 +186,87 @@ def plot_selected_data(
101186

102187
return fig, axes
103188

189+
def plot_data_as_clusters(
190+
self,
191+
data: pd.DataFrame,
192+
closest_centroids: np.ndarray,
193+
**kwargs,
194+
) -> Tuple[plt.figure, plt.axes]:
195+
"""
196+
Plots data as closest clusters.
197+
198+
Parameters
199+
----------
200+
data : pd.DataFrame
201+
The data to plot.
202+
closest_centroids : np.ndarray
203+
The closest centroids.
204+
**kwargs : dict, optional
205+
Additional keyword arguments to be passed to the scatter plot function.
206+
207+
Returns
208+
-------
209+
fig : plt.figure
210+
The figure object containing the plot.
211+
axes : plt.axes
212+
The axes object for the plot.
213+
"""
214+
215+
if (
216+
not data.empty
217+
and list(self.data.columns) != []
218+
and closest_centroids.size > 0
219+
):
220+
variables_names = list(data.columns)
221+
num_variables = len(variables_names)
222+
else:
223+
raise ValueError(
224+
"Data must have columns and closest centroids must have values."
225+
)
226+
227+
# Create figure and axes
228+
default_static_plot = DefaultStaticPlotting()
229+
fig, axes = default_static_plot.get_subplots(
230+
nrows=num_variables - 1,
231+
ncols=num_variables - 1,
232+
sharex=False,
233+
sharey=False,
234+
)
235+
236+
# Gets colors for clusters and append to each closest centroid
237+
colors_for_clusters = default_static_plot.get_list_of_colors_for_colormap(
238+
cmap="viridis", num_colors=self.centroids.shape[0]
239+
)
240+
closest_centroids_colors = [colors_for_clusters[i] for i in closest_centroids]
241+
242+
for c1, v1 in enumerate(variables_names[1:]):
243+
for c2, v2 in enumerate(variables_names[:-1]):
244+
default_static_plot.plot_scatter(
245+
ax=axes[c2, c1],
246+
x=data[v1],
247+
y=data[v2],
248+
c=closest_centroids_colors,
249+
s=kwargs.get("s", default_static_plot.default_scatter_size),
250+
alpha=kwargs.get("alpha", 0.7),
251+
)
252+
if c1 == c2:
253+
axes[c2, c1].set_xlabel(variables_names[c1 + 1])
254+
axes[c2, c1].set_ylabel(variables_names[c2])
255+
elif c1 > c2:
256+
axes[c2, c1].xaxis.set_ticklabels([])
257+
axes[c2, c1].yaxis.set_ticklabels([])
258+
else:
259+
fig.delaxes(axes[c2, c1])
260+
261+
return fig, axes
262+
104263

105264
class BaseReduction(BlueMathModel):
106265
"""
107266
Base class for all dimensionality reduction BlueMath models.
108267
This class provides the basic structure for all dimensionality reduction models.
109268
"""
110269

270+
@abstractmethod
111271
def __init__(self):
112272
super().__init__()

0 commit comments

Comments
 (0)