Skip to content

feat: Allowing the merging of samples for displaying model predictions without touching the PyHF model #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
255b184
some initial implementation of a LightConfig carrying info from pyhf …
MoAly98 Feb 20, 2025
942d62c
fix typing
MoAly98 Feb 20, 2025
789e91d
merging samples in stdev, working but not fully tested
MoAly98 Feb 21, 2025
ba7517d
fix model_yields type after sample merge to be list
MoAly98 Feb 21, 2025
e1b6190
pass sample merging map to stdev from prediction call
MoAly98 Feb 21, 2025
5a82207
do not inherit LightConfig from _ModelConfig, and cleanup comments
MoAly98 Feb 24, 2025
94dfdf7
fix backend test
MoAly98 Feb 24, 2025
071bad6
change model in ModelPrediction object and change typing minimally in…
MoAly98 Feb 24, 2025
0426937
predictions test
MoAly98 Feb 24, 2025
23b8697
add tests for stdev function
MoAly98 Feb 24, 2025
c312f2a
write LightModel test and add set of samples as a key to cache
MoAly98 Feb 24, 2025
fd02e38
simplify the LightModel and LightConfig to only hold minimal information
MoAly98 Feb 25, 2025
2cb952d
clean comments
MoAly98 Feb 25, 2025
548ec1f
missing update of backend test
MoAly98 Feb 25, 2025
f611de5
Merge branch 'master' into maly-issue-501
MoAly98 Mar 12, 2025
f0f6ce5
mypy -fixes- or -getarounds-
MoAly98 Mar 13, 2025
1efdadf
fix mutable class defaults and typing issues
MoAly98 Mar 17, 2025
f97ef39
missing docstring
MoAly98 Mar 17, 2025
a48d8f2
pass sample merging maps around, not light model
MoAly98 Mar 27, 2025
679ea24
add test for sample merging of yields to improve coverage, and raise…
MoAly98 Mar 27, 2025
cceb36c
Merge branch 'master' into maly-issue-501
MoAly98 Mar 27, 2025
a6ab2ab
Merge branch 'master' into maly-issue-501
MoAly98 Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 170 additions & 14 deletions src/cabinetry/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,134 @@
_YIELD_STDEV_CACHE: Dict[Any, Tuple[List[List[List[float]]], List[List[float]]]] = {}


class LightConfig:
def __init__(
self,
model: pyhf.pdf.Model,
samples_merge_map: Optional[Dict[str, List[str]]] = None,
):
self.samples = model.config.samples
self.channels = model.config.channels
self.channel_slices = model.config.channel_slices
self.channel_nbins = model.config.channel_nbins
self.npars = model.config.npars
# this is going to break if more config kwargs added in pyhf _ModelConfig
self.modifier_settings = model.config.modifier_settings
self.samples_merge_map = samples_merge_map
self.merged_samples_indices: np.ndarray = np.zeros(0)
if samples_merge_map is not None:
self._update_samples(samples_merge_map)

@property
def samples(self) -> List[str]:
"""
Ordered list of sample names in the model.
"""
return self._samples

@samples.setter
def samples(self, samples: List[str]) -> None:
"""
Set the Ordered list of sample names in the model.
"""
self._samples = samples

def _update_samples(self, samples_merge_map: Dict[str, List[str]]) -> None:
# Delete samples being merged from config
# Flatten all merged samples into a set for O(1) lookups
merged_samples_set = {
merged_sample
for merged_samples_list in samples_merge_map.values()
for merged_sample in merged_samples_list
}
merged_samples_indices = [
idx
for idx, sample in enumerate(self.samples)
if sample in merged_samples_set
]
self.merged_samples_indices = np.asarray(merged_samples_indices)
self.samples = cast(
List[str], np.delete(self.samples, merged_samples_indices).tolist()
)
# Add new samples at the bottom of stack
self.samples = cast(
List[str],
np.insert(
np.asarray(self.samples, dtype=object),
np.arange(len(samples_merge_map)),
list(samples_merge_map.keys()),
axis=0,
).tolist(),
)


class LightModel:
def __init__(
self,
model: pyhf.pdf.Model,
samples_merge_map: Optional[Dict[str, List[str]]] = None,
):
self.config = LightConfig(model, samples_merge_map)
self.spec = model.spec


def _merge_sample_yields(
model: LightModel,
old_yields: Union[List[List[List[float]]], List[List[float]]],
one_channel: Optional[bool] = False,
) -> np.ndarray:

samples_merge_map = model.config.samples_merge_map

