Skip to content

Commit f0dabcb

Browse files
committed
Merge branch 'develop' of https://github.com/GeoOcean/BlueMath_tk into feature/shytcwaves
2 parents 8111295 + 64e865c commit f0dabcb

File tree

16 files changed

+2774
-124
lines changed

16 files changed

+2774
-124
lines changed

bluemath_tk/core/decorators.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,14 @@ def wrapper(
216216
raise ValueError("Number of iterations must be integer and > 0")
217217
if not isinstance(normalize_data, bool):
218218
raise TypeError("Normalize data must be a boolean")
219-
return func(self, data, directional_variables, num_iteration)
219+
return func(
220+
self,
221+
data,
222+
directional_variables,
223+
custom_scale_factor,
224+
num_iteration,
225+
normalize_data,
226+
)
220227

221228
return wrapper
222229

bluemath_tk/core/operations.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ def normalize(
6565
... }
6666
... )
6767
>>> normalized_data, scale_factor = normalize(data=df)
68+
6869
>>> import numpy as np
6970
>>> import xarray as xr
70-
>>> from bluemath_tk.core.data import normalize
71+
>>> from bluemath_tk.core.operations import normalize
7172
>>> ds = xr.Dataset(
7273
... {
7374
... "Hs": (("time",), np.random.rand(1000) * 7),
@@ -85,6 +86,7 @@ def normalize(
8586
vars_to_normalize = list(data.data_vars)
8687
else:
8788
raise TypeError("Data must be a pandas DataFrame or an xarray Dataset")
89+
8890
normalized_data = data.copy() # Copy data to avoid bad memory replacements
8991
scale_factor = (
9092
custom_scale_factor.copy()
@@ -122,6 +124,7 @@ def normalize(
122124
normalized_data[data_var] = (normalized_data[data_var] - data_var_min) / (
123125
data_var_max - data_var_min
124126
)
127+
125128
return normalized_data, scale_factor
126129

127130

@@ -173,6 +176,7 @@ def denormalize(
173176
... "Dir": [0, 360],
174177
... }
175178
>>> denormalized_data = denormalize(normalized_data=df, scale_factor=scale_factor)
179+
176180
>>> import numpy as np
177181
>>> import xarray as xr
178182
>>> from bluemath_tk.core.operations import denormalize
@@ -204,6 +208,7 @@ def denormalize(
204208
data[data_var] * (scale_factor[data_var][1] - scale_factor[data_var][0])
205209
+ scale_factor[data_var][0]
206210
)
211+
207212
return data
208213

209214

@@ -263,6 +268,7 @@ def standarize(
263268
},
264269
coords=data.coords,
265270
)
271+
266272
return standarized_data, scaler
267273

268274

@@ -308,6 +314,7 @@ def destandarize(
308314
},
309315
coords=standarized_data.coords,
310316
)
317+
311318
return data
312319

313320

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import List, Tuple
2+
3+
import numpy as np
4+
import pandas as pd
5+
from matplotlib.axes import Axes
6+
from matplotlib.figure import Figure
7+
8+
from .base_plotting import DefaultStaticPlotting
9+
from .colors import default_colors
10+
11+
12+
def plot_scatters_in_triangle(
13+
dataframes: List[pd.DataFrame],
14+
data_colors: List[str] = default_colors,
15+
**kwargs,
16+
) -> Tuple[Figure, Axes]:
17+
"""
18+
Plot a scatter plot of the dataframes with axes in a triangle.
19+
20+
Parameters
21+
----------
22+
dataframes : List[pd.DataFrame]
23+
List of dataframes to plot.
24+
data_colors : List[str], optional
25+
List of colors for the dataframes.
26+
**kwargs : dict, optional
27+
Keyword arguments for the scatter plot. Will be passed to the
28+
DefaultStaticPlotting.plot_scatter method, which is the same
29+
as the one in matplotlib.pyplot.scatter.
30+
For example, to change the marker size, you can use:
31+
``plot_scatters_in_triangle(dataframes, s=10)``
32+
33+
Returns
34+
-------
35+
fig : Figure
36+
Figure object.
37+
axes : Axes
38+
Axes object.
39+
"""
40+
41+
# Get the number and names of variables from the first dataframe
42+
variables_names = list(dataframes[0].columns)
43+
num_variables = len(variables_names)
44+
45+
# Check variables names are in all dataframes
46+
for df in dataframes:
47+
if not all(v in df.columns for v in variables_names):
48+
raise ValueError(
49+
f"Variables {variables_names} are not in dataframe {df.columns}."
50+
)
51+
52+
# Create figure and axes
53+
default_static_plot = DefaultStaticPlotting()
54+
fig, axes = default_static_plot.get_subplots(
55+
nrows=num_variables - 1,
56+
ncols=num_variables - 1,
57+
sharex=False,
58+
sharey=False,
59+
)
60+
if isinstance(axes, Axes):
61+
axes = np.array([[axes]])
62+
63+
for c1, v1 in enumerate(variables_names[1:]):
64+
for c2, v2 in enumerate(variables_names[:-1]):
65+
for idf, df in enumerate(dataframes):
66+
default_static_plot.plot_scatter(
67+
ax=axes[c2, c1],
68+
x=df[v1],
69+
y=df[v2],
70+
c=data_colors[idf],
71+
alpha=0.6,
72+
**kwargs,
73+
)
74+
if c1 == c2:
75+
axes[c2, c1].set_xlabel(variables_names[c1 + 1])
76+
axes[c2, c1].set_ylabel(variables_names[c2])
77+
elif c1 > c2:
78+
axes[c2, c1].xaxis.set_ticklabels([])
79+
axes[c2, c1].yaxis.set_ticklabels([])
80+
else:
81+
fig.delaxes(axes[c2, c1])
82+
83+
return fig, axes

bluemath_tk/datamining/_base_datamining.py

Lines changed: 32 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,18 @@
44
import numpy as np
55
import pandas as pd
66
import xarray as xr
7-
from matplotlib import pyplot as plt
87
from matplotlib.axes import Axes
8+
from matplotlib.figure import Figure
99

1010
from ..core.models import BlueMathModel
1111
from ..core.plotting.base_plotting import DefaultStaticPlotting
12+
from ..core.plotting.scatter import plot_scatters_in_triangle
1213

1314

1415
class BaseSampling(BlueMathModel):
1516
"""
1617
Base class for all sampling BlueMath models.
1718
This class provides the basic structure for all sampling models.
18-
19-
Methods
20-
-------
21-
generate : pd.DataFrame
22-
Generates samples.
23-
plot_generated_data : Tuple[plt.figure, plt.axes]
24-
Plots the generated data on a scatter plot matrix.
2519
"""
2620

2721
@abstractmethod
@@ -52,7 +46,7 @@ def plot_generated_data(
5246
self,
5347
data_color: str = "blue",
5448
**kwargs,
55-
) -> Tuple[plt.figure, plt.axes]:
49+
) -> Tuple[Figure, Axes]:
5650
"""
5751
Plots the generated data on a scatter plot matrix.
5852
@@ -65,9 +59,9 @@ def plot_generated_data(
6559
6660
Returns
6761
-------
68-
plt.figure
62+
Figure
6963
The figure object containing the plot.
70-
plt.axes
64+
Axes
7165
Array of axes objects for the subplots.
7266
7367
Raises
@@ -76,40 +70,14 @@ def plot_generated_data(
7670
If the data is empty.
7771
"""
7872

79-
if not self.data.empty:
80-
variables_names = list(self.data.columns)
81-
num_variables = len(variables_names)
82-
else:
73+
if self.data.empty:
8374
raise ValueError("Data must be a non-empty DataFrame with columns to plot.")
8475

85-
# Create figure and axes
86-
default_static_plot = DefaultStaticPlotting()
87-
fig, axes = default_static_plot.get_subplots(
88-
nrows=num_variables - 1,
89-
ncols=num_variables - 1,
90-
sharex=False,
91-
sharey=False,
76+
fig, axes = plot_scatters_in_triangle(
77+
dataframes=[self.data],
78+
data_colors=[data_color],
79+
**kwargs,
9280
)
93-
if isinstance(axes, Axes):
94-
axes = np.array([[axes]])
95-
96-
for c1, v1 in enumerate(variables_names[1:]):
97-
for c2, v2 in enumerate(variables_names[:-1]):
98-
default_static_plot.plot_scatter(
99-
ax=axes[c2, c1],
100-
x=self.data[v1],
101-
y=self.data[v2],
102-
c=data_color,
103-
**kwargs,
104-
)
105-
if c1 == c2:
106-
axes[c2, c1].set_xlabel(variables_names[c1 + 1])
107-
axes[c2, c1].set_ylabel(variables_names[c2])
108-
elif c1 > c2:
109-
axes[c2, c1].xaxis.set_ticklabels([])
110-
axes[c2, c1].yaxis.set_ticklabels([])
111-
else:
112-
fig.delaxes(axes[c2, c1])
11381

11482
return fig, axes
11583

@@ -162,7 +130,7 @@ def fit(
162130
self.custom_scale_factor = custom_scale_factor.copy()
163131
else:
164132
self.logger.info(
165-
"Normalization is disabled. Using default scale factor (0, 1) for all fitting variables."
133+
"Normalization is disabled. Set normalize_data to True to enable normalization."
166134
)
167135
self.custom_scale_factor = {
168136
fitting_variable: (0, 1) for fitting_variable in self.fitting_variables
@@ -198,7 +166,7 @@ def plot_selected_centroids(
198166
centroids_color: str = "red",
199167
plot_text: bool = False,
200168
**kwargs,
201-
) -> Tuple[plt.figure, plt.axes]:
169+
) -> Tuple[Figure, Axes]:
202170
"""
203171
Plots data and selected centroids on a scatter plot matrix.
204172
@@ -215,9 +183,9 @@ def plot_selected_centroids(
215183
216184
Returns
217185
-------
218-
plt.figure
186+
Figure
219187
The figure object containing the plot.
220-
plt.axes
188+
Axes
221189
Array of axes objects for the subplots.
222190
223191
Raises
@@ -231,59 +199,27 @@ def plot_selected_centroids(
231199
and list(self.data.columns) != []
232200
):
233201
variables_names = list(self.data.columns)
234-
num_variables = len(variables_names)
235202
else:
236203
raise ValueError(
237204
"Data and centroids must have the same number of columns > 0."
238205
)
239206

240-
# Create figure and axes
241-
default_static_plot = DefaultStaticPlotting()
242-
fig, axes = default_static_plot.get_subplots(
243-
nrows=num_variables - 1,
244-
ncols=num_variables - 1,
245-
sharex=False,
246-
sharey=False,
207+
fig, axes = plot_scatters_in_triangle(
208+
dataframes=[self.data, self.centroids],
209+
data_colors=[data_color, centroids_color],
210+
**kwargs,
247211
)
248-
if isinstance(axes, Axes):
249-
axes = np.array([[axes]])
250-
251-
for c1, v1 in enumerate(variables_names[1:]):
252-
for c2, v2 in enumerate(variables_names[:-1]):
253-
default_static_plot.plot_scatter(
254-
ax=axes[c2, c1],
255-
x=self.data[v1],
256-
y=self.data[v2],
257-
c=data_color,
258-
alpha=0.6,
259-
**kwargs,
260-
)
261-
if self.centroids is not None:
262-
default_static_plot.plot_scatter(
263-
ax=axes[c2, c1],
264-
x=self.centroids[v1],
265-
y=self.centroids[v2],
266-
c=centroids_color,
267-
alpha=0.9,
268-
**kwargs,
269-
)
270-
if plot_text:
271-
for i in range(self.centroids.shape[0]):
272-
axes[c2, c1].text(
273-
self.centroids[v1][i],
274-
self.centroids[v2][i],
275-
str(i + 1),
276-
fontsize=12,
277-
fontweight="bold",
278-
)
279-
if c1 == c2:
280-
axes[c2, c1].set_xlabel(variables_names[c1 + 1])
281-
axes[c2, c1].set_ylabel(variables_names[c2])
282-
elif c1 > c2:
283-
axes[c2, c1].xaxis.set_ticklabels([])
284-
axes[c2, c1].yaxis.set_ticklabels([])
285-
else:
286-
fig.delaxes(axes[c2, c1])
212+
if plot_text:
213+
for c1, v1 in enumerate(variables_names[1:]):
214+
for c2, v2 in enumerate(variables_names[:-1]):
215+
for i in range(self.centroids.shape[0]):
216+
axes[c2, c1].text(
217+
self.centroids[v1][i],
218+
self.centroids[v2][i],
219+
str(i + 1),
220+
fontsize=12,
221+
fontweight="bold",
222+
)
287223

288224
return fig, axes
289225

@@ -292,7 +228,7 @@ def plot_data_as_clusters(
292228
data: pd.DataFrame,
293229
nearest_centroids: np.ndarray,
294230
**kwargs,
295-
) -> Tuple[plt.figure, plt.axes]:
231+
) -> Tuple[Figure, Axes]:
296232
"""
297233
Plots data as nearest clusters.
298234
@@ -307,9 +243,9 @@ def plot_data_as_clusters(
307243
308244
Returns
309245
-------
310-
plt.figure
246+
Figure
311247
The figure object containing the plot.
312-
plt.axes
248+
Axes
313249
The axes object for the plot.
314250
"""
315251

0 commit comments

Comments
 (0)