Skip to content

Commit 3ea584a

Browse files
rigaanigamova
authored andcommitted
Fix caching in ShapeBuilder.
1 parent 18685de commit 3ea584a

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

python/ShapeTools.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def __init__(self, datacard, options):
7171
self.norm_rename_map = {}
7272
self._fileCache = FileCache(self.options.baseDir)
7373

74+
self._get_shape_cache = {}
75+
self._get_pdf_cache = {}
76+
self._shape2data_cache = {}
77+
self._shape2pdf_cache = {}
78+
7479
## ------------------------------------------
7580
## -------- ModelBuilder interface ----------
7681
## ------------------------------------------
@@ -654,7 +659,10 @@ def doCombinedDataset(self):
654659
## -------------------------------------
655660
## -------- Low level helpers ----------
656661
## -------------------------------------
657-
def getShape(self, channel, process, syst="", _cache={}, allowNoSyst=False):
662+
def getShape(self, channel, process, syst="", _cache=None, allowNoSyst=False):
663+
if _cache is None:
664+
_cache = self._get_shape_cache
665+
658666
if (channel, process, syst) in _cache:
659667
if self.options.verbose > 2:
660668
print(
@@ -841,10 +849,13 @@ def getShape(self, channel, process, syst="", _cache={}, allowNoSyst=False):
841849
_cache[(channel, process, syst)] = ret
842850
return ret
843851

844-
def getData(self, channel, process, syst="", _cache={}):
852+
def getData(self, channel, process, syst="", _cache=None):
845853
return self.shape2Data(self.getShape(channel, process, syst), channel, process)
846854

847-
def getPdf(self, channel, process, _cache={}):
855+
def getPdf(self, channel, process, _cache=None):
856+
if _cache is None:
857+
_cache = self._get_pdf_cache
858+
848859
postFix = "Sig" if (process in self.DC.isSignal and self.DC.isSignal[process]) else "Bkg"
849860
if (channel, process) in _cache:
850861
return _cache[(channel, process)]
@@ -1202,7 +1213,10 @@ def rebinH1(self, shape):
12021213
rebinh1._original_bins = shapeNbins
12031214
return rebinh1
12041215

1205-
def shape2Data(self, shape, channel, process, _cache={}):
1216+
def shape2Data(self, shape, channel, process, _cache=None):
1217+
if _cache is None:
1218+
_cache = self._shape2data_cache
1219+
12061220
postFix = "Sig" if (process in self.DC.isSignal and self.DC.isSignal[process]) else "Bkg"
12071221
if not shape:
12081222
name = f"shape{postFix}_{channel}_{process}"
@@ -1238,7 +1252,10 @@ def shape2Data(self, shape, channel, process, _cache={}):
12381252
raise RuntimeError("shape2Data not implemented for %s" % shape.ClassName())
12391253
return _cache[shape.GetName()]
12401254

1241-
def shape2Pdf(self, shape, channel, process, _cache={}):
1255+
def shape2Pdf(self, shape, channel, process, _cache=None):
1256+
if _cache is None:
1257+
_cache = self._shape2pdf_cache
1258+
12421259
postFix = "Sig" if (process in self.DC.isSignal and self.DC.isSignal[process]) else "Bkg"
12431260
channelBinParFlag = channel in list(self.DC.binParFlags.keys())
12441261
if shape == None:

0 commit comments

Comments
 (0)