def _sum_per_channel(i_ch: Optional[int] = None) -> np.ndarray:
# explicit type casting because mypy worries that old_yield
# will be list(list(float)) or list(float)
# but this will never happen because of the if condition
if i_ch is not None:
old_yield = cast(List[List[float]], old_yields[i_ch])
else:
old_yield = cast(List[List[float]], old_yields)
# for each channel, sum together the desired samples
summed_sample = np.sum(
np.asarray(old_yield)[model.config.merged_samples_indices], axis=0
)
# build set of remaining samples and remove the ones already summed
remaining_samples: np.ndarray = np.delete(
old_yield, model.config.merged_samples_indices, axis=0
)
# mypy not able to tell that map cannot be None-typed
# so we have to check
if samples_merge_map is not None:
model_yields_one_channel = np.insert(
remaining_samples,
np.arange(len(samples_merge_map.keys())),
summed_sample,
axis=0,
)
else:
log.critical(
"Something has gone wrong in merging samples."
+ " Report this to the dev team."
)
return model_yields_one_channel

new_yields = []
if not one_channel:
for i_ch in range(len(model.config.channels)):
new_yields.append(_sum_per_channel(i_ch=i_ch))
else:
new_yields = [_sum_per_channel()] # wrap in list for consistent type

return_yields = np.asarray(new_yields[0]) if one_channel else np.asarray(new_yields)
return return_yields


class ModelPrediction(NamedTuple):
"""Model prediction with yields and total uncertainties per bin and channel.

Args:
model (pyhf.pdf.Model): model to which prediction corresponds to
model (LightModel or pyhf.pdf.Model): model (or a light-weight version of
pyhf.pdf.Model) to which prediction corresponds to
model_yields (List[List[List[float]]]): yields per channel, sample and bin,
indices: channel, sample, bin
total_stdev_model_bins (List[List[List[float]]]): total yield uncertainty per
Expand All @@ -43,7 +166,7 @@ class ModelPrediction(NamedTuple):
label (str): label for the prediction, e.g. "pre-fit" or "post-fit"
"""

