Skip to content

Commit 9ebb89a

Browse files
committed
[RF] Faster Hesse in RooFit by advertising which params are independent
This reduces the time to run Hesse in the ATLAS Higgs benchmark from 123 s to 92 seconds. Given that some models take hours for this, this is a significant improvement for the user experience. Further improvement is possible by analyzing the computation graph a bit more to find more independent parameters (e.g., the different gammas for stat uncertainties from different bins).
1 parent 3091720 commit 9ebb89a

20 files changed

+174
-14
lines changed

math/mathcore/inc/Math/Functor.h

+10
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class Functor : public IBaseFunctionMultiDim {
6767
// for multi-dimensional functions
6868
unsigned int NDim() const override { return fDim; }
6969

70+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
71+
72+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
73+
7074
private :
7175

7276
inline double DoEval (const double * x) const override {
@@ -75,6 +79,7 @@ private :
7579

7680
unsigned int fDim;
7781
std::function<double(double const *)> fFunc;
82+
std::function<bool(int, int)> fVanishSecondDerivFunc;
7883
};
7984

8085
/**
@@ -222,6 +227,10 @@ class GradFunctor : public IGradientFunctionMultiDim {
222227
fGradFunc(x, g);
223228
}
224229

230+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
231+
232+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
233+
225234
private :
226235

227236
inline double DoEval (const double * x) const override {
@@ -244,6 +253,7 @@ private :
244253
std::function<double(const double *)> fFunc;
245254
std::function<double(double const *, unsigned int)> fDerivFunc;
246255
std::function<void(const double *, double*)> fGradFunc;
256+
std::function<bool(int, int)> fVanishSecondDerivFunc;
247257
};
248258

249259

math/mathcore/inc/Math/IFunction.h

+5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ namespace ROOT {
8989
// if it inherits from ROOT::Math::IGradientFunctionMultiDim.
9090
virtual bool HasGradient() const { return false; }
9191

92+
/// Indicate whether a given second order derivative with respect to
93+
/// parameters i and j is always zero. This can help to avoid
94+
/// expensive function calls in Hessian evaluations.
95+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
96+
9297
private:
9398

9499
/// Implementation of the evaluation function. Must be implemented by derived classes.

math/minuit2/inc/Minuit2/FCNAdapter.h

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class FCNAdapter : public FCNBase {
4343

4444
void SetErrorDef(double up) override { fUp = up; }
4545

46+
bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }
47+
4648
private:
4749
const Function &fFunc;
4850
double fUp;

math/minuit2/inc/Minuit2/FCNBase.h

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ class FCNBase : public GenericFunction {
132132
virtual bool HasHessian() const { return false; }
133133

134134
virtual bool HasG2() const { return false; }
135+
136+
/// Indicate whether a given second order derivative with respect to
137+
/// parameters i and j is always zero. This can help to avoid
138+
/// expensive function calls in Hessian evaluations.
139+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
135140
};
136141

137142
} // namespace Minuit2

math/minuit2/inc/Minuit2/FCNGradAdapter.h

+2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ class FCNGradAdapter : public FCNBase {
119119

120120
void SetErrorDef(double up) override { fUp = up; }
121121

122+
bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }
123+
122124
private:
123125
const Function &fFunc;
124126
double fUp;

math/minuit2/src/MnHesse.cxx

+7-4
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,17 @@ MinimumState MnHesse::ComputeNumerical(const MnFcn &mfcn, const MinimumState &st
318318
if ((i + 1) == j || in == startParIndexOffDiagonal)
319319
x(i) += dirin(i);
320320

321-
x(j) += dirin(j);
322-
323-
double fs1 = mfcn(x);
324-
if(!doCentralFD) {
321+
if(mfcn.Fcn().VanishingSecondDerivative(i, j)) {
322+
vhmat(i, j) = 0.;
323+
} else if(!doCentralFD) {
324+
x(j) += dirin(j);
325+
double fs1 = mfcn(x);
325326
double elem = (fs1 + amin - yy(i) - yy(j)) / (dirin(i) * dirin(j));
326327
vhmat(i, j) = elem;
327328
x(j) -= dirin(j);
328329
} else {
330+
x(j) += dirin(j);
331+
double fs1 = mfcn(x);
329332
// three more function evaluations required for central fd
330333
x(i) -= dirin(i); x(i) -= dirin(i);double fs3 = mfcn(x);
331334
x(j) -= dirin(j); x(j) -= dirin(j);double fs4 = mfcn(x);

roofit/roofitcore/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,9 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitCore
132132
RooFit/TestStatistics/RooSubsidiaryL.h
133133
RooFit/TestStatistics/RooSumL.h
134134
RooFit/TestStatistics/RooUnbinnedL.h
135-
RooFit/TestStatistics/buildLikelihood.h
136135
RooFit/TestStatistics/SharedOffset.h
136+
RooFit/TestStatistics/buildLikelihood.h
137+
RooFit/VariableGroups.h
137138
RooFitLegacy/RooCatTypeLegacy.h
138139
RooFitLegacy/RooCategorySharedProperties.h
139140
RooFitLegacy/RooTreeData.h

roofit/roofitcore/inc/RooAbsArg.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ using RooListProxy = RooCollectionProxy<RooArgList>;
5656
class RooExpensiveObjectCache ;
5757
class RooWorkspace ;
5858
namespace RooFit {
59+
5960
namespace Detail {
6061
class CodeSquashContext;
6162
}
63+
struct VariableGroups;
6264
}
6365

6466
class RooRefArray : public TObjArray {
@@ -279,7 +281,7 @@ class RooAbsArg : public TNamed, public RooPrintable {
279281
bool recursiveCheckObservables(const RooArgSet* nset) const ;
280282
RooFit::OwningPtr<RooArgSet> getComponents() const ;
281283

282-
284+
virtual void fillVariableGroups(RooFit::VariableGroups &out) const;
283285

284286
void attachArgs(const RooAbsCollection &set);
285287
void attachDataSet(const RooAbsData &set);

roofit/roofitcore/inc/RooAddition.h

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class RooAddition : public RooAbsReal {
5858

5959
void translate(RooFit::Detail::CodeSquashContext &ctx) const override;
6060

61+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
62+
6163
protected:
6264

6365
RooArgList _ownedList ; ///< List of owned components

roofit/roofitcore/inc/RooConstraintSum.h

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class RooConstraintSum : public RooAbsReal {
4646
std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;
4747

4848
void translate(RooFit::Detail::CodeSquashContext &ctx) const override;
49+
50+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
51+
4952
protected:
5053

5154
RooListProxy _set1 ; ///< Set of constraint terms
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifndef RooFit_VariableGroups_h
2+
#define RooFit_VariableGroups_h
3+
4+
#include <TNamed.h>
5+
6+
#include <iostream>
7+
#include <unordered_map>
8+
#include <vector>
9+
10+
class TNamed;
11+
12+
namespace RooFit {
13+
14+
struct VariableGroups {
15+
16+
std::unordered_map<TNamed const*, std::vector<int>> groups;
17+
18+
inline void print() {
19+
for (auto const& item : groups) {
20+
std::cout << item.first->GetName() << " :";
21+
for (int n : item.second) {
22+
std::cout << " " << n;
23+
}
24+
std::cout << std::endl;
25+
}
26+
}
27+
28+
int currentIndex = 0;
29+
};
30+
31+
}
32+
33+
#endif

roofit/roofitcore/src/RooAbsArg.cxx

+19-1
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ for single nodes.
7575
#include <RooArgSet.h>
7676
#include <RooConstVar.h>
7777
#include <RooExpensiveObjectCache.h>
78+
#include <RooFit/VariableGroups.h>
7879
#include <RooHelpers.h>
79-
#include "RooFitImplHelpers.h"
8080
#include <RooListProxy.h>
8181
#include <RooMsgService.h>
8282
#include <RooRealIntegral.h>
@@ -87,6 +87,8 @@ for single nodes.
8787
#include <RooVectorDataStore.h>
8888
#include <RooWorkspace.h>
8989

90+
#include "RooFitImplHelpers.h"
91+
9092
#include <TBuffer.h>
9193
#include <TClass.h>
9294
#include <TVirtualStreamerInfo.h>
@@ -2566,3 +2568,19 @@ void RooAbsArg::setDataToken(std::size_t index)
25662568
}
25672569
_dataToken = index;
25682570
}
2571+
2572+
void RooAbsArg::fillVariableGroups(RooFit::VariableGroups &out) const
2573+
{
2574+
// Get the set of nodes in the computation graph. Do the detour via
2575+
// RooArgList to avoid deduplication done after adding each element.
2576+
RooArgSet serverSet;
2577+
RooArgList serverList;
2578+
treeNodeServerList(&serverList, nullptr, /*branches*/ false, /*leaves*/ true, /*valueOnly*/ false,
2579+
/*recurseFundamental*/ true);
2580+
serverSet.add(serverList.begin(), serverList.end());
2581+
2582+
for (RooAbsArg const *arg : serverSet) {
2583+
out.groups[arg->namePtr()].push_back(out.currentIndex);
2584+
}
2585+
out.currentIndex++;
2586+
}

