44import numpy as np
55import pandas as pd
66import xarray as xr
7- from matplotlib import pyplot as plt
87from matplotlib .axes import Axes
8+ from matplotlib .figure import Figure
99
1010from ..core .models import BlueMathModel
1111from ..core .plotting .base_plotting import DefaultStaticPlotting
12+ from ..core .plotting .scatter import plot_scatters_in_triangle
1213
1314
1415class BaseSampling (BlueMathModel ):
1516 """
1617 Base class for all sampling BlueMath models.
1718 This class provides the basic structure for all sampling models.
18-
19- Methods
20- -------
21- generate : pd.DataFrame
22- Generates samples.
23- plot_generated_data : Tuple[plt.figure, plt.axes]
24- Plots the generated data on a scatter plot matrix.
2519 """
2620
2721 @abstractmethod
@@ -52,7 +46,7 @@ def plot_generated_data(
5246 self ,
5347 data_color : str = "blue" ,
5448 ** kwargs ,
55- ) -> Tuple [plt . figure , plt . axes ]:
49+ ) -> Tuple [Figure , Axes ]:
5650 """
5751 Plots the generated data on a scatter plot matrix.
5852
@@ -65,9 +59,9 @@ def plot_generated_data(
6559
6660 Returns
6761 -------
68- plt.figure
62+ Figure
6963 The figure object containing the plot.
70- plt.axes
64+ Axes
7165 Array of axes objects for the subplots.
7266
7367 Raises
@@ -76,40 +70,14 @@ def plot_generated_data(
7670 If the data is empty.
7771 """
7872
79- if not self .data .empty :
80- variables_names = list (self .data .columns )
81- num_variables = len (variables_names )
82- else :
73+ if self .data .empty :
8374 raise ValueError ("Data must be a non-empty DataFrame with columns to plot." )
8475
85- # Create figure and axes
86- default_static_plot = DefaultStaticPlotting ()
87- fig , axes = default_static_plot .get_subplots (
88- nrows = num_variables - 1 ,
89- ncols = num_variables - 1 ,
90- sharex = False ,
91- sharey = False ,
76+ fig , axes = plot_scatters_in_triangle (
77+ dataframes = [self .data ],
78+ data_colors = [data_color ],
79+ ** kwargs ,
9280 )
93- if isinstance (axes , Axes ):
94- axes = np .array ([[axes ]])
95-
96- for c1 , v1 in enumerate (variables_names [1 :]):
97- for c2 , v2 in enumerate (variables_names [:- 1 ]):
98- default_static_plot .plot_scatter (
99- ax = axes [c2 , c1 ],
100- x = self .data [v1 ],
101- y = self .data [v2 ],
102- c = data_color ,
103- ** kwargs ,
104- )
105- if c1 == c2 :
106- axes [c2 , c1 ].set_xlabel (variables_names [c1 + 1 ])
107- axes [c2 , c1 ].set_ylabel (variables_names [c2 ])
108- elif c1 > c2 :
109- axes [c2 , c1 ].xaxis .set_ticklabels ([])
110- axes [c2 , c1 ].yaxis .set_ticklabels ([])
111- else :
112- fig .delaxes (axes [c2 , c1 ])
11381
11482 return fig , axes
11583
@@ -162,7 +130,7 @@ def fit(
162130 self .custom_scale_factor = custom_scale_factor .copy ()
163131 else :
164132 self .logger .info (
165- "Normalization is disabled. Using default scale factor (0, 1) for all fitting variables ."
133+ "Normalization is disabled. Set normalize_data to True to enable normalization ."
166134 )
167135 self .custom_scale_factor = {
168136 fitting_variable : (0 , 1 ) for fitting_variable in self .fitting_variables
@@ -198,7 +166,7 @@ def plot_selected_centroids(
198166 centroids_color : str = "red" ,
199167 plot_text : bool = False ,
200168 ** kwargs ,
201- ) -> Tuple [plt . figure , plt . axes ]:
169+ ) -> Tuple [Figure , Axes ]:
202170 """
203171 Plots data and selected centroids on a scatter plot matrix.
204172
@@ -215,9 +183,9 @@ def plot_selected_centroids(
215183
216184 Returns
217185 -------
218- plt.figure
186+ Figure
219187 The figure object containing the plot.
220- plt.axes
188+ Axes
221189 Array of axes objects for the subplots.
222190
223191 Raises
@@ -231,59 +199,27 @@ def plot_selected_centroids(
231199 and list (self .data .columns ) != []
232200 ):
233201 variables_names = list (self .data .columns )
234- num_variables = len (variables_names )
235202 else :
236203 raise ValueError (
237204 "Data and centroids must have the same number of columns > 0."
238205 )
239206
240- # Create figure and axes
241- default_static_plot = DefaultStaticPlotting ()
242- fig , axes = default_static_plot .get_subplots (
243- nrows = num_variables - 1 ,
244- ncols = num_variables - 1 ,
245- sharex = False ,
246- sharey = False ,
207+ fig , axes = plot_scatters_in_triangle (
208+ dataframes = [self .data , self .centroids ],
209+ data_colors = [data_color , centroids_color ],
210+ ** kwargs ,
247211 )
248- if isinstance (axes , Axes ):
249- axes = np .array ([[axes ]])
250-
251- for c1 , v1 in enumerate (variables_names [1 :]):
252- for c2 , v2 in enumerate (variables_names [:- 1 ]):
253- default_static_plot .plot_scatter (
254- ax = axes [c2 , c1 ],
255- x = self .data [v1 ],
256- y = self .data [v2 ],
257- c = data_color ,
258- alpha = 0.6 ,
259- ** kwargs ,
260- )
261- if self .centroids is not None :
262- default_static_plot .plot_scatter (
263- ax = axes [c2 , c1 ],
264- x = self .centroids [v1 ],
265- y = self .centroids [v2 ],
266- c = centroids_color ,
267- alpha = 0.9 ,
268- ** kwargs ,
269- )
270- if plot_text :
271- for i in range (self .centroids .shape [0 ]):
272- axes [c2 , c1 ].text (
273- self .centroids [v1 ][i ],
274- self .centroids [v2 ][i ],
275- str (i + 1 ),
276- fontsize = 12 ,
277- fontweight = "bold" ,
278- )
279- if c1 == c2 :
280- axes [c2 , c1 ].set_xlabel (variables_names [c1 + 1 ])
281- axes [c2 , c1 ].set_ylabel (variables_names [c2 ])
282- elif c1 > c2 :
283- axes [c2 , c1 ].xaxis .set_ticklabels ([])
284- axes [c2 , c1 ].yaxis .set_ticklabels ([])
285- else :
286- fig .delaxes (axes [c2 , c1 ])
212+ if plot_text :
213+ for c1 , v1 in enumerate (variables_names [1 :]):
214+ for c2 , v2 in enumerate (variables_names [:- 1 ]):
215+ for i in range (self .centroids .shape [0 ]):
216+ axes [c2 , c1 ].text (
217+ self .centroids [v1 ][i ],
218+ self .centroids [v2 ][i ],
219+ str (i + 1 ),
220+ fontsize = 12 ,
221+ fontweight = "bold" ,
222+ )
287223
288224 return fig , axes
289225
@@ -292,7 +228,7 @@ def plot_data_as_clusters(
292228 data : pd .DataFrame ,
293229 nearest_centroids : np .ndarray ,
294230 ** kwargs ,
295- ) -> Tuple [plt . figure , plt . axes ]:
231+ ) -> Tuple [Figure , Axes ]:
296232 """
297233 Plots data as nearest clusters.
298234
@@ -307,9 +243,9 @@ def plot_data_as_clusters(
307243
308244 Returns
309245 -------
310- plt.figure
246+ Figure
311247 The figure object containing the plot.
312- plt.axes
248+ Axes
313249 The axes object for the plot.
314250 """
315251
0 commit comments