Skip to content

Commit 73023df

Browse files
committed
Refactor VerticalInterpPdf to enable codegen backend
1 parent 3dda45c commit 73023df

7 files changed

+237
-120
lines changed

interface/CombineCodegenImpl.h

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,28 @@
22
#define HiggsAnalysis_CombinedLimit_CombineCodegenImpl_h
33

44
#include <ROOT/RConfig.hxx> // for ROOT_VERSION
5+
#include <RooAbsReal.h>
6+
#include <string>
57

68
#if ROOT_VERSION_CODE >= ROOT_VERSION(6,35,0)
79
# define COMBINE_DECLARE_CODEGEN_IMPL(CLASS_NAME) \
810
namespace RooFit { namespace Experimental { void codegenImpl(CLASS_NAME &arg, CodegenContext &ctx); }}
11+
# define COMBINE_DECLARE_CODEGEN_INTEGRAL_IMPL(CLASS_NAME) \
12+
namespace RooFit { namespace Experimental { std::string codegenIntegralImpl(CLASS_NAME &arg, int code, const char *rangeName, CodegenContext &ctx); }}
913
# define COMBINE_DECLARE_TRANSLATE
14+
# define COMBINE_DECLARE_ANALYTICAL_INTEGRAL
1015
#elif ROOT_VERSION_CODE >= ROOT_VERSION(6,32,0)
1116
# define COMBINE_DECLARE_CODEGEN_IMPL(CLASS_NAME)
12-
# define COMBINE_DECLARE_TRANSLATE void translate(RooFit::Detail::CodeSquashContext &ctx) const override;
17+
# define COMBINE_DECLARE_CODEGEN_INTEGRAL_IMPL(CLASS_NAME)
18+
# define COMBINE_DECLARE_TRANSLATE \
19+
void translate(RooFit::Detail::CodeSquashContext &ctx) const override;
20+
# define COMBINE_DECLARE_ANALYTICAL_INTEGRAL \
21+
std::string buildCallToAnalyticIntegral(Int_t code, const char *rangeName, RooFit::Detail::CodeSquashContext &ctx) const override;
1322
#else
1423
# define COMBINE_DECLARE_CODEGEN_IMPL(_)
24+
# define COMBINE_DECLARE_CODEGEN_INTEGRAL_IMPL(_)
1525
# define COMBINE_DECLARE_TRANSLATE
26+
# define COMBINE_DECLARE_ANALYTICAL_INTEGRAL
1627
#endif
1728

1829
#endif

interface/CombineMathFuncs.h

+161
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#ifndef CombineMathFuncs_h
22
#define CombineMathFuncs_h
33

4+
#include <RooAbsReal.h>
5+
#include <RooArgList.h>
6+
#include <RooArgSet.h>
7+
#include <RooConstVar.h>
8+
49
#include <cmath>
510

