Skip to content

Commit 6fb4e95

Browse files
committed
fix for typing of fit kwargs in impact calculations
1 parent f98b7fd commit 6fb4e95

File tree

1 file changed

+66
-30
lines changed

1 file changed

+66
-30
lines changed

Diff for: src/cabinetry/fit/__init__.py

+66-30
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import defaultdict
44
import logging
5-
from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Union
5+
from typing import Any, cast, Dict, List, Literal, Optional, Tuple, TypedDict, Union
66

77
import iminuit
88
import numpy as np
@@ -23,6 +23,16 @@
2323
log = logging.getLogger(__name__)
2424

2525

26+
class FitKwargs(TypedDict, total=False):
27+
init_pars: List[float]
28+
fix_pars: List[bool] # Explicitly state it's a list
29+
par_bounds: Optional[List[Tuple[float, float]]]
30+
strategy: Optional[Literal[0, 1, 2]]
31+
maxiter: Optional[int]
32+
tolerance: Optional[float]
33+
custom_fit: bool
34+
35+
2636
def print_results(fit_results: FitResults) -> None:
2737
"""Prints the best-fit parameter results and associated uncertainties.
2838
@@ -470,7 +480,9 @@ def _get_impacts_summary(
470480
non-systematic sources
471481
"""
472482
non_syst_modifiers = ["normfactor", "shapefactor", "staterror"] # Lumi ?
473-
impacts_summary = defaultdict(lambda: defaultdict(float))
483+
impacts_summary: Dict[str, Dict[str, float]] = defaultdict(
484+
lambda: defaultdict(float)
485+
)
474486
# Dictionary to store the merged values after removing certain modifiers
475487
syst_impacts_map = defaultdict(list)
476488
# Iterate through each modifier and its corresponding data
@@ -492,7 +504,7 @@ def _get_impacts_summary(
492504
)
493505
)
494506

495-
return impacts_summary
507+
return dict(impacts_summary)
496508

497509