model: pyhf.pdf.Model
model: Union[LightModel, pyhf.pdf.Model]
model_yields: List[List[List[float]]]
total_stdev_model_bins: List[List[List[float]]]
total_stdev_model_channels: List[List[float]]
Expand Down Expand Up @@ -235,6 +358,7 @@ def yield_stdev(
parameters: np.ndarray,
uncertainty: np.ndarray,
corr_mat: np.ndarray,
light_model: Optional[LightModel] = None,
) -> Tuple[List[List[List[float]]], List[List[float]]]:
"""Calculates symmetrized model yield standard deviation per channel / sample / bin.

Expand All @@ -246,7 +370,7 @@ def yield_stdev(
of this function are cached to speed up subsequent calls with the same arguments.

Args:
model (pyhf.pdf.Model): the model for which to calculate the standard deviations
model (LightModel): the model for which to calculate the standard deviations
for all bins
parameters (np.ndarray): central values of model parameters
uncertainty (np.ndarray): uncertainty of model parameters
Expand All @@ -262,12 +386,18 @@ def yield_stdev(
over all samples)
"""
# check whether results are already stored in cache
samples_string = (
",".join(light_model.config.samples)
if light_model is not None
else ",".join(model.config.samples)
)
cached_results = _YIELD_STDEV_CACHE.get(
(
_hashable_model_key(model),
tuple(parameters),
tuple(uncertainty),
corr_mat.data.tobytes(),
samples_string,
),
None,
)
Expand Down Expand Up @@ -302,6 +432,11 @@ def yield_stdev(
# attach another entry with the total model prediction (sum over all samples)
# indices: sample, bin
up_comb = np.vstack((up_comb, np.sum(up_comb, axis=0)))
if light_model is not None:
up_comb = _merge_sample_yields(
light_model, up_comb.tolist(), one_channel=True
)

# turn into list of channels (keep all samples, select correct bins per channel)
# indices: channel, sample, bin
up_yields_per_channel = [
Expand All @@ -314,6 +449,7 @@ def yield_stdev(
for chan_yields in up_yields_per_channel
]
)

# reshape to drop bin axis, transpose to turn channel axis into new bin axis
# (channel, sample, bin) -> (sample, bin) where "bin" becomes channel sums
up_yields_channel_sum = up_yields_channel_sum.reshape(
Expand All @@ -323,14 +459,19 @@ def yield_stdev(
# concatenate per-channel sums to up_comb (along bin axis)
up_yields = np.concatenate((up_comb, up_yields_channel_sum), axis=1)
# indices: variation, sample, bin
up_variations.append(up_yields.tolist())
up_variations.append(up_yields)

# model distribution per sample with this parameter varied down
down_comb = pyhf.tensorlib.to_numpy(
model.main_model.expected_data(down_pars, return_by_sample=True)
)
# add total prediction (sum over samples)
down_comb = np.vstack((down_comb, np.sum(down_comb, axis=0)))
if light_model is not None:
down_comb = _merge_sample_yields(
light_model, down_comb.tolist(), one_channel=True
)

# turn into list of channels
down_yields_per_channel = [
down_comb[:, model.config.channel_slices[ch]]
Expand All @@ -354,7 +495,6 @@ def yield_stdev(
# convert to numpy arrays for further processing
up_variations_np = np.asarray(up_variations)
down_variations_np = np.asarray(down_variations)

# calculate symmetric uncertainties for all components
# indices: variation, channel (last entries sums), sample (last entry sum), bin
sym_uncs = (up_variations_np - down_variations_np) / 2
Expand Down Expand Up @@ -418,6 +558,7 @@ def yield_stdev(
tuple(parameters),
tuple(uncertainty),
corr_mat.data.tobytes(),
samples_string,
): (total_stdev_per_bin, total_stdev_per_channel)
}
)
Expand All @@ -430,6 +571,7 @@ def prediction(
*,
fit_results: Optional[FitResults] = None,
label: Optional[str] = None,
samples_merge_map: Optional[Dict[str, List[str]]] = None,
) -> ModelPrediction:
"""Returns model prediction, including model yields and uncertainties.

Expand All @@ -450,6 +592,7 @@ def prediction(
Returns:
ModelPrediction: model, yields and uncertainties per channel, sample, bin
"""
light_model = LightModel(model, samples_merge_map)
if fit_results is not None:
if fit_results.labels != model.config.par_names:
log.warning("parameter names in fit results and model do not match")
Expand Down Expand Up @@ -479,15 +622,28 @@ def prediction(
for ch in model.config.channels
]

if samples_merge_map is not None:
model_yields = cast(
List[List[List[float]]],
_merge_sample_yields(light_model, model_yields).tolist(),
)

# calculate the total standard deviation of the model prediction
# indices: (channel, sample, bin) for per-bin uncertainties,
# (channel, sample) for per-channel
total_stdev_model_bins, total_stdev_model_channels = yield_stdev(
model, param_values, param_uncertainty, corr_mat
model,
param_values,
param_uncertainty,
corr_mat,
light_model=light_model if samples_merge_map is not None else None,
)

return ModelPrediction(
model, model_yields, total_stdev_model_bins, total_stdev_model_channels, label
light_model,
model_yields,
total_stdev_model_bins,
total_stdev_model_channels,
label,
)


Expand Down Expand Up @@ -580,11 +736,11 @@ def _poi_index(
return poi_index


def _strip_auxdata(model: pyhf.pdf.Model, data: List[float]) -> List[float]:
def _strip_auxdata(model: LightModel, data: List[float]) -> List[float]:
"""Always returns observed yields, no matter whether data includes auxdata.

Args:
model (pyhf.pdf.Model): model to which data corresponds to
model (LightModel): model to which data corresponds to
data (List[float]): data, either including auxdata which is then stripped off or
only observed yields

Expand All @@ -598,11 +754,11 @@ def _strip_auxdata(model: pyhf.pdf.Model, data: List[float]) -> List[float]:
return data


def _data_per_channel(model: pyhf.pdf.Model, data: List[float]) -> List[List[float]]:
def _data_per_channel(model: LightModel, data: List[float]) -> List[List[float]]:
"""Returns data split per channel, and strips off auxiliary data if included.

Args:
model (pyhf.pdf.Model): model to which data corresponds to
model (LightModel): model to which data corresponds to
data (List[float]): data (not split by channel), can either include auxdata
which is then stripped off, or only observed yields

Expand All @@ -620,12 +776,12 @@ def _data_per_channel(model: pyhf.pdf.Model, data: List[float]) -> List[List[flo


def _filter_channels(
model: pyhf.pdf.Model, channels: Optional[Union[str, List[str]]]
model: LightModel, channels: Optional[Union[str, List[str]]]
) -> List[str]:
"""Returns a list of channels in a model after applying filtering.

Args:
model (pyhf.pdf.Model): model from which to extract channels
model (LightModel): model from which to extract channels
channels (Optional[Union[str, List[str]]]): name of channel or list of channels
to filter, only including those channels provided via this argument in the
return of the function
Expand Down
Loading