Skip to content

Commit 5d5d765

Browse files
committed
[RF] Faster Hesse in RooFit by advertising which params are independent
This reduces the time to run Hesse in the ATLAS Higgs benchmark from 123 s to 92 seconds. Given that some models take hours for this, this is a significant improvement for the user experience. Further improvement is possible by analyzing the computation graph a bit more to find more independent parameters (e.g., the different gammas for stat uncertainties from different bins).
1 parent 1d42767 commit 5d5d765

20 files changed

+172
-14
lines changed

math/mathcore/inc/Math/Functor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class Functor : public IBaseFunctionMultiDim {
6767
// for multi-dimensional functions
6868
unsigned int NDim() const override { return fDim; }
6969

70+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
71+
72+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
73+
7074
private :
7175

7276
inline double DoEval (const double * x) const override {
@@ -75,6 +79,7 @@ private :
7579

7680
unsigned int fDim;
7781
std::function<double(double const *)> fFunc;
82+
std::function<bool(int, int)> fVanishSecondDerivFunc;
7883
};
7984

8085
/**
@@ -222,6 +227,10 @@ class GradFunctor : public IGradientFunctionMultiDim {
222227
fGradFunc(x, g);
223228
}
224229

230+
bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }
231+
232+
void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }
233+
225234
private :
226235

227236
inline double DoEval (const double * x) const override {
@@ -244,6 +253,7 @@ private :
244253
std::function<double(const double *)> fFunc;
245254
std::function<double(double const *, unsigned int)> fDerivFunc;
246255
std::function<void(const double *, double*)> fGradFunc;
256+
std::function<bool(int, int)> fVanishSecondDerivFunc;
247257
};
248258

249259

math/mathcore/inc/Math/IFunction.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ namespace ROOT {
9191
// if it inherits from ROOT::Math::IGradientFunctionMultiDim.
9292
virtual bool HasGradient() const { return false; }
9393

94+
/// Indicate whether a given second order derivative with respect to
95+
/// parameters i and j is always zero. This can help to avoid
96+
/// expensive function calls in Hessian evaluations.
97+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
98+
9499
private:
95100

96101
/// Implementation of the evaluation function. Must be implemented by derived classes.

math/minuit2/inc/Minuit2/FCNAdapter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class FCNAdapter : public FCNBase {
4343

4444
void SetErrorDef(double up) override { fUp = up; }
4545

46+
bool VanishingSecondDerivative(int i, int j) const override { return fFunc.VanishingSecondDerivative(i, j); }
47+
4648
private:
4749
const Function &fFunc;
4850
double fUp;

math/minuit2/inc/Minuit2/FCNBase.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ class FCNBase : public GenericFunction {
132132
virtual bool HasHessian() const { return false; }
133133

134134
virtual bool HasG2() const { return false; }
135+
136+
/// Indicate whether a given second order derivative with respect to
137+
/// parameters i and j is always zero. This can help to avoid
138+
/// expensive function calls in Hessian evaluations.
139+
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
135140
};
136141

137142
} // namespace Minuit2

math/minuit2/inc/Minuit2/FCNGradAdapter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ class FCNGradAdapter : public FCNBase {
119119

120120
void SetErrorDef(double up) override { fUp = up; }
121121

122+
bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }
123+
122124
private:
123125
const Function &fFunc;
124126
double fUp;

math/minuit2/src/MnHesse.cxx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,14 +318,17 @@ MinimumState MnHesse::ComputeNumerical(const MnFcn &mfcn, const MinimumState &st
318318
if ((i + 1) == j || in == startParIndexOffDiagonal)
319319
x(i) += dirin(i);
320320

321-
x(j) += dirin(j);
322-
323-
double fs1 = mfcn(x);
324-
if(!doCentralFD) {
321+
if(mfcn.Fcn().VanishingSecondDerivative(i, j)) {
322+
vhmat(i, j) = 0.;
323+
} else if(!doCentralFD) {
324+
x(j) += dirin(j);
325+
double fs1 = mfcn(x);
325326
double elem = (fs1 + amin - yy(i) - yy(j)) / (dirin(i) * dirin(j));
326327
vhmat(i, j) = elem;
327328
x(j) -= dirin(j);
328329
} else {
330+
x(j) += dirin(j);
331+
double fs1 = mfcn(x);
329332
// three more function evaluations required for central fd
330333
x(i) -= dirin(i); x(i) -= dirin(i);double fs3 = mfcn(x);
331334
x(j) -= dirin(j); x(j) -= dirin(j);double fs4 = mfcn(x);

roofit/roofitcore/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitCore
136136
RooFit/TestStatistics/RooUnbinnedL.h
137137
RooFit/TestStatistics/SharedOffset.h
138138
RooFit/TestStatistics/buildLikelihood.h
139+
RooFit/VariableGroups.h
139140
RooFitLegacy/RooCatTypeLegacy.h
140141
RooFitLegacy/RooCategorySharedProperties.h
141142
RooFitLegacy/RooTreeData.h

roofit/roofitcore/inc/RooAbsArg.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ namespace RooFit {
5757
namespace Experimental {
5858
class CodegenContext;
5959
}
60-
}
60+
struct VariableGroups;
61+
} // namespace RooFit
6162

6263
class RooRefArray : public TObjArray {
6364
public:
@@ -248,7 +249,7 @@ class RooAbsArg : public TNamed, public RooPrintable {
248249
bool recursiveCheckObservables(const RooArgSet* nset) const ;
249250
RooFit::OwningPtr<RooArgSet> getComponents() const ;
250251

251-
252+
virtual void fillVariableGroups(RooFit::VariableGroups &out) const;
252253

253254
void attachArgs(const RooAbsCollection &set);
254255
void attachDataSet(const RooAbsData &set);

roofit/roofitcore/inc/RooAddition.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class RooAddition : public RooAbsReal {
5656

5757
void doEval(RooFit::EvalContext &) const override;
5858

59+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
60+
5961
protected:
6062

6163
RooArgList _ownedList ; ///< List of owned components

roofit/roofitcore/inc/RooConstraintSum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class RooConstraintSum : public RooAbsReal {
4545

4646
std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;
4747

48+
void fillVariableGroups(RooFit::VariableGroups &out) const override;
49+
4850
protected:
4951

5052
RooListProxy _set1 ; ///< Set of constraint terms

0 commit comments

Comments
 (0)