roofit/roofitcore/src/RooAddition.cxx

+7
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,10 @@ std::list<double>* RooAddition::plotSamplingHint(RooAbsRealLValue& obs, double x
330330
{
331331
return RooRealSumPdf::plotSamplingHint(_set, obs, xlo, xhi);
332332
}
333+
334+
void RooAddition::fillVariableGroups(RooFit::VariableGroups &out) const
335+
{
336+
for (RooAbsArg *arg : _set) {
337+
arg->fillVariableGroups(out);
338+
}
339+
}

roofit/roofitcore/src/RooConstraintSum.cxx

+7
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,10 @@ bool RooConstraintSum::setData(RooAbsData const& data, bool /*cloneData=true*/)
120120
}
121121
return true;
122122
}
123+
124+
void RooConstraintSum::fillVariableGroups(RooFit::VariableGroups &out) const
125+
{
126+
for (RooAbsArg *arg : _set1) {
127+
arg->fillVariableGroups(out);
128+
}
129+
}

roofit/roofitcore/src/RooEvaluatorWrapper.cxx

+5
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,9 @@ bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
136136
return true;
137137
}
138138

139+
void RooEvaluatorWrapper::fillVariableGroups(RooFit::VariableGroups &out) const
140+
{
141+
_topNode->fillVariableGroups(out);
142+
}
143+
139144
/// \endcond

