Skip to content

Commit 5d5d765

Browse files
committed
[RF] Faster Hesse in RooFit by advertising which params are independent
This reduces the time to run Hesse in the ATLAS Higgs benchmark from 123 s to 92 seconds. Given that some models take hours for this, this is a significant improvement for the user experience. Further improvement is possible by analyzing the computation graph a bit more to find more independent parameters (e.g., the different gammas for stat uncertainties from different bins).
1 parent 1d42767 commit 5d5d765

20 files changed

+172
-14
lines changed

math/mathcore/inc/Math/Functor.h

+10
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class Functor : public IBaseFunctionMultiDim {
6767
// for multi-dimensional functions
6868
unsigned int NDim() const override { return fDim; }
6969

70+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
71+
72+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
73+
7074
private :
7175

7276
inline double DoEval (const double * x) const override {
@@ -75,6 +79,7 @@ private :
7579

7680
unsigned int fDim;
7781
std::function<double(double const *)> fFunc;
82+
std::function<bool(int, int)> fVanishSecondDerivFunc;
7883
};
7984

8085
/**
@@ -222,6 +227,10 @@ class GradFunctor : public IGradientFunctionMultiDim {
222227
fGradFunc(x, g);
223228
}
224229

230+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
231+
232+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
233+
225234
private :
226235

227236
inline double DoEval (const double * x) const override {
@@ -244,6 +253,7 @@ private :
244253
std::function<double(const double *)> fFunc;
245254
std::function<double(double const *, unsigned int)> fDerivFunc;
246255
std::function<void(const double *, double*)> fGradFunc;
256+
std::function<bool(int, int)> fVanishSecondDerivFunc;
247257
};
248258

249259

math/mathcore/inc/Math/IFunction.h

+5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ namespace ROOT {
9191
// if it inherits from ROOT::Math::IGradientFunctionMultiDim.
9292
virtual bool HasGradient() const { return false; }
9393

94+
/// Indicate whether a given second order derivative with respect to
95+
/// parameters i and j is always zero. This can help to avoid
96+
/// expensive function calls in Hessian evaluations.
97+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
98+
9499
private:
95100

96101
/// Implementation of the evaluation function. Must be implemented by derived classes.

math/minuit2/inc/Minuit2/FCNAdapter.h

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class FCNAdapter : public FCNBase {
4343

4444
void SetErrorDef(double up) override { fUp = up; }
4545

46+
bool VanishingSecondDerivative(int i, int j) const override { return fFunc.VanishingSecondDerivative(i, j); }
47+
4648
private:
4749
const Function &fFunc;
4850
double fUp;

math/minuit2/inc/Minuit2/FCNBase.h

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ class FCNBase : public GenericFunction {
132132
virtual bool HasHessian() const { return false; }
133133

134134
virtual bool HasG2() const { return false; }
135+
136+
/// Indicate whether a given second order derivative with respect to
137+
/// parameters i and j is always zero. This can help to avoid
138+
/// expensive function calls in Hessian evaluations.
139+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
135140
};
136141

137142
} // namespace Minuit2

math/minuit2/inc/Minuit2/FCNGradAdapter.h

+2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ class FCNGradAdapter : public FCNBase {
119119

120120
void SetErrorDef(double up) override { fUp = up; }
121121

122+
bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }
123+
122124
private:
123125
const Function &fFunc;
124126
double fUp;

math/minuit2/src/MnHesse.cxx

+7-4
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,17 @@ MinimumState MnHesse::ComputeNumerical(const MnFcn &mfcn, const MinimumState &st
318318
if ((i + 1) == j || in == startParIndexOffDiagonal)
319319
x(i) += dirin(i);
320320

321-
x(j) += dirin(j);
322-
323-
double fs1 = mfcn(x);
324-
if(!doCentralFD) {
321+
if(mfcn.Fcn().VanishingSecondDerivative(i, j)) {
322+
vhmat(i, j) = 0.;
323+
} else if(!doCentralFD) {
324+
x(j) += dirin(j);
325+
double fs1 = mfcn(x);
325326
double elem = (fs1 + amin - yy(i) - yy(j)) / (dirin(i) * dirin(j));
326327
vhmat(i, j) = elem;
327328
x(j) -= dirin(j);
328329
} else {
330+
x(j) += dirin(j);
331+
double fs1 = mfcn(x);
329332
// three more function evaluations required for central fd
330333
x(i) -= dirin(i); x(i) -= dirin(i);double fs3 = mfcn(x);
331334
x(j) -= dirin(j); x(j) -= dirin(j);double fs4 = mfcn(x);

roofit/roofitcore/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitCore
136136
RooFit/TestStatistics/RooUnbinnedL.h
137137
RooFit/TestStatistics/SharedOffset.h
138138
RooFit/TestStatistics/buildLikelihood.h
139+
RooFit/VariableGroups.h
139140
RooFitLegacy/RooCatTypeLegacy.h
140141
RooFitLegacy/RooCategorySharedProperties.h
141142
RooFitLegacy/RooTreeData.h

roofit/roofitcore/inc/RooAbsArg.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ namespace RooFit {
5757
namespace Experimental {
5858
class CodegenContext;
5959
}
60-
}
60+
struct VariableGroups;
61+
} // namespace RooFit
6162

6263
class RooRefArray : public TObjArray {
6364
public:
@@ -248,7 +249,7 @@ class RooAbsArg : public TNamed, public RooPrintable {
248249
bool recursiveCheckObservables(const RooArgSet* nset) const ;
249250
RooFit::OwningPtr<RooArgSet> getComponents() const ;
250251

251-
252+
virtual void fillVariableGroups(RooFit::VariableGroups &out) const;
252253

253254
void attachArgs(const RooAbsCollection &set);
254255
void attachDataSet(const RooAbsData &set);

roofit/roofitcore/inc/RooAddition.h

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class RooAddition : public RooAbsReal {
5656

5757
void doEval(RooFit::EvalContext &) const override;
5858

59+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
60+
5961
protected:
6062

6163
RooArgList _ownedList ; ///< List of owned components

roofit/roofitcore/inc/RooConstraintSum.h

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class RooConstraintSum : public RooAbsReal {
4545

4646
std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;
4747

48+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
49+
4850
protected:
4951

5052
RooListProxy _set1 ; ///< Set of constraint terms
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifndef RooFit_VariableGroups_h
2+
#define RooFit_VariableGroups_h
3+
4+
#include <TNamed.h>
5+
6+
#include <iostream>
7+
#include <unordered_map>
8+
#include <vector>
9+
10+
class TNamed;
11+
12+
namespace RooFit {
13+
14+
struct VariableGroups {
15+
16+
std::unordered_map<TNamed const*, std::vector<int>> groups;
17+
18+
inline void print() {
19+
for (auto const& item : groups) {
20+
std::cout << item.first->GetName() << " :";
21+
for (int n : item.second) {
22+
std::cout << " " << n;
23+
}
24+
std::cout << std::endl;
25+
}
26+
}
27+
28+
int currentIndex = 0;
29+
};
30+
31+
}
32+
33+
#endif

roofit/roofitcore/src/RooAbsArg.cxx

+19-1
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ for single nodes.
7575
#include <RooArgSet.h>
7676
#include <RooConstVar.h>
7777
#include <RooExpensiveObjectCache.h>
78+
#include <RooFit/VariableGroups.h>
7879
#include <RooHelpers.h>
79-
#include "RooFitImplHelpers.h"
8080
#include <RooListProxy.h>
8181
#include <RooMsgService.h>
8282
#include <RooRealIntegral.h>
@@ -87,6 +87,8 @@ for single nodes.
8787
#include <RooVectorDataStore.h>
8888
#include <RooWorkspace.h>
8989

90+
#include "RooFitImplHelpers.h"
91+
9092
#include <TBuffer.h>
9193
#include <TClass.h>
9294
#include <TVirtualStreamerInfo.h>
@@ -2529,3 +2531,19 @@ void RooAbsArg::setDataToken(std::size_t index)
25292531
}
25302532
_dataToken = index;
25312533
}
2534+
2535+
void RooAbsArg::fillVariableGroups(RooFit::VariableGroups &out) const
2536+
{
2537+
// Get the set of nodes in the computation graph. Do the detour via
2538+
// RooArgList to avoid deduplication done after adding each element.
2539+
RooArgSet serverSet;
2540+
RooArgList serverList;
2541+
treeNodeServerList(&serverList, nullptr, /*branches*/ false, /*leaves*/ true, /*valueOnly*/ false,
2542+
/*recurseFundamental*/ true);
2543+
serverSet.add(serverList.begin(), serverList.end());
2544+
2545+
for (RooAbsArg const *arg : serverSet) {
2546+
out.groups[arg->namePtr()].push_back(out.currentIndex);
2547+
}
2548+
out.currentIndex++;
2549+
}

