Skip to content

Commit eae7b17

Browse files
committed
[RF] Faster Hesse in RooFit by advertising which params are independent
This reduces the time to run Hesse in the ATLAS Higgs benchmark from 123 s to 92 seconds. Given that some models take hours for this, this is a significant improvement for the user experience. Further improvement is possible by analyzing the computation graph a bit more to find more independent parameters (e.g., the different gammas for stat uncertainties from different bins).
1 parent 5e4ba5d commit eae7b17

20 files changed

+171
-12
lines changed

math/mathcore/inc/Math/Functor.h

+10
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class Functor : public IBaseFunctionMultiDim {
6767
// for multi-dimensional functions
6868
unsigned int NDim() const override { return fDim; }
6969

70+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
71+
72+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
73+
7074
private :
7175

7276
inline double DoEval (const double * x) const override {
@@ -75,6 +79,7 @@ private :
7579

7680
unsigned int fDim;
7781
std::function<double(double const *)> fFunc;
82+
std::function<bool(int, int)> fVanishSecondDerivFunc;
7883
};
7984

8085
/**
@@ -222,6 +227,10 @@ class GradFunctor : public IGradientFunctionMultiDim {
222227
fGradFunc(x, g);
223228
}
224229

230+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
231+
232+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
233+
225234
private :
226235

227236
inline double DoEval (const double * x) const override {
@@ -244,6 +253,7 @@ private :
244253
std::function<double(const double *)> fFunc;
245254
std::function<double(double const *, unsigned int)> fDerivFunc;
246255
std::function<void(const double *, double*)> fGradFunc;
256+
std::function<bool(int, int)> fVanishSecondDerivFunc;
247257
};
248258

249259

math/mathcore/inc/Math/IFunction.h

+5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ namespace ROOT {
8989
// if it inherits from ROOT::Math::IGradientFunctionMultiDim.
9090
virtual bool HasGradient() const { return false; }
9191

92+
/// Indicate whether a given second order derivative with respect to
93+
/// parameters i and j is always zero. This can help to avoid
94+
/// expensive function calls in Hessian evaluations.
95+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
96+
9297
private:
9398

9499
/// Implementation of the evaluation function. Must be implemented by derived classes.

math/minuit2/inc/Minuit2/FCNAdapter.h

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class FCNAdapter : public FCNBase {
5050
// forward interface
5151
// virtual double operator()(int npar, double* params,int iflag = 4) const;
5252

53+
bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }
54+
5355
private:
5456
const Function &fFunc;
5557
double fUp;

math/minuit2/inc/Minuit2/FCNBase.h

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ class FCNBase : public GenericFunction {
132132
virtual bool HasHessian() const { return false; }
133133

134134
virtual bool HasG2() const { return false; }
135+
136+
/// Indicate whether a given second order derivative with respect to
137+
/// parameters i and j is always zero. This can help to avoid
138+
/// expensive function calls in Hessian evaluations.
139+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
135140
};
136141

137142
} // namespace Minuit2

math/minuit2/inc/Minuit2/FCNGradAdapter.h

+2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ class FCNGradAdapter : public FCNBase {
119119

120120
void SetErrorDef(double up) override { fUp = up; }
121121

122+
bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }
123+
122124
private:
123125
const Function &fFunc;
124126
double fUp;

math/minuit2/src/MnHesse.cxx

+7-4
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,17 @@ MinimumState MnHesse::ComputeNumerical(const MnFcn &mfcn, const MinimumState &st
352352
if ((i + 1) == j || in == startParIndexOffDiagonal)
353353
x(i) += dirin(i);
354354

355-
x(j) += dirin(j);
356-
357-
double fs1 = mfcn(x);
358-
if(!doCentralFD) {
355+
if(mfcn.Fcn().VanishingSecondDerivative(i, j)) {
356+
vhmat(i, j) = 0.;
357+
} else if(!doCentralFD) {
358+
x(j) += dirin(j);
359+
double fs1 = mfcn(x);
359360
double elem = (fs1 + amin - yy(i) - yy(j)) / (dirin(i) * dirin(j));
360361
vhmat(i, j) = elem;
361362
x(j) -= dirin(j);
362363
} else {
364+
x(j) += dirin(j);
365+
double fs1 = mfcn(x);
363366
// three more function evaluations required for central fd
364367
x(i) -= dirin(i); x(i) -= dirin(i);double fs3 = mfcn(x);
365368
x(j) -= dirin(j); x(j) -= dirin(j);double fs4 = mfcn(x);

roofit/roofitcore/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitCore
133133
RooFit/TestStatistics/RooSubsidiaryL.h
134134
RooFit/TestStatistics/RooSumL.h
135135
RooFit/TestStatistics/RooUnbinnedL.h
136-
RooFit/TestStatistics/buildLikelihood.h
137136
RooFit/TestStatistics/SharedOffset.h
137+
RooFit/TestStatistics/buildLikelihood.h
138+
RooFit/VariableGroups.h
138139
RooFitLegacy/RooCatTypeLegacy.h
139140
RooFitLegacy/RooCategorySharedProperties.h
140141
RooFitLegacy/RooTreeData.h

roofit/roofitcore/inc/RooAbsArg.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ using RooListProxy = RooCollectionProxy<RooArgList>;
5454
class RooExpensiveObjectCache ;
5555
class RooWorkspace ;
5656
namespace RooFit {
57+
5758
namespace Detail {
5859
class CodeSquashContext;
5960
}
61+
class VariableGroups;
6062
}
6163

6264
class RooRefArray : public TObjArray {
@@ -270,7 +272,7 @@ class RooAbsArg : public TNamed, public RooPrintable {
270272
bool recursiveCheckObservables(const RooArgSet* nset) const ;
271273
RooFit::OwningPtr<RooArgSet> getComponents() const ;
272274

273-
275+
virtual void fillVariableGroups(RooFit::VariableGroups &out) const;
274276

275277
void attachArgs(const RooAbsCollection &set);
276278
void attachDataSet(const RooAbsData &set);

roofit/roofitcore/inc/RooAddition.h

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class RooAddition : public RooAbsReal {
5858

5959
void translate(RooFit::Detail::CodeSquashContext &ctx) const override;
6060

61+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
62+
6163
protected:
6264

6365
RooArgList _ownedList ; ///< List of owned components

roofit/roofitcore/inc/RooConstraintSum.h

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class RooConstraintSum : public RooAbsReal {
4646
std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;
4747

4848
void translate(RooFit::Detail::CodeSquashContext &ctx) const override;
49+
50+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
51+
4952
protected:
5053

5154
RooListProxy _set1 ; ///< Set of constraint terms
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifndef RooFit_VariableGroups_h
2+
#define RooFit_VariableGroups_h
3+
4+
#include <TNamed.h>
5+
6+
#include <iostream>
7+
#include <unordered_map>
8+
#include <vector>
9+
10+
class TNamed;
11+
12+
namespace RooFit {
13+
14+
struct VariableGroups {
15+
16+
std::unordered_map<TNamed const*, std::vector<int>> groups;
17+
18+
inline void print() {
19+
for (auto const& item : groups) {
20+
std::cout << item.first->GetName() << " :";
21+
for (int n : item.second) {
22+
std::cout << " " << n;
23+
}
24+
std::cout << std::endl;
25+
}
26+
}
27+
28+
int currentIndex = 0;
29+
};
30+
31+
}
32+
33+
#endif

roofit/roofitcore/src/RooAbsArg.cxx

+19-1
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ for single nodes.
7575
#include <RooArgSet.h>
7676
#include <RooConstVar.h>
7777
#include <RooExpensiveObjectCache.h>
78+
#include <RooFit/VariableGroups.h>
7879
#include <RooHelpers.h>
79-
#include "RooFitImplHelpers.h"
8080
#include <RooListProxy.h>
8181
#include <RooMsgService.h>
8282
#include <RooRealIntegral.h>
@@ -87,6 +87,8 @@ for single nodes.
8787
#include <RooVectorDataStore.h>
8888
#include <RooWorkspace.h>
8989

90+
#include "RooFitImplHelpers.h"
91+
9092
#include <TBuffer.h>
9193
#include <TClass.h>
9294
#include <TVirtualStreamerInfo.h>
@@ -2566,3 +2568,19 @@ void RooAbsArg::setDataToken(std::size_t index)
25662568
}
25672569
_dataToken = index;
25682570
}
2571+
2572+
void RooAbsArg::fillVariableGroups(RooFit::VariableGroups &out) const
2573+
{
2574+
// Get the set of nodes in the computation graph. Do the detour via
2575+
// RooArgList to avoid deduplication done after adding each element.
2576+
RooArgSet serverSet;
2577+
RooArgList serverList;
2578+
treeNodeServerList(&serverList, nullptr, /*branches*/ false, /*leaves*/ true, /*valueOnly*/ false,
2579+
/*recurseFundamental*/ true);
2580+
serverSet.add(serverList.begin(), serverList.end());
2581+
2582+
for (RooAbsArg const *arg : serverSet) {
2583+
out.groups[arg->namePtr()].push_back(out.currentIndex);
2584+
}
2585+
out.currentIndex++;
2586+
}