611
namespace RooFit {
@@ -145,6 +150,162 @@ inline double processNormalization(double nominalValue, std::size_t nThetas, std
145150
return norm;
146151
}
147152

153+
// Interpolation (from VerticalInterpPdf)
154+
inline Double_t interpolate(Double_t const coeff, Double_t const central, Double_t const fUp,
155+
Double_t const fDn, Double_t const quadraticRegion, Int_t const quadraticAlgo)
156+
{
157+
if (quadraticAlgo == -1) {
158+
Double_t kappa = (coeff > 0 ? fUp/central : central/fDn);
159+
return pow(kappa, sqrt(pow(coeff, 2)));
160+
}
161+
162+
if (fabs(coeff) >= quadraticRegion) {
163+
return coeff * (coeff > 0 ? fUp - central : central - fDn);
164+
} else {
165+
// quadratic interpolation coefficients between the three
166+
if (quadraticAlgo == 0) {
167+
// quadratic interpolation null at zero and continuous at boundaries, but not differentiable at boundaries
168+
// conditions:
169+
// c_up (+quadraticRegion) = +quadraticRegion
170+
// c_cen(+quadraticRegion) = -quadraticRegion
171+
// c_dn (+quadraticRegion) = 0
172+
// c_up (-quadraticRegion) = 0
173+
// c_cen(-quadraticRegion) = -quadraticRegion
174+
// c_dn (-quadraticRegion) = +quadraticRegion
175+
// c_up(0) = c_dn(0) = c_cen(0) = 0
176+
Double_t c_up = + coeff * (quadraticRegion + coeff) / (2 * quadraticRegion);
177+
Double_t c_dn = - coeff * (quadraticRegion - coeff) / (2 * quadraticRegion);
178+
Double_t c_cen = - coeff * coeff / quadraticRegion;
179+
return c_up * fUp + c_dn * fDn + c_cen * central;
180+
} else if (quadraticAlgo == 1) {
181+
// quadratic interpolation that is everywhere differentiable, but it's not null at zero
182+
// conditions on the function
183+
// c_up (+quadraticRegion) = +quadraticRegion
184+
// c_cen(+quadraticRegion) = -quadraticRegion
185+
// c_dn (+quadraticRegion) = 0
186+
// c_up (-quadraticRegion) = 0
187+
// c_cen(-quadraticRegion) = -quadraticRegion
188+
// c_dn (-quadraticRegion) = +quadraticRegion
189+
// conditions on the derivatives
190+
// c_up '(+quadraticRegion) = +1
191+
// c_cen'(+quadraticRegion) = -1
192+
// c_dn '(+quadraticRegion) = 0
193+
// c_up '(-quadraticRegion) = 0
194+
// c_cen'(-quadraticRegion) = +1
195+
// c_dn '(-quadraticRegion) = -1
196+
Double_t c_up = (quadraticRegion + coeff) * (quadraticRegion + coeff) / (4 * quadraticRegion);
197+
Double_t c_dn = (quadraticRegion - coeff) * (quadraticRegion - coeff) / (4 * quadraticRegion);
198+
Double_t c_cen = - c_up - c_dn;
199+
return c_up * fUp + c_dn * fDn + c_cen * central;
200+
} else/* if (quadraticAlgo == 1)*/ {
201+
// P(6) interpolation that is everywhere differentiable and null at zero
202+
/* === how the algorithm works, in theory ===
203+
* let dhi = h_hi - h_nominal
204+
* dlo = h_lo - h_nominal
205+
* and x be the morphing parameter
206+
* we define alpha = x * 0.5 * ((dhi-dlo) + (dhi+dlo)*smoothStepFunc(x));
207+
* which satisfies:
208+
* alpha(0) = 0
209+
* alpha(+1) = dhi
210+
* alpha(-1) = dlo
211+
* alpha(x >= +1) = |x|*dhi
212+
* alpha(x <= -1) = |x|*dlo
213+
* alpha is continuous and has continuous first and second derivative, as smoothStepFunc has them
214+
* === and in practice ===
215+
* we already have computed the histogram for diff=(dhi-dlo) and sum=(dhi+dlo)
216+
* so we just do template += (0.5 * x) * (diff + smoothStepFunc(x) * sum)
217+
* ========================================== */
218+
Double_t cnorm = coeff/quadraticRegion;
219+
Double_t cnorm2 = pow(cnorm, 2);
220+
Double_t hi = fUp - central;
221+
Double_t lo = fDn - central;
222+
Double_t sum = hi+lo;
223+
Double_t diff = hi-lo;
224+
Double_t a = coeff/2.; // cnorm*quadraticRegion
225+
Double_t b = 0.125 * cnorm * (cnorm2 * (3.*cnorm2 - 10.) + 15.);
226+
Double_t result = a*(diff + b*sum);
227+
return result;
228+
}
229+
}
230+
}
231+
232+
template <typename Operation>
233+
inline Double_t opInterpolate(RooArgList const& coefList, RooArgList const& funcList, Double_t const pdfFloorVal,
234+
Double_t const quadraticRegion, Int_t const quadraticAlgo, const RooArgSet* normSet2=nullptr)
235+
{
236+
// Do running sum of coef/func pairs, calculate lastCoef.
237+
RooAbsReal* func = &(RooAbsReal&)funcList[0];
238+
Double_t central = func->getVal();
239+
Double_t value = central;
240+
241+
Operation op;
242+
243+
for (int iCoef = 0; iCoef < coefList.getSize(); ++iCoef) {
244+
Double_t coefVal = static_cast<RooAbsReal&>(coefList[iCoef]).getVal(normSet2) ;
245+
RooAbsReal* funcUp = &(RooAbsReal&)funcList[2 * iCoef + 1];
246+
RooAbsReal* funcDn = &(RooAbsReal&)funcList[2 * iCoef + 2];
247+
value = op(value, interpolate(coefVal, central, funcUp->getVal(), funcDn->getVal(), quadraticRegion, quadraticAlgo));
248+
}
249+
return ( value > 0. ? value : pdfFloorVal);
250+
}
251+
252+
inline Double_t additiveInterpolate(double const* coefList, std::size_t nCoeffs, double const* funcList, std::size_t nFuncs,
253+
Double_t const pdfFloorVal, Double_t const quadraticRegion, Int_t const quadraticAlgo)
254+
{
255+
// Do running sum of coef/func pairs, calculate lastCoef.
256+
Double_t central = funcList[0];
257+
Double_t value = central;
258+
259+
for (std::size_t iCoef = 0; iCoef < nCoeffs; ++iCoef) {
260+
double coefVal = coefList[iCoef];
261+
double funcUp = funcList[2 * iCoef + 1];
262+
double funcDn = funcList[2 * iCoef + 2];
263+
value += interpolate(coefVal, central, funcUp, funcDn, quadraticRegion, quadraticAlgo);
264+
}
265+
return ( value > 0. ? value : pdfFloorVal);
266+
}
267+
268+
inline Double_t multiplicativeInterpolate(double const* coefList, std::size_t nCoeffs, double const* funcList, std::size_t nFuncs,
269+
Double_t const pdfFloorVal, Double_t const quadraticRegion, Int_t const quadraticAlgo)
270+
{
271+
// Do running sum of coef/func pairs, calculate lastCoef.
272+
Double_t central = funcList[0];
273+
Double_t value = central;
274+
275+
for (std::size_t iCoef = 0; iCoef < nCoeffs; ++iCoef) {
276+
double coefVal = coefList[iCoef];
277+
double funcUp = funcList[2 * iCoef + 1];
278+
double funcDn = funcList[2 * iCoef + 2];
279+
value *= interpolate(coefVal, central, funcUp, funcDn, quadraticRegion, quadraticAlgo);
280+
}
281+
return ( value > 0. ? value : pdfFloorVal);
282+
}
283+
284+
inline Double_t verticalInterpolate(double const* coefList, std::size_t nCoeffs, double const* funcList, std::size_t nFuncs,
285+
double const pdfFloorVal, double const quadraticRegion, Int_t const quadraticAlgo)
286+
{
287+
// Do running sum of coef/func pairs, calculate lastCoef.
288+
Double_t value = pdfFloorVal;
289+
if (quadraticAlgo >= 0) {
290+
value = RooFit::Detail::MathFuncs::additiveInterpolate(coefList, nCoeffs, funcList, nFuncs, pdfFloorVal, quadraticRegion, quadraticAlgo);
291+
} else {
292+
value = RooFit::Detail::MathFuncs::multiplicativeInterpolate(coefList, nCoeffs, funcList, nFuncs, pdfFloorVal, quadraticRegion, quadraticAlgo);
293+
}
294+
return value;
295+
}
296+
297+
inline Double_t verticalInterpPdfIntegral(double const* coefList, std::size_t nCoeffs, double const* funcIntList, std::size_t nFuncs,
298+
double const pdfFloorVal, double const integralFloorVal,
299+
double const quadraticRegion, Int_t const quadraticAlgo)
300+
{
301+
double value = RooFit::Detail::MathFuncs::additiveInterpolate(coefList, nCoeffs, funcIntList, nFuncs,
302+
pdfFloorVal, quadraticRegion, quadraticAlgo);
303+
double normVal(1);
304+
double result = 0;
305+
if(normVal>0.) result = value / normVal;
306+
return result > 0. ? result : integralFloorVal;
307+
}
308+
148309
} // namespace MathFuncs
149310
} // namespace Detail
150311
} // namespace RooFit