roofit/roofitcore/src/RooAddition.cxx

+7
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,10 @@ std::list<double>* RooAddition::plotSamplingHint(RooAbsRealLValue& obs, double x
297297
{
298298
return RooRealSumPdf::plotSamplingHint(_set, obs, xlo, xhi);
299299
}
300+
301+
void RooAddition::fillVariableGroups(RooFit::VariableGroups &out) const
302+
{
303+
for (RooAbsArg *arg : _set) {
304+
arg->fillVariableGroups(out);
305+
}
306+
}

roofit/roofitcore/src/RooConstraintSum.cxx

+7
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,10 @@ bool RooConstraintSum::setData(RooAbsData const& data, bool /*cloneData=true*/)
114114
}
115115
return true;
116116
}
117+
118+
void RooConstraintSum::fillVariableGroups(RooFit::VariableGroups &out) const
119+
{
120+
for (RooAbsArg *arg : _set1) {
121+
arg->fillVariableGroups(out);
122+
}
123+
}

roofit/roofitcore/src/RooEvaluatorWrapper.cxx

+5
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,9 @@ bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
136136
return true;
137137
}
138138

139+
void RooEvaluatorWrapper::fillVariableGroups(RooFit::VariableGroups &out) const
140+
{
141+
_topNode->fillVariableGroups(out);
142+
}
143+
139144
/// \endcond