roofit/roofitcore/src/RooAddition.cxx

+7
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,10 @@ std::list<double>* RooAddition::plotSamplingHint(RooAbsRealLValue& obs, double x
333333
{
334334
return RooRealSumPdf::plotSamplingHint(_set, obs, xlo, xhi);
335335
}
336+
337+
void RooAddition::fillVariableGroups(RooFit::VariableGroups &out) const
338+
{
339+
for (RooAbsArg *arg : _set) {
340+
arg->fillVariableGroups(out);
341+
}
342+
}

roofit/roofitcore/src/RooConstraintSum.cxx

+7
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,10 @@ bool RooConstraintSum::setData(RooAbsData const& data, bool /*cloneData=true*/)
120120
}
121121
return true;
122122
}
123+
124+
void RooConstraintSum::fillVariableGroups(RooFit::VariableGroups &out) const
125+
{
126+
for (RooAbsArg *arg : _set1) {
127+
arg->fillVariableGroups(out);
128+
}
129+
}

roofit/roofitcore/src/RooEvaluatorWrapper.cxx

+5
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,9 @@ bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
136136
return true;
137137
}
138138

139+
void RooEvaluatorWrapper::fillVariableGroups(RooFit::VariableGroups &out) const
140+
{
141+
_topNode->fillVariableGroups(out);
142+
}
143+
139144
/// \endcond

