1+ from abc import abstractmethod
2+ from typing import Tuple
3+ import numpy as np
14import pandas as pd
5+ from matplotlib import pyplot as plt
26from ..core .models import BlueMathModel
37from ..core .plotting .base_plotting import DefaultStaticPlotting
48
59
10+ class BaseSampling (BlueMathModel ):
11+ """
12+ Base class for all sampling BlueMath models.
13+ This class provides the basic structure for all sampling models.
14+
15+ Methods
16+ -------
17+ generate(*args, **kwargs)
18+ """
19+
20+ @abstractmethod
21+ def __init__ (self ):
22+ super ().__init__ ()
23+
24+ @abstractmethod
25+ def generate (self , * args , ** kwargs ) -> pd .DataFrame :
26+ """
27+ Generates samples.
28+
29+ Parameters
30+ ----------
31+ *args : list
32+ Positional arguments.
33+ **kwargs : dict
34+ Keyword arguments.
35+
36+ Returns
37+ -------
38+ pd.DataFrame
39+ The generated samples.
40+ """
41+
42+ return pd .DataFrame ()
43+
44+
645class BaseClustering (BlueMathModel ):
746 """
847 Base class for all clustering BlueMath models.
948 This class provides the basic structure for all clustering models.
1049
1150 Methods
1251 -------
13- plot_selected_data(data, centroids, data_color, centroids_color, **kwargs)
52+ fit(*args, **kwargs)
53+ predict(*args, **kwargs)
54+ fit_predict(*args, **kwargs)
55+ plot_selected_data(data_color, centroids_color, **kwargs)
1456 """
1557
58+ @abstractmethod
1659 def __init__ (self ):
1760 super ().__init__ ()
1861
19- def plot_selected_data (
62+ @abstractmethod
63+ def fit (self , * args , ** kwargs ):
64+ """
65+ Fits the model to the data.
66+
67+ Parameters
68+ ----------
69+ *args : list
70+ Positional arguments.
71+ **kwargs : dict
72+ Keyword arguments.
73+ """
74+
75+ pass
76+
77+ @abstractmethod
78+ def predict (self , * args , ** kwargs ):
79+ """
80+ Predicts the clusters for the provided data.
81+
82+ Parameters
83+ ----------
84+ *args : list
85+ Positional arguments.
86+ **kwargs : dict
87+ Keyword arguments.
88+ """
89+
90+ pass
91+
92+ @abstractmethod
93+ def fit_predict (self , * args , ** kwargs ):
94+ """
95+ Fits the model to the data and predicts the clusters.
96+
97+ Parameters
98+ ----------
99+ *args : list
100+ Positional arguments.
101+ **kwargs : dict
102+ Keyword arguments.
103+ """
104+
105+ pass
106+
107+ def plot_selected_centroids (
20108 self ,
21- data : pd .DataFrame ,
22- centroids : pd .DataFrame = None ,
23109 data_color : str = "blue" ,
24110 centroids_color : str = "red" ,
25111 ** kwargs ,
26- ):
112+ ) -> Tuple [ plt . figure , plt . axes ] :
27113 """
28- Plots selected data and centroids on a scatter plot matrix.
114+ Plots data and selected centroids on a scatter plot matrix.
29115
30- Parameters:
31- -----------
32- data : pd.DataFrame
33- DataFrame containing the data to be plotted.
34- centroids : pd.DataFrame, optional
35- DataFrame containing the centroids to be plotted. Default is None.
116+ Parameters
117+ ----------
36118 data_color : str, optional
37119 Color for the data points. Default is "blue".
38120 centroids_color : str, optional
39121 Color for the centroid points. Default is "red".
40122 **kwargs : dict, optional
41123 Additional keyword arguments to be passed to the scatter plot function.
42124
43- Returns:
44- --------
45- fig : matplotlib .figure.Figure
125+ Returns
126+ -------
127+ fig : plt .figure
46128 The figure object containing the plot.
47- axes : numpy.ndarray
129+ axes : plt.axes
48130 Array of axes objects for the subplots.
49131
50- Raises:
51- -------
132+ Raises
133+ ------
52134 ValueError
53135 If the data and centroids do not have the same number of columns or if the columns are empty.
54136 """
55137
56- if list (data .columns ) == list (centroids .columns ) and list (data .columns ) != []:
57- variables_names = list (data .columns )
138+ if (
139+ list (self .data .columns ) == list (self .centroids .columns )
140+ and list (self .data .columns ) != []
141+ ):
142+ variables_names = list (self .data .columns )
58143 num_variables = len (variables_names )
59144 else :
60145 raise ValueError (
@@ -74,18 +159,18 @@ def plot_selected_data(
74159 for c2 , v2 in enumerate (variables_names [:- 1 ]):
75160 default_static_plot .plot_scatter (
76161 ax = axes [c2 , c1 ],
77- x = data [v1 ],
78- y = data [v2 ],
162+ x = self . data [v1 ],
163+ y = self . data [v2 ],
79164 c = data_color ,
80165 s = kwargs .get ("s" , default_static_plot .default_scatter_size ),
81166 alpha = kwargs .get ("alpha" , 0.7 ),
82167 )
83168 # Plot centroids in selected ax if passed
84- if centroids is not None :
169+ if self . centroids is not None :
85170 default_static_plot .plot_scatter (
86171 ax = axes [c2 , c1 ],
87- x = centroids [v1 ],
88- y = centroids [v2 ],
172+ x = self . centroids [v1 ],
173+ y = self . centroids [v2 ],
89174 c = centroids_color ,
90175 s = kwargs .get ("s" , default_static_plot .default_scatter_size ),
91176 alpha = kwargs .get ("alpha" , 0.9 ),
@@ -101,12 +186,87 @@ def plot_selected_data(
101186
102187 return fig , axes
103188
189+ def plot_data_as_clusters (
190+ self ,
191+ data : pd .DataFrame ,
192+ closest_centroids : np .ndarray ,
193+ ** kwargs ,
194+ ) -> Tuple [plt .figure , plt .axes ]:
195+ """
196+ Plots data as closest clusters.
197+
198+ Parameters
199+ ----------
200+ data : pd.DataFrame
201+ The data to plot.
202+ closest_centroids : np.ndarray
203+ The closest centroids.
204+ **kwargs : dict, optional
205+ Additional keyword arguments to be passed to the scatter plot function.
206+
207+ Returns
208+ -------
209+ fig : plt.figure
210+ The figure object containing the plot.
211+ axes : plt.axes
212+ The axes object for the plot.
213+ """
214+
215+ if (
216+ not data .empty
217+ and list (self .data .columns ) != []
218+ and closest_centroids .size > 0
219+ ):
220+ variables_names = list (data .columns )
221+ num_variables = len (variables_names )
222+ else :
223+ raise ValueError (
224+ "Data must have columns and closest centroids must have values."
225+ )
226+
227+ # Create figure and axes
228+ default_static_plot = DefaultStaticPlotting ()
229+ fig , axes = default_static_plot .get_subplots (
230+ nrows = num_variables - 1 ,
231+ ncols = num_variables - 1 ,
232+ sharex = False ,
233+ sharey = False ,
234+ )
235+
236+ # Gets colors for clusters and append to each closest centroid
237+ colors_for_clusters = default_static_plot .get_list_of_colors_for_colormap (
238+ cmap = "viridis" , num_colors = self .centroids .shape [0 ]
239+ )
240+ closest_centroids_colors = [colors_for_clusters [i ] for i in closest_centroids ]
241+
242+ for c1 , v1 in enumerate (variables_names [1 :]):
243+ for c2 , v2 in enumerate (variables_names [:- 1 ]):
244+ default_static_plot .plot_scatter (
245+ ax = axes [c2 , c1 ],
246+ x = data [v1 ],
247+ y = data [v2 ],
248+ c = closest_centroids_colors ,
249+ s = kwargs .get ("s" , default_static_plot .default_scatter_size ),
250+ alpha = kwargs .get ("alpha" , 0.7 ),
251+ )
252+ if c1 == c2 :
253+ axes [c2 , c1 ].set_xlabel (variables_names [c1 + 1 ])
254+ axes [c2 , c1 ].set_ylabel (variables_names [c2 ])
255+ elif c1 > c2 :
256+ axes [c2 , c1 ].xaxis .set_ticklabels ([])
257+ axes [c2 , c1 ].yaxis .set_ticklabels ([])
258+ else :
259+ fig .delaxes (axes [c2 , c1 ])
260+
261+ return fig , axes
262+
104263
105264class BaseReduction (BlueMathModel ):
106265 """
107266 Base class for all dimensionality reduction BlueMath models.
108267 This class provides the basic structure for all dimensionality reduction models.
109268 """
110269
270+ @abstractmethod
111271 def __init__ (self ):
112272 super ().__init__ ()
0 commit comments