2
2
3
3
from collections import defaultdict
4
4
import logging
5
- from typing import Any , cast , Dict , List , Literal , Optional , Tuple , Union
5
+ from typing import Any , cast , Dict , List , Literal , Optional , Tuple , TypedDict , Union
6
6
7
7
import iminuit
8
8
import numpy as np
23
23
log = logging .getLogger (__name__ )
24
24
25
25
26
+ class FitKwargs (TypedDict , total = False ):
27
+ init_pars : List [float ]
28
+ fix_pars : List [bool ] # Explicitly state it's a list
29
+ par_bounds : Optional [List [Tuple [float , float ]]]
30
+ strategy : Optional [Literal [0 , 1 , 2 ]]
31
+ maxiter : Optional [int ]
32
+ tolerance : Optional [float ]
33
+ custom_fit : bool
34
+
35
+
26
36
def print_results (fit_results : FitResults ) -> None :
27
37
"""Prints the best-fit parameter results and associated uncertainties.
28
38
@@ -470,7 +480,9 @@ def _get_impacts_summary(
470
480
non-systematic sources
471
481
"""
472
482
non_syst_modifiers = ["normfactor" , "shapefactor" , "staterror" ] # Lumi ?
473
- impacts_summary = defaultdict (lambda : defaultdict (float ))
483
+ impacts_summary : Dict [str , Dict [str , float ]] = defaultdict (
484
+ lambda : defaultdict (float )
485
+ )
474
486
# Dictionary to store the merged values after removing certain modifiers
475
487
syst_impacts_map = defaultdict (list )
476
488
# Iterate through each modifier and its corresponding data
@@ -492,7 +504,7 @@ def _get_impacts_summary(
492
504
)
493
505
)
494
506
495
- return impacts_summary
507
+ return dict ( impacts_summary )
496
508
497
509
498
510
def _get_datastat_impacts_np_shift (
@@ -501,7 +513,7 @@ def _get_datastat_impacts_np_shift(
501
513
data : List [float ],
502
514
poi_index : int ,
503
515
fit_results : FitResults ,
504
- fit_kwargs ,
516
+ fit_kwargs : FitKwargs ,
505
517
) -> Dict [str , Dict [str , float ]]:
506
518
"""
507
519
Calculate the impact of statistical uncertainties on the parameter of interest by
@@ -514,7 +526,7 @@ def _get_datastat_impacts_np_shift(
514
526
data (List[float]): data (including auxdata) the model is fit to
515
527
poi_index (int): index of the parameter of interest
516
528
fit_results (FitResults): nominal fit results to use in impacts calculation
517
- fit_kwargs (_type_ ): settings to be used in the fits
529
+ fit_kwargs (FitKwargs ): settings to be used in the fits
518
530
519
531
Returns:
520
532
Dict[str, Dict[str, float]]: impacts summary categorized by systematic and
@@ -532,12 +544,19 @@ def _get_datastat_impacts_np_shift(
532
544
fit_results .bestfit [i_par ] if i_par != poi_index else init_pars_datastat [i_par ]
533
545
for i_par in range (len (model .config .par_names ))
534
546
]
547
+ # removing init_pars and fix_pars from dict, cast it to FitKwargs
548
+ # type then update values in dict since casting re-introduces them.
549
+ updated_fit_kwargs = cast (
550
+ FitKwargs ,
551
+ {k : v for k , v in fit_kwargs .items () if k not in ["init_pars" , "fix_pars" ]},
552
+ )
553
+ updated_fit_kwargs ["init_pars" ] = init_pars_datastat
554
+ updated_fit_kwargs ["fix_pars" ] = fix_pars_datastat
555
+
535
556
fit_results_datastat = _fit_model (
536
557
model ,
537
558
data ,
538
- init_pars = init_pars_datastat ,
539
- fix_pars = fix_pars_datastat ,
540
- ** {k : v for k , v in fit_kwargs .items () if k not in ["init_pars" , "fix_pars" ]},
559
+ ** updated_fit_kwargs ,
541
560
)
542
561
datastat_poi_val = fit_results_datastat .bestfit [poi_index ]
543
562
datastat_impact = datastat_poi_val - nominal_poi
@@ -546,7 +565,9 @@ def _get_datastat_impacts_np_shift(
546
565
return impacts_summary
547
566
548
567
549
- def _get_datastat_impacts_quadruture (impacts_summary , total_error ):
568
+ def _get_datastat_impacts_quadruture (
569
+ impacts_summary : Dict [str , Dict [str , float ]], total_error : float
570
+ ) -> Dict [str , Dict [str , float ]]:
550
571
"""
551
572
Calculate the impact of statistical uncertainties on the parameter of subtracting
552
573
other sources from the total error in quadrature.
@@ -636,7 +657,9 @@ def _cov_impacts(
636
657
systematic and non-systematic sources
637
658
"""
638
659
total_poi_error = fit_results .uncertainty [poi_index ]
639
- impacts_by_modifier_type = defaultdict (lambda : defaultdict (list ))
660
+ impacts_by_modifier_type : Dict [str , Dict [str , List [float ]]] = defaultdict (
661
+ lambda : defaultdict (list )
662
+ )
640
663
i_global_par = 0
641
664
642
665
for parameter in model .config .par_order :
@@ -688,7 +711,7 @@ def _cov_impacts(
688
711
impacts_summary = _get_impacts_summary (impacts_by_modifier_type )
689
712
impacts_summary = _get_datastat_impacts_quadruture (impacts_summary , total_poi_error )
690
713
691
- return impacts_by_modifier_type , impacts_summary
714
+ return dict ( impacts_by_modifier_type ) , impacts_summary
692
715
693
716
694
717
def _np_impacts (
@@ -698,7 +721,7 @@ def _np_impacts(
698
721
fit_results : FitResults ,
699
722
prefit_unc : np .ndarray ,
700
723
labels : List [str ],
701
- fit_kwargs ,
724
+ fit_kwargs : FitKwargs ,
702
725
) -> Tuple [Dict [str , Dict [str , List [float ]]], Dict [str , Dict [str , float ]]]:
703
726
"""
704
727
Computes the impact of nuisance parameters on the POI by shifting parameter
@@ -712,7 +735,7 @@ def _np_impacts(
712
735
fit_results (FitResults): nominal fit results to use in impacts calculation
713
736
prefit_unc (np.ndarray): pre-fit uncertainties of parameters
714
737
labels (List[str]): list of parameter names
715
- fit_kwargs: settings to be used in the fits.
738
+ fit_kwargs (FitKwargs) : settings to be used in the fits.
716
739
717
740
Returns:
718
741
Tuple[Dict[str, Dict[str, List[float]]], Dict[str, Dict[str, float]]]:
@@ -723,7 +746,9 @@ def _np_impacts(
723
746
systematic and non-systematic sources.
724
747
"""
725
748
nominal_poi = fit_results .bestfit [poi_index ]
726
- impacts_by_modifier_type = defaultdict (lambda : defaultdict (list ))
749
+ impacts_by_modifier_type : Dict [str , Dict [str , List [float ]]] = defaultdict (
750
+ lambda : defaultdict (list )
751
+ )
727
752
i_global_par = 0
728
753
729
754
for parameter in model .config .par_order :
@@ -765,17 +790,25 @@ def _np_impacts(
765
790
init_pars_ranking = fit_kwargs ["init_pars" ].copy ()
766
791
# value of current nuisance parameter
767
792
init_pars_ranking [i_par ] = np_val
768
- fit_results_ranking = _fit_model (
769
- model ,
770
- data ,
771
- init_pars = init_pars_ranking ,
772
- fix_pars = fix_pars_ranking ,
773
- ** {
793
+
794
+ # removing init_pars and fix_pars from dict, cast it to FitKwargs
795
+ # type then update values in dict since casting re-introduces them.
796
+ updated_fit_kwargs = cast (
797
+ FitKwargs ,
798
+ {
774
799
k : v
775
800
for k , v in fit_kwargs .items ()
776
801
if k not in ["init_pars" , "fix_pars" ]
777
802
},
778
803
)
804
+ updated_fit_kwargs ["init_pars" ] = init_pars_ranking
805
+ updated_fit_kwargs ["fix_pars" ] = fix_pars_ranking
806
+
807
+ fit_results_ranking = _fit_model (
808
+ model ,
809
+ data ,
810
+ ** updated_fit_kwargs ,
811
+ )
779
812
poi_val = fit_results_ranking .bestfit [poi_index ]
780
813
parameter_impact = poi_val - nominal_poi
781
814
log .debug (
@@ -798,7 +831,7 @@ def _np_impacts(
798
831
impacts_summary , model , data , poi_index , fit_results , fit_kwargs
799
832
)
800
833
801
- return impacts_by_modifier_type , impacts_summary
834
+ return dict ( impacts_by_modifier_type ) , impacts_summary
802
835
803
836
804
837
def _auxdata_shift_impacts (
@@ -1047,15 +1080,18 @@ def ranking(
1047
1080
RankingResults: fit results for parameters, and pre- and post-fit impacts
1048
1081
"""
1049
1082
1050
- fit_settings = {
1051
- "init_pars" : init_pars or model .config .suggested_init (),
1052
- "fix_pars" : fix_pars or model .config .suggested_fixed (),
1053
- "par_bounds" : par_bounds ,
1054
- "strategy" : strategy ,
1055
- "maxiter" : maxiter ,
1056
- "tolerance" : tolerance ,
1057
- "custom_fit" : custom_fit ,
1058
- }
1083
+ fit_settings = cast (
1084
+ FitKwargs ,
1085
+ {
1086
+ "init_pars" : init_pars or model .config .suggested_init (),
1087
+ "fix_pars" : fix_pars or model .config .suggested_fixed (),
1088
+ "par_bounds" : par_bounds ,
1089
+ "strategy" : strategy ,
1090
+ "maxiter" : maxiter ,
1091
+ "tolerance" : tolerance ,
1092
+ "custom_fit" : custom_fit ,
1093
+ },
1094
+ )
1059
1095
1060
1096
if fit_results is None :
1061
1097
fit_results = _fit_model (model , data , ** fit_settings )
0 commit comments