roofit/roofitcore/src/RooEvaluatorWrapper.h

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class RooEvaluatorWrapper final : public RooAbsReal {
6161
/// The RooFit::Evaluator is dealing with constant terms itself.
6262
void constOptimizeTestStatistic(ConstOpCode /*opcode*/, bool /*doAlsoTrackingOpt*/) override {}
6363

64+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
65+
6466
protected:
6567
double evaluate() const override;
6668

roofit/roofitcore/src/RooMinimizer.cxx

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ RooMinimizer::RooMinimizer(RooAbsReal &function, Config const &cfg) : _cfg(cfg)
135135
_fcn = std::make_unique<RooMinimizerFcn>(&function, this);
136136
}
137137
initMinimizerFcnDependentPart(function.defaultErrorLevel());
138-
};
138+
}
139139

140140
/// Initialize the part of the minimizer that is independent of the function to be minimized
141141
void RooMinimizer::initMinimizerFirstPart()

roofit/roofitcore/src/RooMinimizerFcn.cxx

+53-6
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
#include "RooAbsArg.h"
2424
#include "RooAbsPdf.h"
2525
#include "RooArgSet.h"
26-
#include "RooRealVar.h"
27-
#include "RooMsgService.h"
26+
#include "RooFit/VariableGroups.h"
2827
#include "RooMinimizer.h"
28+
#include "RooMsgService.h"
2929
#include "RooNaNPacker.h"
30+
#include "RooRealVar.h"
3031

3132
#include "Math/Functor.h"
3233
#include "TMatrixDSym.h"
@@ -38,6 +39,23 @@ using std::cout, std::endl, std::setprecision;
3839