interface/VerticalInterpPdf.h

+11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "RooListProxy.h"
99
#include "RooObjCacheManager.h"
1010

11+
#include "CombineCodegenImpl.h"
12+
1113
class VerticalInterpPdf : public RooAbsPdf {
1214
public:
1315

@@ -18,17 +20,23 @@ class VerticalInterpPdf : public RooAbsPdf {
1820

1921
Double_t evaluate() const override ;
2022
Bool_t checkObservables(const RooArgSet* nset) const override ;
23+
COMBINE_DECLARE_TRANSLATE
2124

2225
Bool_t forceAnalyticalInt(const RooAbsArg&) const override { return kTRUE ; }
2326
Int_t getAnalyticalIntegralWN(RooArgSet& allVars, RooArgSet& numVars, const RooArgSet* normSet, const char* rangeName=0) const override ;
2427
Double_t analyticalIntegralWN(Int_t code, const RooArgSet* normSet, const char* rangeName=0) const override ;
28+
COMBINE_DECLARE_ANALYTICAL_INTEGRAL
2529

2630
const RooArgList& funcList() const { return _funcList ; }
31+
const RooArgList& funcIntListFromCache() const;
2732
const RooArgList& coefList() const { return _coefList ; }
2833

2934
const Double_t quadraticRegion() const { return _quadraticRegion; }
3035
const Int_t quadraticAlgo() const { return _quadraticAlgo; }
3136

37+
const Double_t pdfFloorVal() const { return _pdfFloorVal; }
38+
const Double_t integralFloorVal() const { return _integralFloorVal; }
39+
3240
void setFloorVals(Double_t const& pdf_val, Double_t const& integral_val);
3341

3442
#if ROOT_VERSION_CODE >= ROOT_VERSION(6,34,06)
@@ -62,4 +70,7 @@ class VerticalInterpPdf : public RooAbsPdf {
6270
ClassDefOverride(VerticalInterpPdf,3) // PDF constructed from a sum of (non-pdf) functions
6371
};
6472

73+
COMBINE_DECLARE_CODEGEN_IMPL(VerticalInterpPdf);
74+
COMBINE_DECLARE_CODEGEN_INTEGRAL_IMPL(VerticalInterpPdf);
75+
6576
#endif

src/CombineCodegenImpl.cxx

+31
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55
#include "../interface/AsymPow.h"
66
#include "../interface/ProcessNormalization.h"
77
#include "../interface/VerticalInterpHistPdf.h"
8+
#include "../interface/VerticalInterpPdf.h"
9+
#include "../interface/CombineMathFuncs.h"
810

911
#include <RooUniformBinning.h>
1012

1113
#if ROOT_VERSION_CODE >= ROOT_VERSION(6,35,0)
1214
namespace RooFit {
1315
namespace Experimental {
1416
# define CODEGEN_IMPL(CLASS_NAME) void codegenImpl(CLASS_NAME &arg0, CodegenContext &ctx)
17+
# define CODEGEN_INTEGRAL_IMPL(CLASS_NAME) std::string codegenIntegralImpl(CLASS_NAME &arg0, int code, const char *rangeName, CodegenContext &ctx)
1518
# define ARG_VAR auto &arg = arg0;
1619
#else
1720
# define CODEGEN_IMPL(CLASS_NAME) void CLASS_NAME::translate(RooFit::Detail::CodeSquashContext &ctx) const
21+
# define CODEGEN_INTEGRAL_IMPL(CLASS_NAME) std::string CLASS_NAME::buildCallToAnalyticIntegral(Int_t code, const char *rangeName, RooFit::Detail::CodeSquashContext &ctx) const
1822
# define ARG_VAR auto &arg = *this;
1923
#endif
2024

@@ -217,6 +221,33 @@ CODEGEN_IMPL(FastVerticalInterpHistPdf2D2) {
217221
ctx.addResult(&arg, arrName + "[" + binIdx.str() + "]");
218222
}
219223

224+
CODEGEN_IMPL(VerticalInterpPdf) {
225+
ARG_VAR;
226+
ctx.addResult(&arg,
227+
ctx.buildCall("RooFit::Detail::MathFuncs::verticalInterpolate",
228+
arg.coefList(),
229+
arg.coefList().size(),
230+
arg.funcList(),
231+
arg.funcList().size(),
232+
arg.pdfFloorVal(),
233+
arg.quadraticRegion(),
234+
arg.quadraticAlgo()));
235+
236+
}
237+
238+
CODEGEN_INTEGRAL_IMPL(VerticalInterpPdf) {
239+
ARG_VAR;
240+
return ctx.buildCall("RooFit::Detail::MathFuncs::verticalInterpPdfIntegral",
241+
arg.coefList(),
242+
arg.coefList().size(),
243+
arg.funcIntListFromCache(),
244+
arg.funcIntListFromCache().size(),
245+
arg.pdfFloorVal(),
246+
arg.integralFloorVal(),
247+
arg.quadraticRegion(),
248+
arg.quadraticAlgo());
249+
}
250+
220251
#if ROOT_VERSION_CODE >= ROOT_VERSION(6,35,0)
221252
} // namespace RooFit
222253
} // namespace Experimental

src/VerticalInterpHistPdf.cc

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "RooMsgService.h"
1111
#include "RooAbsData.h"
1212

13+
#include "RooUniformBinning.h"
14+
1315
//#define TRACE_CALLS
1416
#ifdef TRACE_CALLS
1517
#include "../interface/ProfilingTools.h"

0 commit comments

Comments
 (0)