498510
def _get_datastat_impacts_np_shift(
@@ -501,7 +513,7 @@ def _get_datastat_impacts_np_shift(
501513
data: List[float],
502514
poi_index: int,
503515
fit_results: FitResults,
504-
fit_kwargs,
516+
fit_kwargs: FitKwargs,
505517
) -> Dict[str, Dict[str, float]]:
506518
"""
507519
Calculate the impact of statistical uncertainties on the parameter of interest by
@@ -514,7 +526,7 @@ def _get_datastat_impacts_np_shift(
514526
data (List[float]): data (including auxdata) the model is fit to
515527
poi_index (int): index of the parameter of interest
516528
fit_results (FitResults): nominal fit results to use in impacts calculation
517-
fit_kwargs (_type_): settings to be used in the fits
529+
fit_kwargs (FitKwargs): settings to be used in the fits
518530
519531
Returns:
520532
Dict[str, Dict[str, float]]: impacts summary categorized by systematic and
@@ -532,12 +544,19 @@ def _get_datastat_impacts_np_shift(
532544
fit_results.bestfit[i_par] if i_par != poi_index else init_pars_datastat[i_par]
533545
for i_par in range(len(model.config.par_names))
534546
]
547+
# removing init_pars and fix_pars from dict, cast it to FitKwargs
548+
# type then update values in dict since casting re-introduces them.
549+
updated_fit_kwargs = cast(
550+
FitKwargs,
551+
{k: v for k, v in fit_kwargs.items() if k not in ["init_pars", "fix_pars"]},
552+
)
553+
updated_fit_kwargs["init_pars"] = init_pars_datastat
554+
updated_fit_kwargs["fix_pars"] = fix_pars_datastat
555+
535556
fit_results_datastat = _fit_model(
536557
model,
537558
data,
538-
init_pars=init_pars_datastat,
539-
fix_pars=fix_pars_datastat,
540-
**{k: v for k, v in fit_kwargs.items() if k not in ["init_pars", "fix_pars"]},
559+
**updated_fit_kwargs,
541560
)
542561
datastat_poi_val = fit_results_datastat.bestfit[poi_index]
543562
datastat_impact = datastat_poi_val - nominal_poi
@@ -546,7 +565,9 @@ def _get_datastat_impacts_np_shift(
546565
return impacts_summary
547566

548567

549-
def _get_datastat_impacts_quadruture(impacts_summary, total_error):
568+
def _get_datastat_impacts_quadruture(
569+
impacts_summary: Dict[str, Dict[str, float]], total_error: float
570+
) -> Dict[str, Dict[str, float]]:
550571
"""
551572
Calculate the impact of statistical uncertainties on the parameter of subtracting
552573
other sources from the total error in quadrature.
@@ -636,7 +657,9 @@ def _cov_impacts(
636657
systematic and non-systematic sources
637658
"""
638659
total_poi_error = fit_results.uncertainty[poi_index]
639-
impacts_by_modifier_type = defaultdict(lambda: defaultdict(list))
660+
impacts_by_modifier_type: Dict[str, Dict[str, List[float]]] = defaultdict(
661+
lambda: defaultdict(list)
662+
)
640663
i_global_par = 0
641664

642665
for parameter in model.config.par_order:
@@ -688,7 +711,7 @@ def _cov_impacts(
688711
impacts_summary = _get_impacts_summary(impacts_by_modifier_type)
689712
impacts_summary = _get_datastat_impacts_quadruture(impacts_summary, total_poi_error)
690713

691-
return impacts_by_modifier_type, impacts_summary
714+
return dict(impacts_by_modifier_type), impacts_summary
692715

693716

694717
def _np_impacts(
@@ -698,7 +721,7 @@ def _np_impacts(
698721
fit_results: FitResults,
699722
prefit_unc: np.ndarray,
700723
labels: List[str],
701-
fit_kwargs,
724+
fit_kwargs: FitKwargs,
702725
) -> Tuple[Dict[str, Dict[str, List[float]]], Dict[str, Dict[str, float]]]:
703726
"""
704727
Computes the impact of nuisance parameters on the POI by shifting parameter
@@ -712,7 +735,7 @@ def _np_impacts(
712735
fit_results (FitResults): nominal fit results to use in impacts calculation
713736
prefit_unc (np.ndarray): pre-fit uncertainties of parameters
714737
labels (List[str]): list of parameter names
715-
fit_kwargs: settings to be used in the fits.
738+
fit_kwargs (FitKwargs): settings to be used in the fits.
716739
717740
Returns:
718741
Tuple[Dict[str, Dict[str, List[float]]], Dict[str, Dict[str, float]]]:
@@ -723,7 +746,9 @@ def _np_impacts(
723746
systematic and non-systematic sources.
724747
"""
725748
nominal_poi = fit_results.bestfit[poi_index]
726-
impacts_by_modifier_type = defaultdict(lambda: defaultdict(list))
749+
impacts_by_modifier_type: Dict[str, Dict[str, List[float]]] = defaultdict(
750+
lambda: defaultdict(list)
751+
)
727752
i_global_par = 0
728753

729754
for parameter in model.config.par_order:
@@ -765,17 +790,25 @@ def _np_impacts(
765790
init_pars_ranking = fit_kwargs["init_pars"].copy()
766791
# value of current nuisance parameter
767792
init_pars_ranking[i_par] = np_val
768-
fit_results_ranking = _fit_model(
769-
model,
770-
data,
771-
init_pars=init_pars_ranking,
772-
fix_pars=fix_pars_ranking,
773-
**{
793+
794+
# removing init_pars and fix_pars from dict, cast it to FitKwargs
795+
# type then update values in dict since casting re-introduces them.
796+
updated_fit_kwargs = cast(
797+
FitKwargs,
798+
{
774799
k: v
775800
for k, v in fit_kwargs.items()
776801
if k not in ["init_pars", "fix_pars"]
777802
},
778803
)
804+
updated_fit_kwargs["init_pars"] = init_pars_ranking
805+
updated_fit_kwargs["fix_pars"] = fix_pars_ranking
806+
807+
fit_results_ranking = _fit_model(
808+
model,
809+
data,
810+
**updated_fit_kwargs,
811+
)
779812
poi_val = fit_results_ranking.bestfit[poi_index]
780813
parameter_impact = poi_val - nominal_poi
781814
log.debug(
@@ -798,7 +831,7 @@ def _np_impacts(
798831
impacts_summary, model, data, poi_index, fit_results, fit_kwargs
799832
)
800833

801-
return impacts_by_modifier_type, impacts_summary
834+
return dict(impacts_by_modifier_type), impacts_summary
802835

803836

804837
def _auxdata_shift_impacts(
@@ -1047,15 +1080,18 @@ def ranking(
10471080
RankingResults: fit results for parameters, and pre- and post-fit impacts
10481081
"""
10491082

1050-
fit_settings = {
1051-
"init_pars": init_pars or model.config.suggested_init(),
1052-
"fix_pars": fix_pars or model.config.suggested_fixed(),
1053-
"par_bounds": par_bounds,
1054-
"strategy": strategy,
1055-
"maxiter": maxiter,
1056-
"tolerance": tolerance,
1057-
"custom_fit": custom_fit,
1058-
}
1083+
fit_settings = cast(
1084+
FitKwargs,
1085+
{
1086+
"init_pars": init_pars or model.config.suggested_init(),
1087+
"fix_pars": fix_pars or model.config.suggested_fixed(),
1088+
"par_bounds": par_bounds,
1089+
"strategy": strategy,
1090+
"maxiter": maxiter,
1091+
"tolerance": tolerance,
1092+
"custom_fit": custom_fit,
1093+
},
1094+
)
10591095

10601096
if fit_results is None:
10611097
fit_results = _fit_model(model, data, **fit_settings)

0 commit comments

Comments
 (0)