roofit/roofitcore/src/RooEvaluatorWrapper.h

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class RooEvaluatorWrapper final : public RooAbsReal {
6161
/// The RooFit::Evaluator is dealing with constant terms itself.
6262
void constOptimizeTestStatistic(ConstOpCode /*opcode*/, bool /*doAlsoTrackingOpt*/) override {}
6363

64+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
65+
6466
protected:
6567
double evaluate() const override;
6668

roofit/roofitcore/src/RooMinimizer.cxx

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ RooMinimizer::RooMinimizer(RooAbsReal &function, Config const &cfg) : _cfg(cfg)
138138
_fcn = std::make_unique<RooMinimizerFcn>(&function, this);
139139
}
140140
initMinimizerFcnDependentPart(function.defaultErrorLevel());
141-
};
141+
}
142142

143143
/// Initialize the part of the minimizer that is independent of the function to be minimized
144144
void RooMinimizer::initMinimizerFirstPart()

roofit/roofitcore/src/RooMinimizerFcn.cxx

+50-4
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
#include "RooAbsArg.h"
2424
#include "RooAbsPdf.h"
2525
#include "RooArgSet.h"
26-
#include "RooRealVar.h"
27-
#include "RooMsgService.h"
26+
#include "RooFit/VariableGroups.h"
2827
#include "RooMinimizer.h"
28+
#include "RooMsgService.h"
2929
#include "RooNaNPacker.h"
30+
#include "RooRealVar.h"
3031

