44from time import time
55from typing import Tuple
66
7- from distributed import Client , get_worker , LocalCluster , SSHCluster
87import ml
9- from plotting import save_ml_plots , save_plots
108import ROOT
11- from utils import (
12- AGCInput ,
13- AGCResult ,
14- postprocess_results ,
15- retrieve_inputs ,
16- save_histos ,
17- )
9+ from distributed import Client , LocalCluster , SSHCluster , get_worker
10+ from plotting import save_ml_plots , save_plots
11+ from statistical import fit_histograms
12+ from utils import AGCInput , AGCResult , postprocess_results , retrieve_inputs , save_histos
1813
1914# Using https://atlas-groupdata.web.cern.ch/atlas-groupdata/dev/AnalysisTop/TopDataPreparation/XSection-MC15-13TeV.data
2015# as a reference. Values are in pb.
@@ -90,7 +85,24 @@ def parse_args() -> argparse.Namespace:
9085 "--hosts" ,
9186 help = "A comma-separated list of worker node hostnames. Only required if --scheduler=dask-ssh, ignored otherwise." ,
9287 )
93- p .add_argument ("-v" , "--verbose" , help = "Turn on verbose execution logs." , action = "store_true" )
88+ p .add_argument (
89+ "-v" ,
90+ "--verbose" ,
91+ help = "Turn on verbose execution logs." ,
92+ action = "store_true" ,
93+ )
94+
95+ p .add_argument (
96+ "--statistical-validation" ,
97+ help = argparse .SUPPRESS ,
98+ action = "store_true" ,
99+ )
100+
101+ p .add_argument (
102+ "--no-fitting" ,
103+ help = "Do not run statistical validation part of the analysis." ,
104+ action = "store_true" ,
105+ )
94106
95107 return p .parse_args ()
96108
@@ -109,7 +121,11 @@ def create_dask_client(scheduler: str, ncores: int, hosts: str, scheduler_addres
109121 sshc = SSHCluster (
110122 workers ,
111123 connect_options = {"known_hosts" : None },
112- worker_options = {"nprocs" : ncores , "nthreads" : 1 , "memory_limit" : "32GB" },
124+ worker_options = {
125+ "nprocs" : ncores ,
126+ "nthreads" : 1 ,
127+ "memory_limit" : "32GB" ,
128+ },
113129 )
114130 return Client (sshc )
115131
@@ -128,7 +144,10 @@ def define_trijet_mass(df: ROOT.RDataFrame) -> ROOT.RDataFrame:
128144 df = df .Filter ("Sum(Jet_btagCSVV2_cut > 0.5) > 1" )
129145
130146 # Build four-momentum vectors for each jet
131- df = df .Define ("Jet_p4" , "ConstructP4(Jet_pt_cut, Jet_eta_cut, Jet_phi_cut, Jet_mass_cut)" )
147+ df = df .Define (
148+ "Jet_p4" ,
149+ "ConstructP4(Jet_pt_cut, Jet_eta_cut, Jet_phi_cut, Jet_mass_cut)" ,
150+ )
132151
133152 # Build trijet combinations
134153 df = df .Define ("Trijet_idx" , "Combinations(Jet_pt_cut, 3)" )
@@ -186,7 +205,7 @@ def book_histos(
186205 # pt_res_up(jet_pt) - jet resolution systematic
187206 df = df .Vary (
188207 "Jet_pt" ,
189- "ROOT::RVec<ROOT::RVecF>{Jet_pt*pt_scale_up(), Jet_pt*jet_pt_resolution(Jet_pt.size() )}" ,
208+ "ROOT::RVec<ROOT::RVecF>{Jet_pt*pt_scale_up(), Jet_pt*jet_pt_resolution(Jet_pt)}" ,
190209 ["pt_scale_up" , "pt_res_up" ],
191210 )
192211
@@ -240,8 +259,7 @@ def book_histos(
240259 # Only one b-tagged region required
241260 # The observable is the total transvesre momentum
242261 # fmt: off
243- df4j1b = df .Filter ("Sum(Jet_btagCSVV2_cut > 0.5) == 1" )\
244- .Define ("HT" , "Sum(Jet_pt_cut)" )
262+ df4j1b = df .Filter ("Sum(Jet_btagCSVV2_cut > 0.5) == 1" ).Define ("HT" , "Sum(Jet_pt_cut)" )
245263 # fmt: on
246264
247265 # Define trijet_mass observable for the 4j2b region (this one is more complicated)
@@ -251,20 +269,34 @@ def book_histos(
251269 results = []
252270 for df , observable , region in zip ([df4j1b , df4j2b ], ["HT" , "Trijet_mass" ], ["4j1b" , "4j2b" ]):
253271 histo_model = ROOT .RDF .TH1DModel (
254- name = f"{ region } _{ process } _{ variation } " , title = process , nbinsx = 25 , xlow = 50 , xup = 550
272+ name = f"{ region } _{ process } _{ variation } " ,
273+ title = process ,
274+ nbinsx = 25 ,
275+ xlow = 50 ,
276+ xup = 550 ,
255277 )
256278 nominal_histo = df .Histo1D (histo_model , observable , "Weights" )
257279
258280 if variation == "nominal" :
259281 results .append (
260282 AGCResult (
261- nominal_histo , region , process , variation , nominal_histo , should_vary = True
283+ nominal_histo ,
284+ region ,
285+ process ,
286+ variation ,
287+ nominal_histo ,
288+ should_vary = True ,
262289 )
263290 )
264291 else :
265292 results .append (
266293 AGCResult (
267- nominal_histo , region , process , variation , nominal_histo , should_vary = False
294+ nominal_histo ,
295+ region ,
296+ process ,
297+ variation ,
298+ nominal_histo ,
299+ should_vary = False ,
268300 )
269301 )
270302 print (f"Booked histogram { histo_model .fName } " )
@@ -292,7 +324,12 @@ def book_histos(
292324 if variation == "nominal" :
293325 ml_results .append (
294326 AGCResult (
295- nominal_histo , feature .name , process , variation , nominal_histo , should_vary = True
327+ nominal_histo ,
328+ feature .name ,
329+ process ,
330+ variation ,
331+ nominal_histo ,
332+ should_vary = True ,
296333 )
297334 )
298335 else :
@@ -382,7 +419,10 @@ def ml_init():
382419 with create_dask_client (args .scheduler , args .ncores , args .hosts , scheduler_address ) as client :
383420 for input in inputs :
384421 df = ROOT .RDF .Experimental .Distributed .Dask .RDataFrame (
385- "Events" , input .paths , daskclient = client , npartitions = args .npartitions
422+ "Events" ,
423+ input .paths ,
424+ daskclient = client ,
425+ npartitions = args .npartitions ,
386426 )
387427 df ._headnode .backend .distribute_unique_paths (
388428 [
@@ -426,6 +466,10 @@ def main() -> None:
426466 # To only change the verbosity in a given scope, use ROOT.Experimental.RLogScopedVerbosity.
427467 ROOT .Detail .RDF .RDFLogChannel ().SetVerbosity (ROOT .Experimental .ELogLevel .kInfo )
428468
469+ if args .statistical_validation :
470+ fit_histograms (filename = args .output )
471+ return
472+
429473 inputs : list [AGCInput ] = retrieve_inputs (
430474 args .n_max_files_per_sample , args .remote_data_prefix , args .data_cache
431475 )
@@ -457,6 +501,9 @@ def main() -> None:
457501 save_histos ([r .histo for r in ml_results ], output_fname = output_fname )
458502 print (f"Result histograms from ML inference step saved in file { output_fname } " )
459503
504+ if not args .no_fitting :
505+ fit_histograms (filename = args .output )
506+
460507
461508if __name__ == "__main__" :
462509 main ()
0 commit comments