3940
namespace {
4041

42+
template <class InputIt1, class InputIt2>
43+
bool intersect(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2)
44+
{
45+
while (first1 != last1 && first2 != last2) {
46+
if (*first1 < *first2) {
47+
++first1;
48+
continue;
49+
}
50+
if (*first2 < *first1) {
51+
++first2;
52+
continue;
53+
}
54+
return true;
55+
}
56+
return false;
57+
}
58+
4159
// Helper function that wraps RooAbsArg::getParameters and directly returns the
4260
// output RooArgSet. To be used in the initializer list of the RooMinimizerFcn
4361
// constructor.
@@ -54,14 +72,38 @@ RooArgSet getParameters(RooAbsReal const &funct)
5472
RooMinimizerFcn::RooMinimizerFcn(RooAbsReal *funct, RooMinimizer *context)
5573
: RooAbsMinimizerFcn(getParameters(*funct), context), _funct(funct)
5674
{
57-
unsigned int nDim = getNDim();
75+
RooFit::VariableGroups groups;
76+
funct->fillVariableGroups(groups);
77+
78+
RooArgList parameters;
79+
for (std::size_t i = 0; i < getNDim(); ++i) {
80+
parameters.add(floatableParam(i));
81+
}
82+
83+
std::size_t nParams = parameters.size();
84+
85+
_secondDerivMask.resize(nParams * nParams);
86+
for (std::size_t i = 0; i < nParams; ++i) {
87+
_secondDerivMask[nParams * i + i] = 1;
88+
for (std::size_t j = 0; j < i; ++j) {
89+
// std::cout << parameters[i].GetName() << " " << parameters[j].GetName() << std::endl;
90+
auto const &gr1 = groups.groups.at(parameters[i].namePtr());
91+
auto const &gr2 = groups.groups.at(parameters[j].namePtr());
92+
_secondDerivMask[nParams * i + j] = intersect(gr1.begin(), gr1.end(), gr2.begin(), gr2.end());
93+
_secondDerivMask[nParams * j + i] = _secondDerivMask[nParams * i + j];
94+
}
95+
}
5896

5997
if (context->_cfg.useGradient && funct->hasGradient()) {
6098
_gradientOutput.resize(_allParams.size());
61-
_multiGenFcn = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
62-
&RooMinimizerFcn::evaluateGradient, nDim);
99+
auto functor = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
100+
&RooMinimizerFcn::evaluateGradient, getNDim());
101+
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
102+
_multiGenFcn = std::move(functor);
63103
} else {
64-
_multiGenFcn = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
104+
auto functor = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
105+
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
106+
_multiGenFcn = std::move(functor);
65107
}
66108
}
67109

@@ -145,3 +187,8 @@ void RooMinimizerFcn::setOffsetting(bool flag)
145187
{
146188
_funct->enableOffsetting(flag);
147189
}
190+
191+
bool RooMinimizerFcn::vanishingSecondDerivative(int i, int j) const
192+
{
193+
return _secondDerivMask[getNDim() * i + j] == 0;
194+
}

roofit/roofitcore/src/RooMinimizerFcn.h

+3
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ class RooMinimizerFcn : public RooAbsMinimizerFcn {
4444
double operator()(const double *x) const;
4545
void evaluateGradient(const double *x, double *out) const;
4646

47+
bool vanishingSecondDerivative(int i, int j) const;
48+
4749
private:
4850
RooAbsReal *_funct = nullptr;
4951
std::unique_ptr<ROOT::Math::IBaseFunctionMultiDim> _multiGenFcn;
5052
mutable std::vector<double> _gradientOutput;
53+
std::vector<int> _secondDerivMask;
5154
};
5255

5356
#endif

roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class MinuitGradFunctor : public ROOT::Math::IMultiGradFunction {
4343

4444
bool returnsInMinuit2ParameterSpace() const override { return _fcn.returnsInMinuit2ParameterSpace(); }
4545

46+
// TODO: Implement this
47+
bool VanishingSecondDerivative(int /*i*/, int /*j*/) const override { return false; }
48+
4649
private:
4750
double DoEval(const double *x) const override { return _fcn(x); }
4851

0 commit comments

Comments
 (0)