@@ -19,6 +19,7 @@ def load_plot_data(
19
19
data : List ,
20
20
effect_size : str = "mean_diff" ,
21
21
contrast_type : str = None ,
22
+ ci_type : str = "bca" ,
22
23
idx : Optional [List [int ]] = None
23
24
) -> List :
24
25
"""
@@ -32,6 +33,8 @@ def load_plot_data(
32
33
Type of effect size ('mean_diff', 'median_diff', etc.).
33
34
contrast_type: str
34
35
Type of dabest object to plot ('delta2' or 'mini-meta' or 'delta').
36
+ ci_type: str
37
+ Type of confidence interval to plot ('bca' or 'pct')
35
38
idx: Optional[List[int]], default=None
36
39
List of indices to select from the contrast objects if delta-delta experiment.
37
40
If None, only the delta-delta objects are plotted.
@@ -53,14 +56,14 @@ def load_plot_data(
53
56
current_plot_data = getattr (current_contrast , effect_attr )
54
57
bootstraps .append (current_plot_data .results .bootstraps [index ])
55
58
differences .append (current_plot_data .results .difference [index ])
56
- bcalows .append (current_plot_data .results .bca_low [index ])
57
- bcahighs .append (current_plot_data .results .bca_high [index ])
59
+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' ) [index ])
60
+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' ) [index ])
58
61
else :
59
62
contrast_plot_data = [getattr (contrast , effect_attr ) for contrast in data ]
60
63
bootstraps_nested = [result .results .bootstraps .to_list () for result in contrast_plot_data ]
61
64
differences_nested = [result .results .difference .to_list () for result in contrast_plot_data ]
62
- bcalows_nested = [result .results .bca_low .to_list () for result in contrast_plot_data ]
63
- bcahighs_nested = [result .results .bca_high .to_list () for result in contrast_plot_data ]
65
+ bcalows_nested = [result .results .get ( ci_type + '_low' ) .to_list () for result in contrast_plot_data ]
66
+ bcahighs_nested = [result .results .get ( ci_type + '_high' ) .to_list () for result in contrast_plot_data ]
64
67
65
68
bootstraps = [element for innerList in bootstraps_nested for element in innerList ]
66
69
differences = [element for innerList in differences_nested for element in innerList ]
@@ -79,14 +82,14 @@ def load_plot_data(
79
82
current_plot_data = getattr (getattr (current_contrast , effect_attr ), contrast_attr )
80
83
bootstraps .append (current_plot_data .bootstraps_delta_delta )
81
84
differences .append (current_plot_data .difference )
82
- bcalows .append (current_plot_data .bca_low )
83
- bcahighs .append (current_plot_data .bca_high )
85
+ bcalows .append (current_plot_data .results . get ( ci_type + '_low' )[ 0 ] )
86
+ bcahighs .append (current_plot_data .results . get ( ci_type + '_high' )[ 0 ] )
84
87
elif index == 0 or index == 1 :
85
88
current_plot_data = getattr (current_contrast , effect_attr )
86
89
bootstraps .append (current_plot_data .results .bootstraps [index ])
87
90
differences .append (current_plot_data .results .difference [index ])
88
- bcalows .append (current_plot_data .results .bca_low [index ])
89
- bcahighs .append (current_plot_data .results .bca_high [index ])
91
+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' ) [index ])
92
+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' ) [index ])
90
93
else :
91
94
raise ValueError ("The selected indices must be 0, 1, or 2." )
92
95
else :
@@ -95,14 +98,14 @@ def load_plot_data(
95
98
current_plot_data = getattr (getattr (current_contrast , effect_attr ), contrast_attr )
96
99
bootstraps .append (current_plot_data .bootstraps_weighted_delta )
97
100
differences .append (current_plot_data .difference )
98
- bcalows .append (current_plot_data .results .bca_low )
99
- bcahighs .append (current_plot_data .results .bca_high )
101
+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' )[ 0 ] )
102
+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' )[ 0 ] )
100
103
elif index < num_of_groups :
101
104
current_plot_data = getattr (current_contrast , effect_attr )
102
105
bootstraps .append (current_plot_data .results .bootstraps [index ])
103
106
differences .append (current_plot_data .results .difference [index ])
104
- bcalows .append (current_plot_data .results .bca_low [index ])
105
- bcahighs .append (current_plot_data .results .bca_high [index ])
107
+ bcalows .append (current_plot_data .results .get ( ci_type + '_low' ) [index ])
108
+ bcahighs .append (current_plot_data .results .get ( ci_type + '_high' ) [index ])
106
109
else :
107
110
msg1 = "There are only {} groups (starting from zero) in this dabest object. " .format (num_of_groups )
108
111
msg2 = "The idx given is {}." .format (index )
@@ -113,8 +116,8 @@ def load_plot_data(
113
116
114
117
bootstraps = [getattr (result , f"bootstraps_{ attribute_suffix } " ) for result in contrast_plot_data ]
115
118
differences = [result .difference for result in contrast_plot_data ]
116
- bcalows = [result .bca_low for result in contrast_plot_data ]
117
- bcahighs = [result .bca_high for result in contrast_plot_data ]
119
+ bcalows = [result .results . get ( ci_type + '_low' )[ 0 ] for result in contrast_plot_data ]
120
+ bcahighs = [result .results . get ( ci_type + '_high' )[ 0 ] for result in contrast_plot_data ]
118
121
119
122
return bootstraps , differences , bcalows , bcahighs
120
123
@@ -124,6 +127,7 @@ def check_for_errors(
124
127
ax ,
125
128
fig_size ,
126
129
effect_size ,
130
+ ci_type ,
127
131
horizontal ,
128
132
marker_size ,
129
133
custom_palette ,
@@ -140,6 +144,7 @@ def check_for_errors(
140
144
yticks ,
141
145
yticklabels ,
142
146
remove_spines ,
147
+ summary_bars ,
143
148
) -> str :
144
149
145
150
# Contrasts
@@ -203,6 +208,10 @@ def check_for_errors(
203
208
raise ValueError ("The `effect_size` argument must be `mean_diff` for mini-meta analyses." )
204
209
if data [0 ].delta2 and effect_size not in ['mean_diff' , 'hedges_g' , 'delta_g' ]:
205
210
raise ValueError ("The `effect_size` argument must be `mean_diff`, `hedges_g`, or `delta_g` for delta-delta analyses." )
211
+
212
+ # CI type
213
+ if ci_type not in ('bca' , 'pct' ):
214
+ raise TypeError ("`ci_type` must be either 'bca' or 'pct'." )
206
215
207
216
# Horizontal
208
217
if not isinstance (horizontal , bool ):
@@ -277,6 +286,15 @@ def check_for_errors(
277
286
if not isinstance (remove_spines , bool ):
278
287
raise TypeError ("`remove_spines` must be a boolean value." )
279
288
289
+ # Summary bars
290
+ if summary_bars is not None :
291
+ if not isinstance (summary_bars , list | tuple ):
292
+ raise TypeError ("summary_bars must be a list/tuple of indices (ints)." )
293
+ if not all (isinstance (i , int ) for i in summary_bars ):
294
+ raise TypeError ("summary_bars must be a list/tuple of indices (ints)." )
295
+ if any (i >= number_of_curves_to_plot for i in summary_bars ):
296
+ raise ValueError ("Index {} chosen is out of range for the contrast objects." .format ([i for i in summary_bars if i >= number_of_curves_to_plot ]))
297
+
280
298
return contrast_type
281
299
282
300
@@ -288,6 +306,7 @@ def get_kwargs(
288
306
errorbar_kwargs ,
289
307
delta_text_kwargs ,
290
308
contrast_bars_kwargs ,
309
+ summary_bars_kwargs ,
291
310
marker_size
292
311
):
293
312
from .misc_tools import merge_two_dicts
@@ -369,9 +388,21 @@ def get_kwargs(
369
388
else :
370
389
contrast_bars_kwargs = merge_two_dicts (default_contrast_bars_kwargs , contrast_bars_kwargs )
371
390
391
+ # Summary bars kwargs.
392
+ default_summary_bars_kwargs = {
393
+ "span_ax" : False ,
394
+ "color" : None ,
395
+ "alpha" : 0.15 ,
396
+ "zorder" :- 3
397
+ }
398
+ if summary_bars_kwargs is None :
399
+ summary_bars_kwargs = default_summary_bars_kwargs
400
+ else :
401
+ summary_bars_kwargs = merge_two_dicts (default_summary_bars_kwargs , summary_bars_kwargs )
402
+
372
403
373
404
return (violin_kwargs , zeroline_kwargs , marker_kwargs , errorbar_kwargs ,
374
- delta_text_kwargs , contrast_bars_kwargs )
405
+ delta_text_kwargs , contrast_bars_kwargs , summary_bars_kwargs )
375
406
376
407
377
408
@@ -407,6 +438,7 @@ def forest_plot(
407
438
ax : Optional [plt .Axes ] = None ,
408
439
fig_size : tuple [int , int ] = None ,
409
440
effect_size : str = "mean_diff" ,
441
+ ci_type = 'bca' ,
410
442
horizontal : bool = False ,
411
443
412
444
marker_size : int = 10 ,
@@ -431,6 +463,8 @@ def forest_plot(
431
463
432
464
contrast_bars : bool = True ,
433
465
contrast_bars_kwargs : dict = None ,
466
+ summary_bars : list | tuple = None ,
467
+ summary_bars_kwargs : dict = None ,
434
468
435
469
violin_kwargs : Optional [dict ] = None ,
436
470
zeroline_kwargs : Optional [dict ] = None ,
@@ -455,6 +489,8 @@ def forest_plot(
455
489
Figure size for the plot.
456
490
effect_size : str
457
491
Type of effect size to plot (e.g., 'mean_diff', `hedges_g` or 'delta_g').
492
+ ci_type : str
493
+ Type of confidence interval to plot (bca' or 'pct')
458
494
horizontal : bool, default=False
459
495
If True, the plot will be horizontal.
460
496
marker_size : int, default=12
@@ -495,6 +531,10 @@ def forest_plot(
495
531
If True, it adds bars from the zeroline to the effect size curve.
496
532
contrast_bars_kwargs : dict, default=None
497
533
Additional keyword arguments for the contrast_bars.
534
+ summary_bars: list | tuple, default=None,
535
+ If True, it adds summary bars to the relevant effect size curves.
536
+ summary_bars_kwargs : dict, default=None,
537
+ Additional keyword arguments for the summary_bars.
498
538
violin_kwargs : Optional[dict], default=None
499
539
Additional arguments for violin plot customization.
500
540
zeroline_kwargs : Optional[dict], default=None
@@ -519,6 +559,7 @@ def forest_plot(
519
559
ax = ax ,
520
560
fig_size = fig_size ,
521
561
effect_size = effect_size ,
562
+ ci_type = ci_type ,
522
563
horizontal = horizontal ,
523
564
marker_size = marker_size ,
524
565
custom_palette = custom_palette ,
@@ -535,16 +576,17 @@ def forest_plot(
535
576
yticks = yticks ,
536
577
yticklabels = yticklabels ,
537
578
remove_spines = remove_spines ,
579
+ summary_bars = summary_bars ,
538
580
)
539
581
540
582
# Load plot data and extract info
541
583
bootstraps , differences , bcalows , bcahighs = load_plot_data (
542
584
data = data ,
543
585
effect_size = effect_size ,
544
586
contrast_type = contrast_type ,
587
+ ci_type = ci_type ,
545
588
idx = idx
546
589
)
547
-
548
590
# Adjust figure size based on orientation
549
591
number_of_curves_to_plot = len (bootstraps )
550
592
# number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
@@ -556,16 +598,17 @@ def forest_plot(
556
598
fig , ax = plt .subplots (figsize = fig_size )
557
599
558
600
# Get Kwargs
559
- (violin_kwargs , zeroline_kwargs , marker_kwargs ,
560
- errorbar_kwargs , delta_text_kwargs , contrast_bars_kwargs ) = get_kwargs (
561
- violin_kwargs = violin_kwargs ,
562
- zeroline_kwargs = zeroline_kwargs ,
563
- horizontal = horizontal ,
564
- marker_kwargs = marker_kwargs ,
565
- errorbar_kwargs = errorbar_kwargs ,
566
- delta_text_kwargs = delta_text_kwargs ,
567
- contrast_bars_kwargs = contrast_bars_kwargs ,
568
- marker_size = marker_size
601
+ (violin_kwargs , zeroline_kwargs , marker_kwargs , errorbar_kwargs ,
602
+ delta_text_kwargs , contrast_bars_kwargs , summary_bars_kwargs ) = get_kwargs (
603
+ violin_kwargs = violin_kwargs ,
604
+ zeroline_kwargs = zeroline_kwargs ,
605
+ horizontal = horizontal ,
606
+ marker_kwargs = marker_kwargs ,
607
+ errorbar_kwargs = errorbar_kwargs ,
608
+ delta_text_kwargs = delta_text_kwargs ,
609
+ contrast_bars_kwargs = contrast_bars_kwargs ,
610
+ summary_bars_kwargs = summary_bars_kwargs ,
611
+ marker_size = marker_size
569
612
)
570
613
571
614
# Plot the violins and make adjustments
@@ -719,6 +762,42 @@ def forest_plot(
719
762
else :
720
763
ax .add_patch (mpatches .Rectangle ((x , 0 ), 0.25 , y , color = bar_colors [x - 1 ], ** contrast_bars_kwargs ))
721
764
765
+ # Summary bars
766
+ if summary_bars :
767
+ _bar_color = summary_bars_kwargs .pop ('color' )
768
+ if _bar_color is not None :
769
+ bar_colors = [_bar_color ] * number_of_curves_to_plot
770
+ else :
771
+ bar_colors = violin_colors
772
+
773
+ span_ax = summary_bars_kwargs .pop ("span_ax" )
774
+ summary_xmin , summary_xmax = ax .get_xlim ()
775
+ summary_ymin , summary_ymax = ax .get_ylim ()
776
+
777
+ for summary_index in summary_bars :
778
+ if span_ax == True :
779
+ starting_location = summary_ymin if horizontal else summary_xmin
780
+ else :
781
+ starting_location = summary_index + 1
782
+
783
+ summary_color = bar_colors [summary_index ]
784
+ summary_ci_low , summary_ci_high = bcalows [summary_index ], bcahighs [summary_index ]
785
+
786
+ if horizontal :
787
+ ax .add_patch (mpatches .Rectangle (
788
+ (summary_ci_low , starting_location ),
789
+ summary_ci_high - summary_ci_low , summary_ymax + 1 ,
790
+ color = summary_color ,
791
+ ** summary_bars_kwargs )
792
+ )
793
+ else :
794
+ ax .add_patch (mpatches .Rectangle (
795
+ (starting_location , summary_ci_low ),
796
+ summary_xmax + 1 , summary_ci_high - summary_ci_low ,
797
+ color = summary_color ,
798
+ ** summary_bars_kwargs )
799
+ )
800
+
722
801
## Invert Y-axis if horizontal
723
802
if horizontal :
724
803
ax .invert_yaxis ()
0 commit comments