roofit/roofitcore/src/RooEvaluatorWrapper.h

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class RooEvaluatorWrapper final : public RooAbsReal {
6161
/// The RooFit::Evaluator is dealing with constant terms itself.
6262
void constOptimizeTestStatistic(ConstOpCode /*opcode*/, bool /*doAlsoTrackingOpt*/) override {}
6363

64+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
65+
6466
protected:
6567
double evaluate() const override;
6668

roofit/roofitcore/src/RooMinimizer.cxx

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ RooMinimizer::RooMinimizer(RooAbsReal &function, Config const &cfg) : _cfg(cfg)
134134
_fcn = std::make_unique<RooMinimizerFcn>(&function, this);
135135
}
136136
initMinimizerFcnDependentPart(function.defaultErrorLevel());
137-
};
137+
}
138138

139139
/// Initialize the part of the minimizer that is independent of the function to be minimized
140140
void RooMinimizer::initMinimizerFirstPart()

roofit/roofitcore/src/RooMinimizerFcn.cxx

+53-6
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
#include "RooAbsArg.h"
2424
#include "RooAbsPdf.h"
2525
#include "RooArgSet.h"
26-
#include "RooRealVar.h"
27-
#include "RooMsgService.h"
26+
#include "RooFit/VariableGroups.h"
2827
#include "RooMinimizer.h"
28+
#include "RooMsgService.h"
2929
#include "RooNaNPacker.h"
30+
#include "RooRealVar.h"
3031

3132
#include "Math/Functor.h"
3233
#include "TMatrixDSym.h"
@@ -38,6 +39,23 @@ using std::setprecision;
3839

