Skip to content

Commit a6e0d0b

Browse files
committed
add summary bars to forest plot
1 parent 3158359 commit a6e0d0b

8 files changed

+326
-53
lines changed

dabest/forest_plot.py

+105-26
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def load_plot_data(
1919
data: List,
2020
effect_size: str = "mean_diff",
2121
contrast_type: str = None,
22+
ci_type: str = "bca",
2223
idx: Optional[List[int]] = None
2324
) -> List:
2425
"""
@@ -32,6 +33,8 @@ def load_plot_data(
3233
Type of effect size ('mean_diff', 'median_diff', etc.).
3334
contrast_type: str
3435
Type of dabest object to plot ('delta2' or 'mini-meta' or 'delta').
36+
ci_type: str
37+
Type of confidence interval to plot ('bca' or 'pct')
3538
idx: Optional[List[int]], default=None
3639
List of indices to select from the contrast objects if delta-delta experiment.
3740
If None, only the delta-delta objects are plotted.
@@ -53,14 +56,14 @@ def load_plot_data(
5356
current_plot_data = getattr(current_contrast, effect_attr)
5457
bootstraps.append(current_plot_data.results.bootstraps[index])
5558
differences.append(current_plot_data.results.difference[index])
56-
bcalows.append(current_plot_data.results.bca_low[index])
57-
bcahighs.append(current_plot_data.results.bca_high[index])
59+
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index])
60+
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index])
5861
else:
5962
contrast_plot_data = [getattr(contrast, effect_attr) for contrast in data]
6063
bootstraps_nested = [result.results.bootstraps.to_list() for result in contrast_plot_data]
6164
differences_nested = [result.results.difference.to_list() for result in contrast_plot_data]
62-
bcalows_nested = [result.results.bca_low.to_list() for result in contrast_plot_data]
63-
bcahighs_nested = [result.results.bca_high.to_list() for result in contrast_plot_data]
65+
bcalows_nested = [result.results.get(ci_type+'_low').to_list() for result in contrast_plot_data]
66+
bcahighs_nested = [result.results.get(ci_type+'_high').to_list() for result in contrast_plot_data]
6467

6568
bootstraps = [element for innerList in bootstraps_nested for element in innerList]
6669
differences = [element for innerList in differences_nested for element in innerList]
@@ -79,14 +82,14 @@ def load_plot_data(
7982
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
8083
bootstraps.append(current_plot_data.bootstraps_delta_delta)
8184
differences.append(current_plot_data.difference)
82-
bcalows.append(current_plot_data.bca_low)
83-
bcahighs.append(current_plot_data.bca_high)
85+
bcalows.append(current_plot_data.results.get(ci_type+'_low')[0])
86+
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[0])
8487
elif index == 0 or index == 1:
8588
current_plot_data = getattr(current_contrast, effect_attr)
8689
bootstraps.append(current_plot_data.results.bootstraps[index])
8790
differences.append(current_plot_data.results.difference[index])
88-
bcalows.append(current_plot_data.results.bca_low[index])
89-
bcahighs.append(current_plot_data.results.bca_high[index])
91+
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index])
92+
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index])
9093
else:
9194
raise ValueError("The selected indices must be 0, 1, or 2.")
9295
else:
@@ -95,14 +98,14 @@ def load_plot_data(
9598
current_plot_data = getattr(getattr(current_contrast, effect_attr), contrast_attr)
9699
bootstraps.append(current_plot_data.bootstraps_weighted_delta)
97100
differences.append(current_plot_data.difference)
98-
bcalows.append(current_plot_data.results.bca_low)
99-
bcahighs.append(current_plot_data.results.bca_high)
101+
bcalows.append(current_plot_data.results.get(ci_type+'_low')[0])
102+
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[0])
100103
elif index < num_of_groups:
101104
current_plot_data = getattr(current_contrast, effect_attr)
102105
bootstraps.append(current_plot_data.results.bootstraps[index])
103106
differences.append(current_plot_data.results.difference[index])
104-
bcalows.append(current_plot_data.results.bca_low[index])
105-
bcahighs.append(current_plot_data.results.bca_high[index])
107+
bcalows.append(current_plot_data.results.get(ci_type+'_low')[index])
108+
bcahighs.append(current_plot_data.results.get(ci_type+'_high')[index])
106109
else:
107110
msg1 = "There are only {} groups (starting from zero) in this dabest object. ".format(num_of_groups)
108111
msg2 = "The idx given is {}.".format(index)
@@ -113,8 +116,8 @@ def load_plot_data(
113116

114117
bootstraps = [getattr(result, f"bootstraps_{attribute_suffix}") for result in contrast_plot_data]
115118
differences = [result.difference for result in contrast_plot_data]
116-
bcalows = [result.bca_low for result in contrast_plot_data]
117-
bcahighs = [result.bca_high for result in contrast_plot_data]
119+
bcalows = [result.results.get(ci_type+'_low')[0] for result in contrast_plot_data]
120+
bcahighs = [result.results.get(ci_type+'_high')[0] for result in contrast_plot_data]
118121

119122
return bootstraps, differences, bcalows, bcahighs
120123

@@ -124,6 +127,7 @@ def check_for_errors(
124127
ax,
125128
fig_size,
126129
effect_size,
130+
ci_type,
127131
horizontal,
128132
marker_size,
129133
custom_palette,
@@ -140,6 +144,7 @@ def check_for_errors(
140144
yticks,
141145
yticklabels,
142146
remove_spines,
147+
summary_bars,
143148
) -> str:
144149

145150
# Contrasts
@@ -203,6 +208,10 @@ def check_for_errors(
203208
raise ValueError("The `effect_size` argument must be `mean_diff` for mini-meta analyses.")
204209
if data[0].delta2 and effect_size not in ['mean_diff', 'hedges_g', 'delta_g']:
205210
raise ValueError("The `effect_size` argument must be `mean_diff`, `hedges_g`, or `delta_g` for delta-delta analyses.")
211+
212+
# CI type
213+
if ci_type not in ('bca', 'pct'):
214+
raise TypeError("`ci_type` must be either 'bca' or 'pct'.")
206215

207216
# Horizontal
208217
if not isinstance(horizontal, bool):
@@ -277,6 +286,15 @@ def check_for_errors(
277286
if not isinstance(remove_spines, bool):
278287
raise TypeError("`remove_spines` must be a boolean value.")
279288

289+
# Summary bars
290+
if summary_bars is not None:
291+
if not isinstance(summary_bars, list | tuple):
292+
raise TypeError("summary_bars must be a list/tuple of indices (ints).")
293+
if not all(isinstance(i, int) for i in summary_bars):
294+
raise TypeError("summary_bars must be a list/tuple of indices (ints).")
295+
if any(i >= number_of_curves_to_plot for i in summary_bars):
296+
raise ValueError("Index {} chosen is out of range for the contrast objects.".format([i for i in summary_bars if i >= number_of_curves_to_plot]))
297+
280298
return contrast_type
281299

282300

@@ -288,6 +306,7 @@ def get_kwargs(
288306
errorbar_kwargs,
289307
delta_text_kwargs,
290308
contrast_bars_kwargs,
309+
summary_bars_kwargs,
291310
marker_size
292311
):
293312
from .misc_tools import merge_two_dicts
@@ -369,9 +388,21 @@ def get_kwargs(
369388
else:
370389
contrast_bars_kwargs = merge_two_dicts(default_contrast_bars_kwargs, contrast_bars_kwargs)
371390

391+
# Summary bars kwargs.
392+
default_summary_bars_kwargs = {
393+
"span_ax": False,
394+
"color": None,
395+
"alpha": 0.15,
396+
"zorder":-3
397+
}
398+
if summary_bars_kwargs is None:
399+
summary_bars_kwargs = default_summary_bars_kwargs
400+
else:
401+
summary_bars_kwargs = merge_two_dicts(default_summary_bars_kwargs, summary_bars_kwargs)
402+
372403

373404
return (violin_kwargs, zeroline_kwargs, marker_kwargs, errorbar_kwargs,
374-
delta_text_kwargs, contrast_bars_kwargs)
405+
delta_text_kwargs, contrast_bars_kwargs, summary_bars_kwargs)
375406

376407

377408

@@ -407,6 +438,7 @@ def forest_plot(
407438
ax: Optional[plt.Axes] = None,
408439
fig_size: tuple[int, int] = None,
409440
effect_size: str = "mean_diff",
441+
ci_type='bca',
410442
horizontal: bool = False,
411443

412444
marker_size: int = 10,
@@ -431,6 +463,8 @@ def forest_plot(
431463

432464
contrast_bars: bool = True,
433465
contrast_bars_kwargs: dict = None,
466+
summary_bars: list|tuple = None,
467+
summary_bars_kwargs: dict = None,
434468

435469
violin_kwargs: Optional[dict] = None,
436470
zeroline_kwargs: Optional[dict] = None,
@@ -455,6 +489,8 @@ def forest_plot(
455489
Figure size for the plot.
456490
effect_size : str
457491
Type of effect size to plot (e.g., 'mean_diff', `hedges_g` or 'delta_g').
492+
ci_type : str
493+
Type of confidence interval to plot (bca' or 'pct')
458494
horizontal : bool, default=False
459495
If True, the plot will be horizontal.
460496
marker_size : int, default=12
@@ -495,6 +531,10 @@ def forest_plot(
495531
If True, it adds bars from the zeroline to the effect size curve.
496532
contrast_bars_kwargs : dict, default=None
497533
Additional keyword arguments for the contrast_bars.
534+
summary_bars: list | tuple, default=None,
535+
If True, it adds summary bars to the relevant effect size curves.
536+
summary_bars_kwargs : dict, default=None,
537+
Additional keyword arguments for the summary_bars.
498538
violin_kwargs : Optional[dict], default=None
499539
Additional arguments for violin plot customization.
500540
zeroline_kwargs : Optional[dict], default=None
@@ -519,6 +559,7 @@ def forest_plot(
519559
ax = ax,
520560
fig_size = fig_size,
521561
effect_size = effect_size,
562+
ci_type = ci_type,
522563
horizontal = horizontal,
523564
marker_size = marker_size,
524565
custom_palette = custom_palette,
@@ -535,16 +576,17 @@ def forest_plot(
535576
yticks = yticks,
536577
yticklabels = yticklabels,
537578
remove_spines = remove_spines,
579+
summary_bars = summary_bars,
538580
)
539581

540582
# Load plot data and extract info
541583
bootstraps, differences, bcalows, bcahighs = load_plot_data(
542584
data = data,
543585
effect_size = effect_size,
544586
contrast_type = contrast_type,
587+
ci_type = ci_type,
545588
idx = idx
546589
)
547-
548590
# Adjust figure size based on orientation
549591
number_of_curves_to_plot = len(bootstraps)
550592
# number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
@@ -556,16 +598,17 @@ def forest_plot(
556598
fig, ax = plt.subplots(figsize=fig_size)
557599

558600
# Get Kwargs
559-
(violin_kwargs, zeroline_kwargs, marker_kwargs,
560-
errorbar_kwargs, delta_text_kwargs, contrast_bars_kwargs) = get_kwargs(
561-
violin_kwargs = violin_kwargs,
562-
zeroline_kwargs = zeroline_kwargs,
563-
horizontal = horizontal,
564-
marker_kwargs = marker_kwargs,
565-
errorbar_kwargs = errorbar_kwargs,
566-
delta_text_kwargs = delta_text_kwargs,
567-
contrast_bars_kwargs = contrast_bars_kwargs,
568-
marker_size = marker_size
601+
(violin_kwargs, zeroline_kwargs, marker_kwargs, errorbar_kwargs,
602+
delta_text_kwargs, contrast_bars_kwargs, summary_bars_kwargs) = get_kwargs(
603+
violin_kwargs = violin_kwargs,
604+
zeroline_kwargs = zeroline_kwargs,
605+
horizontal = horizontal,
606+
marker_kwargs = marker_kwargs,
607+
errorbar_kwargs = errorbar_kwargs,
608+
delta_text_kwargs = delta_text_kwargs,
609+
contrast_bars_kwargs = contrast_bars_kwargs,
610+
summary_bars_kwargs = summary_bars_kwargs,
611+
marker_size = marker_size
569612
)
570613

571614
# Plot the violins and make adjustments
@@ -719,6 +762,42 @@ def forest_plot(
719762
else:
720763
ax.add_patch(mpatches.Rectangle((x, 0), 0.25, y, color=bar_colors[x-1], **contrast_bars_kwargs))
721764

765+
# Summary bars
766+
if summary_bars:
767+
_bar_color = summary_bars_kwargs.pop('color')
768+
if _bar_color is not None:
769+
bar_colors = [_bar_color] * number_of_curves_to_plot
770+
else:
771+
bar_colors = violin_colors
772+
773+
span_ax = summary_bars_kwargs.pop("span_ax")
774+
summary_xmin, summary_xmax = ax.get_xlim()
775+
summary_ymin, summary_ymax = ax.get_ylim()
776+
777+
for summary_index in summary_bars:
778+
if span_ax == True:
779+
starting_location = summary_ymin if horizontal else summary_xmin
780+
else:
781+
starting_location = summary_index+1
782+
783+
summary_color = bar_colors[summary_index]
784+
summary_ci_low, summary_ci_high = bcalows[summary_index], bcahighs[summary_index]
785+
786+
if horizontal:
787+
ax.add_patch(mpatches.Rectangle(
788+
(summary_ci_low, starting_location),
789+
summary_ci_high-summary_ci_low, summary_ymax+1,
790+
color=summary_color,
791+
**summary_bars_kwargs)
792+
)
793+
else:
794+
ax.add_patch(mpatches.Rectangle(
795+
(starting_location, summary_ci_low),
796+
summary_xmax+1, summary_ci_high-summary_ci_low,
797+
color=summary_color,
798+
**summary_bars_kwargs)
799+
)
800+
722801
## Invert Y-axis if horizontal
723802
if horizontal:
724803
ax.invert_yaxis()

0 commit comments

Comments
 (0)