Skip to content

Commit e5b500e

Browse files
committed
plotting fixes
1 parent 114bc49 commit e5b500e

2 files changed

Lines changed: 36 additions & 26 deletions

File tree

src/boostedhh/plotting.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def _combine_hbb_bgs(hists, bg_keys):
169169
return h, bg_keys
170170

171171

172-
def _process_samples(sig_keys, bg_keys, bg_colours, sig_scale_dict, bg_order, syst, variation):
172+
def _process_samples(
173+
sig_keys, bg_keys, bg_colours, sig_scale_dict, bg_order, syst, variation, sample_label_map
174+
):
173175
# set up samples, colours and labels
174176
bg_keys = [key for key in bg_order if key in bg_keys]
175177
bg_colours = [COLOURS[bg_colours[sample]] for sample in bg_keys]
@@ -301,7 +303,7 @@ def ratioHistPlot(
301303
plot_significance: bool = False,
302304
significance_dir: str = "right",
303305
plot_ratio: bool = True,
304-
axrax: tuple = None,
306+
axraxsax: tuple = None,
305307
leg_args: dict = None,
306308
cmslabel: str = None,
307309
cmsloc: int = 0,
@@ -353,17 +355,18 @@ def ratioHistPlot(
353355
if sig_colours is None:
354356
sig_colours = SIG_COLOURS
355357
if leg_args is None:
356-
leg_args = {"ncol": 2 if log else 1, "fontsize": 24}
358+
leg_args = {"fontsize": 24}
359+
leg_args["ncol"] = leg_args.get("ncol", (2 if log else 1))
357360

358361
# copy hists and bg_keys so input objects are not changed
359362
hists, bg_keys = deepcopy(hists), deepcopy(bg_keys)
360363
hists, bg_keys = _combine_hbb_bgs(hists, bg_keys)
364+
data_label = sample_label_map.get(data_key, data_key)
361365

362366
bg_keys, bg_colours, bg_labels, sig_scale_dict, sig_labels = _process_samples(
363-
sig_keys, bg_keys, bg_colours, sig_scale_dict, bg_order, syst, variation
367+
sig_keys, bg_keys, bg_colours, sig_scale_dict, bg_order, syst, variation, sample_label_map
364368
)
365369

366-
print([hists[sample, :] for sample in bg_keys])
367370
bg_tot = np.maximum(sum([hists[sample, :] for sample in bg_keys]).values(), 0.0)
368371

369372
if syst is not None and variation is None:
@@ -390,18 +393,20 @@ def ratioHistPlot(
390393
hists, data_err, bg_tot, bg_err = _divide_bin_widths(hists, data_err, bg_tot, bg_err)
391394

392395
# set up plots
393-
if axrax is not None:
396+
if axraxsax is not None:
394397
if plot_significance:
395-
raise RuntimeError("Significance plots with input axes not implemented yet.")
398+
ax, rax, sax = axraxsax
399+
elif plot_ratio:
400+
ax, rax = axraxsax
401+
else:
402+
ax = axraxsax
396403

397-
ax, rax = axrax
398-
ax.sharex(rax)
399404
elif plot_significance:
400405
fig, (ax, rax, sax) = plt.subplots(
401406
3,
402407
1,
403408
figsize=(12, 18),
404-
gridspec_kw={"height_ratios": [3, 1, 1], "hspace": 0},
409+
gridspec_kw={"height_ratios": [3, 1, 1], "hspace": 0.1},
405410
sharex=True,
406411
)
407412
elif plot_ratio:
@@ -430,6 +435,7 @@ def ratioHistPlot(
430435
stack=True,
431436
label=bg_labels,
432437
color=bg_colours,
438+
flow="none",
433439
)
434440

435441
# signal samples
@@ -441,6 +447,7 @@ def ratioHistPlot(
441447
label=list(sig_labels.values()),
442448
color=sig_colours[: len(sig_keys)],
443449
linewidth=3,
450+
flow="none",
444451
)
445452

446453
# plot signal errors
@@ -457,6 +464,7 @@ def ratioHistPlot(
457464
label=[f"{sig_label} {skey}" for sig_label in sig_labels.values()],
458465
alpha=0.6,
459466
color=sig_colours[: len(sig_keys)],
467+
flow="none",
460468
)
461469
elif sig_err is not None:
462470
for sig_key, sig_scale in sig_scale_dict.items():
@@ -521,24 +529,22 @@ def ratioHistPlot(
521529
ax=ax,
522530
yerr=data_err,
523531
xerr=divide_bin_width,
524-
label=data_key,
532+
label=data_label,
525533
**DATA_STYLE,
534+
flow="none",
526535
)
527536

537+
# legend ordering
538+
legend_order = [data_label] + bg_order[::-1] + list(sig_labels.values()) + [BG_UNC_LABEL]
539+
legend_order = [sample_label_map.get(k, k) for k in legend_order]
540+
541+
handles, labels = ax.get_legend_handles_labels()
542+
ordered_handles = [handles[labels.index(label)] for label in legend_order if label in labels]
543+
ordered_labels = [label for label in legend_order if label in labels]
544+
ax.legend(ordered_handles, ordered_labels, **leg_args)
545+
528546
if log:
529547
ax.set_yscale("log")
530-
# two column legend
531-
ax.legend(**leg_args)
532-
else:
533-
legend_order = [data_key] + bg_order[::-1] + list(sig_labels.values()) + [BG_UNC_LABEL]
534-
legend_order = [sample_label_map.get(k, k) for k in legend_order]
535-
536-
handles, labels = ax.get_legend_handles_labels()
537-
ordered_handles = [
538-
handles[labels.index(label)] for label in legend_order if label in labels
539-
]
540-
ordered_labels = [label for label in legend_order if label in labels]
541-
ax.legend(ordered_handles, ordered_labels, **leg_args)
542548

543549
y_lowlim = 0 if not log else 1e-5
544550
if ylim is not None:
@@ -566,6 +572,7 @@ def ratioHistPlot(
566572
xerr=divide_bin_width,
567573
ax=rax,
568574
**DATA_STYLE,
575+
flow="none",
569576
)
570577

571578
if bg_err is not None and bg_err_type == "shaded":
@@ -614,12 +621,15 @@ def ratioHistPlot(
614621
histtype="step",
615622
label=[sample_label_map.get(sig_key, sig_key) for sig_key in sig_scale_dict],
616623
color=sig_colours[: len(sig_keys)],
624+
flow="none",
617625
)
618626

619-
sax.legend(fontsize=12)
627+
sax.legend(fontsize=15)
620628
sax.set_yscale("log")
621629
sax.set_ylim([1e-7, 10])
622630
sax.set_xlabel(hists.axes[1].label)
631+
sax.set_ylabel(sax.get_ylabel(), fontsize=22)
632+
rax.set_xlabel(None)
623633

624634
if title is not None:
625635
ax.set_title(title, y=1.08)
@@ -639,7 +649,7 @@ def ratioHistPlot(
639649

640650
add_cms_label(ax, year, label=cmslabel, loc=cmsloc)
641651

642-
if axrax is None:
652+
if axraxsax is None:
643653
if len(name):
644654
plt.savefig(name, bbox_inches="tight")
645655

src/boostedhh/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ def singleVarHist(
806806
for sample in samples:
807807
events = events_dict[sample]
808808
if sample == "data" and var.endswith(("_up", "_down")):
809-
fill_var = "_".join(var.split("_")[:-2])
809+
fill_var = "_".join(var.split("_")[:-2]) # remove _up/_down
810810
else:
811811
fill_var = var
812812

0 commit comments

Comments
 (0)