@@ -261,9 +261,9 @@ def _figshape(n: int) -> tuple[int, int]:
261261 nrows = ncols - n_empty_rows
262262 return int (nrows ), int (ncols )
263263
264- def plot_data (self , type : str ):
264+ def plot_data (self , type : str , show_jitter : bool = True , show_violin : bool = True , ** kwargs ):
265265 """Boxplots for the sampled output."""
266- super ().plot ()
266+ super ().plot (** kwargs )
267267
268268 # calculate number of rows and columns
269269 if type == "samples" :
@@ -314,27 +314,29 @@ def plot_data(self, type: str):
314314 box .set_facecolor (color )
315315
316316 # violin
317- violin_offset = 0.3
318- vp = ax .violinplot (
319- data ,
320- positions = [k + violin_offset for k in range (self .num_groups )],
321- showmeans = True ,
322- showmedians = True ,
323- showextrema = False ,
324- )
317+ if show_violin :
318+ violin_offset = 0.3
319+ vp = ax .violinplot (
320+ data ,
321+ positions = [k + violin_offset for k in range (self .num_groups )],
322+ showmeans = True ,
323+ showmedians = True ,
324+ showextrema = False ,
325+ )
325326
326- for body , color in zip (vp ["bodies" ], colors ):
327- body .set_facecolor (color )
327+ for body , color in zip (vp ["bodies" ], colors ):
328+ body .set_facecolor (color )
328329
329330 # jitter
330- jitter_offset = 0.3
331- jitter_width = 0.02 # Adjust for spacing
332- for kg , g in enumerate (self .groups ):
333- data_g = data [kg ]
334- x_jitter = np .random .normal (kg + jitter_offset , jitter_width , len (data_g ))
335- ax .scatter (x_jitter , data_g , alpha = 0.7 , s = 30 , color = 'white' ,
336- edgecolors = 'black'
337- )
331+ if show_jitter :
332+ jitter_offset = 0.3
333+ jitter_width = 0.02 # Adjust for spacing
334+ for kg , g in enumerate (self .groups ):
335+ data_g = data [kg ]
336+ x_jitter = np .random .normal (kg + jitter_offset , jitter_width , len (data_g ))
337+ ax .scatter (x_jitter , data_g , alpha = 0.7 , s = 30 , color = 'white' ,
338+ edgecolors = 'black'
339+ )
338340
339341 # ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold")
340342 # ax.set_ylim(bottom=0)
@@ -371,7 +373,7 @@ def plot_data(self, type: str):
371373
372374
373375
374- def plot (self ):
376+ def plot (self , ** kwargs ):
375377 """Boxplots for the Sampling sensitivity."""
376- self .plot_data (type = "samples" )
377- self .plot_data (type = "outputs" )
378+ self .plot_data (type = "samples" , ** kwargs )
379+ self .plot_data (type = "outputs" , ** kwargs )
0 commit comments