3132
#include "Math/Functor.h"
3233
#include "TMatrixDSym.h"
@@ -38,6 +39,23 @@ using std::cout, std::endl, std::setprecision;
3839

3940
namespace {
4041

42+
template <class InputIt1, class InputIt2>
43+
bool intersect(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2)
44+
{
45+
while (first1 != last1 && first2 != last2) {
46+
if (*first1 < *first2) {
47+
++first1;
48+
continue;
49+
}
50+
if (*first2 < *first1) {
51+
++first2;
52+
continue;
53+
}
54+
return true;
55+
}
56+
return false;
57+
}
58+
4159
// Helper function that wraps RooAbsArg::getParameters and directly returns the
4260
// output RooArgSet. To be used in the initializer list of the RooMinimizerFcn
4361
// constructor.
@@ -54,11 +72,34 @@ RooArgSet getParameters(RooAbsReal const &funct)
5472
RooMinimizerFcn::RooMinimizerFcn(RooAbsReal *funct, RooMinimizer *context)
5573
: RooAbsMinimizerFcn(getParameters(*funct), context), _funct(funct)
5674
{
75+
RooFit::VariableGroups groups;
76+
funct->fillVariableGroups(groups);
77+
78+
RooArgList const &parameters = *GetFloatParamList();
79+
80+
std::size_t nParams = parameters.size();
81+
82+
_secondDerivMask.resize(nParams * nParams);
83+
for (std::size_t i = 0; i < nParams; ++i) {
84+
_secondDerivMask[nParams * i + i] = 1;
85+
for (std::size_t j = 0; j < i; ++j) {
86+
// std::cout << parameters[i].GetName() << " " << parameters[j].GetName() << std::endl;
87+
auto const &gr1 = groups.groups.at(parameters[i].namePtr());
88+
auto const &gr2 = groups.groups.at(parameters[j].namePtr());
89+
_secondDerivMask[nParams * i + j] = intersect(gr1.begin(), gr1.end(), gr2.begin(), gr2.end());
90+
_secondDerivMask[nParams * j + i] = _secondDerivMask[nParams * i + j];
91+
}
92+
}
93+
5794
if (context->_cfg.useGradient && funct->hasGradient()) {
58-
_multiGenFcn = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
95+
auto functor = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
5996
&RooMinimizerFcn::evaluateGradient, getNDim());
97+
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
98+
_multiGenFcn = std::move(functor);
6099
} else {
61-
_multiGenFcn = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
100+
auto functor = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
101+
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
102+
_multiGenFcn = std::move(functor);
62103
}
63104
}
64105

@@ -132,3 +173,8 @@ void RooMinimizerFcn::setOffsetting(bool flag)
132173
{
133174
_funct->enableOffsetting(flag);
134175
}
176+
177+
bool RooMinimizerFcn::vanishingSecondDerivative(int i, int j) const
178+
{
179+
return _secondDerivMask[_nDim * i + j] == 0;
180+
}

roofit/roofitcore/src/RooMinimizerFcn.h

+3
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ class RooMinimizerFcn : public RooAbsMinimizerFcn {
4444
double operator()(const double *x) const;
4545
void evaluateGradient(const double *x, double *out) const;
4646

47+
bool vanishingSecondDerivative(int i, int j) const;
48+
4749
private:
4850
RooAbsReal *_funct;
4951
std::unique_ptr<ROOT::Math::IBaseFunctionMultiDim> _multiGenFcn;
52+
std::vector<int> _secondDerivMask;
5053
};
5154

5255
#endif

roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class MinuitGradFunctor : public ROOT::Math::IMultiGradFunction {
4343

4444
bool returnsInMinuit2ParameterSpace() const override { return _fcn.returnsInMinuit2ParameterSpace(); }
4545

46+
// TODO: Implement this
47+
bool VanishingSecondDerivative(int /*i*/, int /*j*/) const override { false; }
48+
4649
private:
4750
double DoEval(const double *x) const override { return _fcn(x); }
4851

0 commit comments

Comments
 (0)