@@ -169,7 +169,9 @@ def _combine_hbb_bgs(hists, bg_keys):
169169 return h , bg_keys
170170
171171
172- def _process_samples (sig_keys , bg_keys , bg_colours , sig_scale_dict , bg_order , syst , variation ):
172+ def _process_samples (
173+ sig_keys , bg_keys , bg_colours , sig_scale_dict , bg_order , syst , variation , sample_label_map
174+ ):
173175 # set up samples, colours and labels
174176 bg_keys = [key for key in bg_order if key in bg_keys ]
175177 bg_colours = [COLOURS [bg_colours [sample ]] for sample in bg_keys ]
@@ -301,7 +303,7 @@ def ratioHistPlot(
301303 plot_significance : bool = False ,
302304 significance_dir : str = "right" ,
303305 plot_ratio : bool = True ,
304- axrax : tuple = None ,
306+ axraxsax : tuple = None ,
305307 leg_args : dict = None ,
306308 cmslabel : str = None ,
307309 cmsloc : int = 0 ,
@@ -353,17 +355,18 @@ def ratioHistPlot(
353355 if sig_colours is None :
354356 sig_colours = SIG_COLOURS
355357 if leg_args is None :
356- leg_args = {"ncol" : 2 if log else 1 , "fontsize" : 24 }
358+ leg_args = {"fontsize" : 24 }
359+ leg_args ["ncol" ] = leg_args .get ("ncol" , (2 if log else 1 ))
357360
358361 # copy hists and bg_keys so input objects are not changed
359362 hists , bg_keys = deepcopy (hists ), deepcopy (bg_keys )
360363 hists , bg_keys = _combine_hbb_bgs (hists , bg_keys )
364+ data_label = sample_label_map .get (data_key , data_key )
361365
362366 bg_keys , bg_colours , bg_labels , sig_scale_dict , sig_labels = _process_samples (
363- sig_keys , bg_keys , bg_colours , sig_scale_dict , bg_order , syst , variation
367+ sig_keys , bg_keys , bg_colours , sig_scale_dict , bg_order , syst , variation , sample_label_map
364368 )
365369
366- print ([hists [sample , :] for sample in bg_keys ])
367370 bg_tot = np .maximum (sum ([hists [sample , :] for sample in bg_keys ]).values (), 0.0 )
368371
369372 if syst is not None and variation is None :
@@ -390,18 +393,20 @@ def ratioHistPlot(
390393 hists , data_err , bg_tot , bg_err = _divide_bin_widths (hists , data_err , bg_tot , bg_err )
391394
392395 # set up plots
393- if axrax is not None :
396+ if axraxsax is not None :
394397 if plot_significance :
395- raise RuntimeError ("Significance plots with input axes not implemented yet." )
398+ ax , rax , sax = axraxsax
399+ elif plot_ratio :
400+ ax , rax = axraxsax
401+ else :
402+ ax = axraxsax
396403
397- ax , rax = axrax
398- ax .sharex (rax )
399404 elif plot_significance :
400405 fig , (ax , rax , sax ) = plt .subplots (
401406 3 ,
402407 1 ,
403408 figsize = (12 , 18 ),
404- gridspec_kw = {"height_ratios" : [3 , 1 , 1 ], "hspace" : 0 },
409+ gridspec_kw = {"height_ratios" : [3 , 1 , 1 ], "hspace" : 0.1 },
405410 sharex = True ,
406411 )
407412 elif plot_ratio :
@@ -430,6 +435,7 @@ def ratioHistPlot(
430435 stack = True ,
431436 label = bg_labels ,
432437 color = bg_colours ,
438+ flow = "none" ,
433439 )
434440
435441 # signal samples
@@ -441,6 +447,7 @@ def ratioHistPlot(
441447 label = list (sig_labels .values ()),
442448 color = sig_colours [: len (sig_keys )],
443449 linewidth = 3 ,
450+ flow = "none" ,
444451 )
445452
446453 # plot signal errors
@@ -457,6 +464,7 @@ def ratioHistPlot(
457464 label = [f"{ sig_label } { skey } " for sig_label in sig_labels .values ()],
458465 alpha = 0.6 ,
459466 color = sig_colours [: len (sig_keys )],
467+ flow = "none" ,
460468 )
461469 elif sig_err is not None :
462470 for sig_key , sig_scale in sig_scale_dict .items ():
@@ -521,24 +529,22 @@ def ratioHistPlot(
521529 ax = ax ,
522530 yerr = data_err ,
523531 xerr = divide_bin_width ,
524- label = data_key ,
532+ label = data_label ,
525533 ** DATA_STYLE ,
534+ flow = "none" ,
526535 )
527536
537+ # legend ordering
538+ legend_order = [data_label ] + bg_order [::- 1 ] + list (sig_labels .values ()) + [BG_UNC_LABEL ]
539+ legend_order = [sample_label_map .get (k , k ) for k in legend_order ]
540+
541+ handles , labels = ax .get_legend_handles_labels ()
542+ ordered_handles = [handles [labels .index (label )] for label in legend_order if label in labels ]
543+ ordered_labels = [label for label in legend_order if label in labels ]
544+ ax .legend (ordered_handles , ordered_labels , ** leg_args )
545+
528546 if log :
529547 ax .set_yscale ("log" )
530- # two column legend
531- ax .legend (** leg_args )
532- else :
533- legend_order = [data_key ] + bg_order [::- 1 ] + list (sig_labels .values ()) + [BG_UNC_LABEL ]
534- legend_order = [sample_label_map .get (k , k ) for k in legend_order ]
535-
536- handles , labels = ax .get_legend_handles_labels ()
537- ordered_handles = [
538- handles [labels .index (label )] for label in legend_order if label in labels
539- ]
540- ordered_labels = [label for label in legend_order if label in labels ]
541- ax .legend (ordered_handles , ordered_labels , ** leg_args )
542548
543549 y_lowlim = 0 if not log else 1e-5
544550 if ylim is not None :
@@ -566,6 +572,7 @@ def ratioHistPlot(
566572 xerr = divide_bin_width ,
567573 ax = rax ,
568574 ** DATA_STYLE ,
575+ flow = "none" ,
569576 )
570577
571578 if bg_err is not None and bg_err_type == "shaded" :
@@ -614,12 +621,15 @@ def ratioHistPlot(
614621 histtype = "step" ,
615622 label = [sample_label_map .get (sig_key , sig_key ) for sig_key in sig_scale_dict ],
616623 color = sig_colours [: len (sig_keys )],
624+ flow = "none" ,
617625 )
618626
619- sax .legend (fontsize = 12 )
627+ sax .legend (fontsize = 15 )
620628 sax .set_yscale ("log" )
621629 sax .set_ylim ([1e-7 , 10 ])
622630 sax .set_xlabel (hists .axes [1 ].label )
631+ sax .set_ylabel (sax .get_ylabel (), fontsize = 22 )
632+ rax .set_xlabel (None )
623633
624634 if title is not None :
625635 ax .set_title (title , y = 1.08 )
@@ -639,7 +649,7 @@ def ratioHistPlot(
639649
640650 add_cms_label (ax , year , label = cmslabel , loc = cmsloc )
641651
642- if axrax is None :
652+ if axraxsax is None :
643653 if len (name ):
644654 plt .savefig (name , bbox_inches = "tight" )
645655
0 commit comments