Skip to content

Commit 3158359

Browse files
committed
Added more error checking for forest plot
1 parent 8aff220 commit 3158359

File tree

6 files changed

+105
-57
lines changed

6 files changed

+105
-57
lines changed

dabest/forest_plot.py

+49-25
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def check_for_errors(
127127
horizontal,
128128
marker_size,
129129
custom_palette,
130-
halfviolin_alpha,
131-
halfviolin_desat,
130+
contrast_alpha,
131+
contrast_desat,
132132
labels,
133133
labels_rotation,
134134
labels_fontsize,
@@ -166,8 +166,26 @@ def check_for_errors(
166166
if idx is not None:
167167
if not isinstance(idx, (tuple, list)):
168168
raise TypeError("`idx` must be a tuple or list of integers.")
169-
# if contrast_type == "mini_meta":
170-
# raise ValueError("The `idx` argument is not applicable to mini-meta analyses.")
169+
170+
msg1 = "The `idx` argument must have the same length as the number of dabest objects. "
171+
msg2 = "E.g., If two dabest objects are supplied, there should be two lists within `idx`. "
172+
msg3 = "E.g., `idx` = [[1,2],[0,1]]."
173+
_total = 0
174+
for _group in idx:
175+
if isinstance(_group, int | float):
176+
raise ValueError(msg1+msg2+msg3)
177+
else:
178+
_total += 1
179+
if _total != len(data):
180+
raise ValueError(msg1+msg2+msg3)
181+
182+
if idx is not None:
183+
number_of_curves_to_plot = sum([len(i) for i in idx])
184+
else:
185+
if contrast_type == 'delta':
186+
number_of_curves_to_plot = sum(len(getattr(i, effect_size).results) for i in data)
187+
else:
188+
number_of_curves_to_plot = len(data)
171189

172190
# Axes
173191
if ax is not None and not isinstance(ax, plt.Axes):
@@ -195,25 +213,26 @@ def check_for_errors(
195213
raise TypeError("`marker_size` must be a positive integer or float.")
196214

197215
# Custom palette
198-
if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):
216+
if custom_palette is not None and not isinstance(custom_palette, (dict, list, tuple, str, type(None))):
199217
raise TypeError("The `custom_palette` must be either a dictionary, list, string, or `None`.")
200218
if isinstance(custom_palette, dict) and labels is None:
201219
raise ValueError("The `labels` argument must be provided if `custom_palette` is a dictionary.")
220+
if isinstance(custom_palette, (list, tuple)) and len(custom_palette) < number_of_curves_to_plot:
221+
raise ValueError("The `custom_palette` list/tuple must have the same length as the number of `data` provided.")
202222

203-
204-
# Halfviolin alpha and desat
205-
if not isinstance(halfviolin_alpha, float) or not 0 <= halfviolin_alpha <= 1:
206-
raise TypeError("`halfviolin_alpha` must be a float between 0 and 1.")
223+
# Contrast alpha and desat
224+
if not isinstance(contrast_alpha, float) or not 0 <= contrast_alpha <= 1:
225+
raise TypeError("`contrast_alpha` must be a float between 0 and 1.")
207226

208-
if not isinstance(halfviolin_desat, (float, int)) or not 0 <= halfviolin_desat <= 1:
209-
raise TypeError("`halfviolin_desat` must be a float between 0 and 1 or an int (1).")
227+
if not isinstance(contrast_desat, (float, int)) or not 0 <= contrast_desat <= 1:
228+
raise TypeError("`contrast_desat` must be a float between 0 and 1 or an int (1).")
210229

211230

212231
# Contrast labels
213232
if labels is not None and not all(isinstance(label, str) for label in labels):
214233
raise TypeError("The `labels` must be a list of strings or `None`.")
215234

216-
number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)
235+
217236
if labels is not None and len(labels) != number_of_curves_to_plot:
218237
raise ValueError("`labels` must match the number of `data` provided.")
219238

@@ -360,7 +379,7 @@ def color_palette(
360379
custom_palette,
361380
labels,
362381
number_of_curves_to_plot,
363-
halfviolin_desat
382+
contrast_desat
364383
):
365384
if custom_palette is not None:
366385
if isinstance(custom_palette, dict):
@@ -378,7 +397,7 @@ def color_palette(
378397
)
379398
else:
380399
violin_colors = sns.color_palette(n_colors=number_of_curves_to_plot)
381-
violin_colors = [sns.desaturate(color, halfviolin_desat) for color in violin_colors]
400+
violin_colors = [sns.desaturate(color, contrast_desat) for color in violin_colors]
382401
return violin_colors
383402

384403

@@ -392,8 +411,8 @@ def forest_plot(
392411

393412
marker_size: int = 10,
394413
custom_palette: Optional[Union[dict, list, str]] = None,
395-
halfviolin_alpha: float = 0.8,
396-
halfviolin_desat: float = 1,
414+
contrast_alpha: float = 0.8,
415+
contrast_desat: float = 1,
397416

398417
labels: list[str] = None,
399418
labels_rotation: int = None,
@@ -442,9 +461,9 @@ def forest_plot(
442461
Marker size for plotting effect size dots.
443462
custom_palette : Optional[Union[dict, list, str]], default=None
444463
Custom color palette for the plot.
445-
halfviolin_alpha : float, default=0.8
464+
contrast_alpha : float, default=0.8
446465
Transparency level for violin plots.
447-
halfviolin_desat : float, default=1
466+
contrast_desat : float, default=1
448467
Saturation level for violin plots.
449468
labels : List[str]
450469
Labels for each contrast. If None, defaults to 'Contrast 1', 'Contrast 2', etc.
@@ -468,9 +487,14 @@ def forest_plot(
468487
Custom y-tick labels for the plot.
469488
remove_spines : bool, default=True
470489
If True, removes plot spines (except the relevant dependent variable spine).
471-
472-
473-
490+
delta_text : bool, default=True
491+
If True, it adds text next to each curve representing the effect size value.
492+
delta_text_kwargs : dict, default=None
493+
Additional keyword arguments for the delta_text.
494+
contrast_bars : bool, default=True
495+
If True, it adds bars from the zeroline to the effect size curve.
496+
contrast_bars_kwargs : dict, default=None
497+
Additional keyword arguments for the contrast_bars.
474498
violin_kwargs : Optional[dict], default=None
475499
Additional arguments for violin plot customization.
476500
zeroline_kwargs : Optional[dict], default=None
@@ -498,8 +522,8 @@ def forest_plot(
498522
horizontal = horizontal,
499523
marker_size = marker_size,
500524
custom_palette = custom_palette,
501-
halfviolin_alpha = halfviolin_alpha,
502-
halfviolin_desat = halfviolin_desat,
525+
contrast_alpha = contrast_alpha,
526+
contrast_desat = contrast_desat,
503527
labels = labels,
504528
labels_rotation = labels_rotation,
505529
labels_fontsize = labels_fontsize,
@@ -551,7 +575,7 @@ def forest_plot(
551575
)
552576
halfviolin(
553577
v,
554-
alpha = halfviolin_alpha,
578+
alpha = contrast_alpha,
555579
half = "bottom" if horizontal else "right",
556580
)
557581

@@ -570,7 +594,7 @@ def forest_plot(
570594
custom_palette = custom_palette,
571595
labels = labels,
572596
number_of_curves_to_plot = number_of_curves_to_plot,
573-
halfviolin_desat = halfviolin_desat
597+
contrast_desat = contrast_desat
574598
)
575599

576600
for patch, color in zip(v["bodies"], violin_colors):

nbs/API/forest_plot.ipynb

+49-25
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@
187187
" horizontal,\n",
188188
" marker_size,\n",
189189
" custom_palette,\n",
190-
" halfviolin_alpha,\n",
191-
" halfviolin_desat,\n",
190+
" contrast_alpha,\n",
191+
" contrast_desat,\n",
192192
" labels,\n",
193193
" labels_rotation,\n",
194194
" labels_fontsize,\n",
@@ -226,8 +226,26 @@
226226
" if idx is not None:\n",
227227
" if not isinstance(idx, (tuple, list)):\n",
228228
" raise TypeError(\"`idx` must be a tuple or list of integers.\")\n",
229-
" # if contrast_type == \"mini_meta\":\n",
230-
" # raise ValueError(\"The `idx` argument is not applicable to mini-meta analyses.\")\n",
229+
"\n",
230+
" msg1 = \"The `idx` argument must have the same length as the number of dabest objects. \"\n",
231+
" msg2 = \"E.g., If two dabest objects are supplied, there should be two lists within `idx`. \"\n",
232+
" msg3 = \"E.g., `idx` = [[1,2],[0,1]].\"\n",
233+
" _total = 0\n",
234+
" for _group in idx:\n",
235+
" if isinstance(_group, int | float):\n",
236+
" raise ValueError(msg1+msg2+msg3)\n",
237+
" else:\n",
238+
" _total += 1\n",
239+
" if _total != len(data):\n",
240+
" raise ValueError(msg1+msg2+msg3)\n",
241+
" \n",
242+
" if idx is not None:\n",
243+
" number_of_curves_to_plot = sum([len(i) for i in idx])\n",
244+
" else:\n",
245+
" if contrast_type == 'delta':\n",
246+
" number_of_curves_to_plot = sum(len(getattr(i, effect_size).results) for i in data)\n",
247+
" else:\n",
248+
" number_of_curves_to_plot = len(data)\n",
231249
"\n",
232250
" # Axes\n",
233251
" if ax is not None and not isinstance(ax, plt.Axes):\n",
@@ -255,25 +273,26 @@
255273
" raise TypeError(\"`marker_size` must be a positive integer or float.\")\n",
256274
"\n",
257275
" # Custom palette\n",
258-
" if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):\n",
276+
" if custom_palette is not None and not isinstance(custom_palette, (dict, list, tuple, str, type(None))):\n",
259277
" raise TypeError(\"The `custom_palette` must be either a dictionary, list, string, or `None`.\")\n",
260278
" if isinstance(custom_palette, dict) and labels is None:\n",
261279
" raise ValueError(\"The `labels` argument must be provided if `custom_palette` is a dictionary.\")\n",
280+
" if isinstance(custom_palette, (list, tuple)) and len(custom_palette) < number_of_curves_to_plot:\n",
281+
" raise ValueError(\"The `custom_palette` list/tuple must have the same length as the number of `data` provided.\")\n",
262282
"\n",
263-
"\n",
264-
" # Halfviolin alpha and desat\n",
265-
" if not isinstance(halfviolin_alpha, float) or not 0 <= halfviolin_alpha <= 1:\n",
266-
" raise TypeError(\"`halfviolin_alpha` must be a float between 0 and 1.\")\n",
283+
" # Contrast alpha and desat\n",
284+
" if not isinstance(contrast_alpha, float) or not 0 <= contrast_alpha <= 1:\n",
285+
" raise TypeError(\"`contrast_alpha` must be a float between 0 and 1.\")\n",
267286
" \n",
268-
" if not isinstance(halfviolin_desat, (float, int)) or not 0 <= halfviolin_desat <= 1:\n",
269-
" raise TypeError(\"`halfviolin_desat` must be a float between 0 and 1 or an int (1).\")\n",
287+
" if not isinstance(contrast_desat, (float, int)) or not 0 <= contrast_desat <= 1:\n",
288+
" raise TypeError(\"`contrast_desat` must be a float between 0 and 1 or an int (1).\")\n",
270289
" \n",
271290
"\n",
272291
" # Contrast labels\n",
273292
" if labels is not None and not all(isinstance(label, str) for label in labels):\n",
274293
" raise TypeError(\"The `labels` must be a list of strings or `None`.\")\n",
275294
" \n",
276-
" number_of_curves_to_plot = sum([len(i) for i in idx]) if idx is not None else len(data)\n",
295+
" \n",
277296
" if labels is not None and len(labels) != number_of_curves_to_plot:\n",
278297
" raise ValueError(\"`labels` must match the number of `data` provided.\")\n",
279298
" \n",
@@ -420,7 +439,7 @@
420439
" custom_palette, \n",
421440
" labels, \n",
422441
" number_of_curves_to_plot,\n",
423-
" halfviolin_desat\n",
442+
" contrast_desat\n",
424443
" ):\n",
425444
" if custom_palette is not None:\n",
426445
" if isinstance(custom_palette, dict):\n",
@@ -438,7 +457,7 @@
438457
" )\n",
439458
" else:\n",
440459
" violin_colors = sns.color_palette(n_colors=number_of_curves_to_plot)\n",
441-
" violin_colors = [sns.desaturate(color, halfviolin_desat) for color in violin_colors]\n",
460+
" violin_colors = [sns.desaturate(color, contrast_desat) for color in violin_colors]\n",
442461
" return violin_colors\n",
443462
"\n",
444463
"\n",
@@ -452,8 +471,8 @@
452471
"\n",
453472
" marker_size: int = 10,\n",
454473
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
455-
" halfviolin_alpha: float = 0.8,\n",
456-
" halfviolin_desat: float = 1,\n",
474+
" contrast_alpha: float = 0.8,\n",
475+
" contrast_desat: float = 1,\n",
457476
"\n",
458477
" labels: list[str] = None,\n",
459478
" labels_rotation: int = None,\n",
@@ -502,9 +521,9 @@
502521
" Marker size for plotting effect size dots.\n",
503522
" custom_palette : Optional[Union[dict, list, str]], default=None\n",
504523
" Custom color palette for the plot.\n",
505-
" halfviolin_alpha : float, default=0.8\n",
524+
" contrast_alpha : float, default=0.8\n",
506525
" Transparency level for violin plots.\n",
507-
" halfviolin_desat : float, default=1\n",
526+
" contrast_desat : float, default=1\n",
508527
" Saturation level for violin plots.\n",
509528
" labels : List[str]\n",
510529
" Labels for each contrast. If None, defaults to 'Contrast 1', 'Contrast 2', etc.\n",
@@ -528,9 +547,14 @@
528547
" Custom y-tick labels for the plot.\n",
529548
" remove_spines : bool, default=True\n",
530549
" If True, removes plot spines (except the relevant dependent variable spine).\n",
531-
"\n",
532-
"\n",
533-
"\n",
550+
" delta_text : bool, default=True\n",
551+
" If True, it adds text next to each curve representing the effect size value.\n",
552+
" delta_text_kwargs : dict, default=None\n",
553+
" Additional keyword arguments for the delta_text.\n",
554+
" contrast_bars : bool, default=True\n",
555+
" If True, it adds bars from the zeroline to the effect size curve.\n",
556+
" contrast_bars_kwargs : dict, default=None\n",
557+
" Additional keyword arguments for the contrast_bars.\n",
534558
" violin_kwargs : Optional[dict], default=None\n",
535559
" Additional arguments for violin plot customization.\n",
536560
" zeroline_kwargs : Optional[dict], default=None\n",
@@ -558,8 +582,8 @@
558582
" horizontal = horizontal,\n",
559583
" marker_size = marker_size,\n",
560584
" custom_palette = custom_palette,\n",
561-
" halfviolin_alpha = halfviolin_alpha,\n",
562-
" halfviolin_desat = halfviolin_desat,\n",
585+
" contrast_alpha = contrast_alpha,\n",
586+
" contrast_desat = contrast_desat,\n",
563587
" labels = labels,\n",
564588
" labels_rotation = labels_rotation,\n",
565589
" labels_fontsize = labels_fontsize,\n",
@@ -611,7 +635,7 @@
611635
" )\n",
612636
" halfviolin(\n",
613637
" v, \n",
614-
" alpha = halfviolin_alpha, \n",
638+
" alpha = contrast_alpha, \n",
615639
" half = \"bottom\" if horizontal else \"right\",\n",
616640
" )\n",
617641
" \n",
@@ -630,7 +654,7 @@
630654
" custom_palette = custom_palette, \n",
631655
" labels = labels, \n",
632656
" number_of_curves_to_plot = number_of_curves_to_plot,\n",
633-
" halfviolin_desat = halfviolin_desat\n",
657+
" contrast_desat = contrast_desat\n",
634658
" )\n",
635659
" \n",
636660
" for patch, color in zip(v[\"bodies\"], violin_colors):\n",

nbs/tests/data/mocked_data_test_forestplot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"marker_size": 20, # Ensure it's a positive integer or float.
4949
"remove_spines": True, # Ensure it's a boolean.
5050
"labels_rotation": 45, # Ensure it's an integer or float between 0 and 360.
51-
"halfviolin_alpha": 0.8, # Ensure it's a float between 0 and 1.
51+
"contrast_alpha": 0.8, # Ensure it's a float between 0 and 1.
5252
"horizontal": False, # Ensure it's a boolean.
5353
}
5454

nbs/tests/mpl_image_tests/test_05_forest_plot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def test_509_deltadelta_halfviolin_aesthetics_forest():
280280
return forest_plot(
281281
contrasts,
282282
labels=['Drug1', 'Drug2', 'Drug3'],
283-
halfviolin_alpha=0.2,
284-
halfviolin_desat=0.2
283+
contrast_alpha=0.2,
284+
contrast_desat=0.2
285285
)
286286

287287

nbs/tests/test_forest_plot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def test_forest_plot_no_input_parameters():
2020
("horizontal", "sideways", "`horizontal` must be a boolean value.", TypeError),
2121
("marker_size", "large", "`marker_size` must be a positive integer or float.", TypeError),
2222
("custom_palette", 123, "The `custom_palette` must be either a dictionary, list, string, or `None`.", TypeError),
23-
("halfviolin_alpha", "opaque", "`halfviolin_alpha` must be a float between 0 and 1.", TypeError),
24-
("halfviolin_desat", "yes", "`halfviolin_desat` must be a float between 0 and 1 or an int (1).", TypeError),
23+
("contrast_alpha", "opaque", "`contrast_alpha` must be a float between 0 and 1.", TypeError),
24+
("contrast_desat", "yes", "`contrast_desat` must be a float between 0 and 1 or an int (1).", TypeError),
2525
("labels", ["valid", 123], "The `labels` must be a list of strings or `None`.", TypeError),
2626
("labels", ['valid', 'valid'], "`labels` must match the number of `data` provided.", ValueError),
2727
("labels_fontsize", "big", "`labels_fontsize` must be an integer or float.", TypeError),

nbs/tutorials/07-forest_plot.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -665,9 +665,9 @@
665665
"\n",
666666
"- `marker_size`: The size of the markers for the effect sizes. The default is 12.\n",
667667
"\n",
668-
"- ``halfviolin_alpha``: Transparency level for violin plots. The default is 0.8.\n",
668+
"- ``contrast_alpha``: Transparency level for violin plots. The default is 0.8.\n",
669669
"\n",
670-
"- ``halfviolin_desat``: Saturation level for violin plots. The default is 1.\n",
670+
"- ``contrast_desat``: Saturation level for violin plots. The default is 1.\n",
671671
"\n",
672672
"- `labels_rotation`: Rotation angle for contrast labels. The default is 45.\n",
673673
"\n",

0 commit comments

Comments
 (0)