Skip to content

Commit 80e8dc6

Browse files
authored
Merge branch 'master' into maly-issue-491
2 parents e0351eb + 4728445 commit 80e8dc6

File tree

5 files changed

+44
-24
lines changed

5 files changed

+44
-24
lines changed

.pre-commit-config.yaml

+9-9
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ repos:
33
rev: 25.1.0
44
hooks:
55
- id: black
6-
# - repo: https://github.com/pre-commit/mirrors-mypy
7-
# rev: v1.14.1
8-
# hooks:
9-
# - id: mypy
10-
# name: mypy with Python 3.12
11-
# files: src/cabinetry
12-
# additional_dependencies: ["numpy>=1.22", "boost-histogram>=1.0.1", "click>=8", "types-tabulate", "types-PyYAML", "hist>=2.3.0"]
13-
# args: ["--python-version=3.12"]
6+
- repo: https://github.com/pre-commit/mirrors-mypy
7+
rev: v1.15.0
8+
hooks:
9+
- id: mypy
10+
name: mypy with Python 3.12
11+
files: src/cabinetry
12+
additional_dependencies: ["numpy>=1.22", "boost-histogram>=1.0.1", "click>=8", "types-tabulate", "types-PyYAML", "hist>=2.3.0"]
13+
args: ["--python-version=3.12"]
1414
- repo: https://github.com/pycqa/flake8
15-
rev: 7.1.1
15+
rev: 7.1.2
1616
hooks:
1717
- id: flake8
1818
additional_dependencies: [flake8-bugbear, flake8-import-order, flake8-print]

src/cabinetry/fit/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import defaultdict
44
import logging
5-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
5+
from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Union
66

77
import iminuit
88
import numpy as np
@@ -1105,7 +1105,7 @@ def scan(
11051105
for i_par, par_value in enumerate(scan_values):
11061106
log.debug(f"performing fit with {par_name} = {par_value:.3f}")
11071107
init_pars_scan = init_pars.copy()
1108-
init_pars_scan[par_index] = par_value
1108+
init_pars_scan[par_index] = cast(float, par_value)
11091109
scan_fit_results = _fit_model(
11101110
model,
11111111
data,

src/cabinetry/model_utils.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,17 @@
33
from collections import defaultdict
44
import json
55
import logging
6-
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Tuple, Union
6+
from typing import (
7+
Any,
8+
cast,
9+
DefaultDict,
10+
Dict,
11+
List,
12+
NamedTuple,
13+
Optional,
14+
Tuple,
15+
Union,
16+
)
717

818
import numpy as np
919
import pyhf
@@ -746,7 +756,7 @@ def _parameters_maximizing_constraint_term(
746756
Returns:
747757
List[float]: parameters maximizing the model constraint term
748758
"""
749-
best_pars = [] # parameters maximizing constraint term
759+
best_pars: List[float] = [] # parameters maximizing constraint term
750760
i_aux = 0 # current position in auxiliary data list
751761
i_poisson = 0 # current position in list of Poisson rescale factors
752762

@@ -771,8 +781,12 @@ def _parameters_maximizing_constraint_term(
771781
else:
772782
rescale_factors = [1.0] * n_params # no rescaling by default
773783

774-
best_pars += list(
775-
np.asarray(aux_data[i_aux : i_aux + n_params]) / rescale_factors
784+
# manually cast, possible cause https://github.com/numpy/numpy/issues/27944
785+
best_pars += cast(
786+
List[float],
787+
(
788+
np.asarray(aux_data[i_aux : i_aux + n_params]) / rescale_factors
789+
).tolist(),
776790
)
777791
i_aux += n_params
778792

src/cabinetry/workspace.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
import pathlib
6-
from typing import Any, Dict, List, Optional, Tuple, Union
6+
from typing import Any, cast, Dict, List, Optional, Tuple, Union
77

88
import pyhf
99

@@ -189,8 +189,9 @@ def normplusshape_modifiers(
189189
else:
190190
norm_effect_up = histogram_up.normalize_to_yield(histogram_nominal)
191191
norm_effect_down = histogram_down.normalize_to_yield(histogram_nominal)
192-
histo_yield_up = histogram_up.yields.tolist()
193-
histo_yield_down = histogram_down.yields.tolist()
192+
# manually cast due to https://github.com/numpy/numpy/issues/27944
193+
histo_yield_up = cast(List[float], histogram_up.yields.tolist())
194+
histo_yield_down = cast(List[float], histogram_down.yields.tolist())
194195

195196
log.debug(
196197
f"normalization impact of systematic {systematic['Name']} on sample "
@@ -508,9 +509,12 @@ def _symmetrized_templates_and_norm(
508509
# normalize the variation to the same yield as nominal
509510
norm_effect_var = variation.normalize_to_yield(reference)
510511
norm_effect_sym = 2 - norm_effect_var
511-
histo_yield_var = variation.yields.tolist()
512+
# manually cast due to https://github.com/numpy/numpy/issues/27944
513+
histo_yield_var = cast(List[float], variation.yields.tolist())
512514
# need another histogram that corresponds to the symmetrized variation,
513515
# which is 2*nominal - variation
514-
histo_yield_sym = (2 * reference.yields - variation.yields).tolist()
516+
histo_yield_sym = cast(
517+
List[float], (2 * reference.yields - variation.yields).tolist()
518+
)
515519

516520
return histo_yield_var, histo_yield_sym, norm_effect_var, norm_effect_sym

tests/fit/test_fit.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,14 @@ def test__goodness_of_fit(
376376
assert np.allclose(p_val, 0.91926079)
377377
caplog.clear()
378378

379-
# same setup but using custom auxdata
379+
# same setup but using custom auxdata and custom fixed parameters
380380
model, _ = model_utils.model_and_data(example_spec_multibin)
381381
data = [35, 8, 10] + [0.9, 1.1, 0.8] # custom aux
382-
p_val = fit._goodness_of_fit(model, data, 9.964913)
382+
fix_pars = [False, False, False, True] # custom fixed
383+
p_val = fit._goodness_of_fit(model, data, 9.964913, fix_pars=fix_pars)
383384
assert mock_pars.call_count == 2
384385
assert np.allclose(mock_pars.call_args[0][1], [0.9, 1.1, 0.8]) # aux picked up
386+
assert mock_count.call_args[1] == {"fix_pars": fix_pars} # fixed pars picked up
385387
assert np.allclose(p_val, 0.91926079) # same result as before
386388

387389
# no auxdata and zero degrees of freedom in chi2 test
@@ -501,9 +503,9 @@ def test_fit(mock_fit, mock_print, mock_gof):
501503
assert fit_results.bestfit == [1.0]
502504

503505
# goodness-of-fit test
504-
fit_results_gof = fit.fit(model, data, goodness_of_fit=True)
506+
fit_results_gof = fit.fit(model, data, goodness_of_fit=True, fix_pars=fix_pars)
505507
assert mock_gof.call_args[0] == (model, data, 2.0)
506-
assert mock_gof.call_args[1] == {"fix_pars": None}
508+
assert mock_gof.call_args[1] == {"fix_pars": fix_pars}
507509
assert fit_results_gof.goodness_of_fit == 0.1
508510

509511

0 commit comments

Comments
 (0)