@@ -75,6 +75,7 @@ class MLHistoConf:
7575 for i in range (len (feature_names ))
7676]
7777
78+
7879def load_cpp (max_n_jets = 6 ):
7980 # the default value of max_n_jets is the same as in the reference implementation
8081 # https://github.com/iris-hep/analysis-grand-challenge
@@ -91,12 +92,13 @@ def load_cpp(max_n_jets=6):
9192 model_even_path = "models/bdt_even.root"
9293 model_odd_path = "models/bdt_odd.root"
9394
94- ROOT .gInterpreter .Declare (f" #include \ "{ cpp_source } \" " )
95+ ROOT .gInterpreter .Declare (f' #include "{ cpp_source } "' )
9596 # Initialize FastForest models.
9697 # Our BDT models have 20 input features according to the AGC documentation
9798 # https://agc.readthedocs.io/en/latest/taskbackground.html#machine-learning-component
9899
99- ROOT .gInterpreter .ProcessLine (f"""
100+ ROOT .gInterpreter .ProcessLine (
101+ f"""
100102 #ifndef AGC_MODELS
101103 #define AGC_MODELS
102104 const static TMVA::Experimental::RBDT model_even{{"feven", "{ model_even_path } "}};
@@ -105,7 +107,8 @@ def load_cpp(max_n_jets=6):
105107 const static auto permutations = get_permutations_dict(max_n_jets);
106108 #endif
107109 """
108- )
110+ )
111+
109112
110113def define_features (df : ROOT .RDataFrame ) -> ROOT .RDataFrame :
111114 return df .Define (
@@ -131,6 +134,7 @@ def define_features(df: ROOT.RDataFrame) -> ROOT.RDataFrame:
131134 """ ,
132135 )
133136
137+
134138def predict_proba (df : ROOT .RDataFrame ) -> ROOT .RDataFrame :
135139 """get probability scores for every permutation in event"""
136140
@@ -147,6 +151,7 @@ def predict_proba(df: ROOT.RDataFrame) -> ROOT.RDataFrame:
147151 """ ,
148152 )
149153
154+
150155def infer_output_ml_features (df : ROOT .RDataFrame ) -> ROOT .RDataFrame :
151156 """
152157 Choose for each feature the best candidate with the highest probability score.
0 commit comments