3940
namespace {
4041

42+
template <class InputIt1, class InputIt2>
43+
bool intersect(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2)
44+
{
45+
while (first1 != last1 && first2 != last2) {
46+
if (*first1 < *first2) {
47+
++first1;
48+
continue;
49+
}
50+
if (*first2 < *first1) {
51+
++first2;
52+
continue;
53+
}
54+
return true;
55+
}
56+
return false;
57+
}
58+
4159
// Helper function that wraps RooAbsArg::getParameters and directly returns the
4260
// output RooArgSet. To be used in the initializer list of the RooMinimizerFcn
4361
// constructor.
@@ -54,14 +72,38 @@ RooArgSet getParameters(RooAbsReal const &funct)
5472
RooMinimizerFcn::RooMinimizerFcn(RooAbsReal *funct, RooMinimizer *context)
5573
: RooAbsMinimizerFcn(getParameters(*funct), context), _funct(funct)
5674
{
57-
unsigned int nDim = getNDim();
75+
RooFit::VariableGroups groups;
76+
funct->fillVariableGroups(groups);
77+
78+
RooArgList parameters;
79+
for (std::size_t i = 0; i < getNDim(); ++i) {
80+
parameters.add(floatableParam(i));
81+
}
82+
83+
std::size_t nParams = parameters.size();
84+
85+
_secondDerivMask.resize(nParams * nParams);
86+
for (std::size_t i = 0; i < nParams; ++i) {
87+
_secondDerivMask[nParams * i + i] = 1;
88+
for (std::size_t j = 0; j < i; ++j) {
89+
// std::cout << parameters[i].GetName() << " " << parameters[j].GetName() << std::endl;
90+
auto const &gr1 = groups.groups.at(parameters[i].namePtr());
91+
auto const &gr2 = groups.groups.at(parameters[j].namePtr());
92+
_secondDerivMask[nParams * i + j] = intersect(gr1.begin(), gr1.end(), gr2.begin(), gr2.end());
93+
_secondDerivMask[nParams * j + i] = _secondDerivMask[nParams * i + j];
94+
}
95+
}
5896

5997
if (context->_cfg.useGradient && funct->hasGradient()) {
6098
_gradientOutput.resize(_allParams.size());
61-
_multiGenFcn = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
62-
&RooMinimizerFcn::evaluateGradient, nDim);
99+
auto functor = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
100+
&RooMinimizerFcn::evaluateGradient, getNDim());
101+
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
102+
_multiGenFcn = std::move(functor);
63103
} else {
64-
_multiGenFcn = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
104+
auto functor = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
105+
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
106+
_multiGenFcn = std::move(functor);
65107
}
66108
}
67109

@@ -145,3 +187,8 @@ void RooMinimizerFcn::setOffsetting(bool flag)
145187
{
146188
_funct->enableOffsetting(flag);
147189
}
190+
191+
bool RooMinimizerFcn::vanishingSecondDerivative(int i, int j) const
192+
{
193+
return _secondDerivMask[getNDim() * i + j] == 0;
194+
}

roofit/roofitcore/src/RooMinimizerFcn.h

+3
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ class RooMinimizerFcn : public RooAbsMinimizerFcn {
4444
double operator()(const double *x) const;
4545
void evaluateGradient(const double *x, double *out) const;
4646

47+
bool vanishingSecondDerivative(int i, int j) const;
48+
4749
private:
4850
RooAbsReal *_funct = nullptr;
4951
std::unique_ptr<ROOT::Math::IBaseFunctionMultiDim> _multiGenFcn;
5052
mutable std::vector<double> _gradientOutput;
53+
std::vector<int> _secondDerivMask;
5154
};
5255

5356
#endif

roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class MinuitGradFunctor : public ROOT::Math::IMultiGradFunction {
4343

4444
bool returnsInMinuit2ParameterSpace() const override { return _fcn.returnsInMinuit2ParameterSpace(); }
4545

46+
// TODO: Implement this
47+
bool VanishingSecondDerivative(int /*i*/, int /*j*/) const override { return false; }
48+
4649
private:
4750
double DoEval(const double *x) const override { return _fcn(x); }
4851

0 commit comments

Comments
 (0)