Skip to content

Commit 8654d8f

Browse files
author
Nathan Simpson
authored
Merge pull request #53 from gradhep/make-CLs-optional
add new `cls_method` kwarg for `infer.hypotest` to optionally give access to `CLsb` (mostly for experiments)
2 parents 4c5a148 + d474a44 commit 8654d8f

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

src/relaxed/infer.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def hypotest(
2222
return_mle_pars: bool = False,
2323
test_stat: str = "q",
2424
expected_pars: Array | None = None,
25+
cls_method: bool = True,
2526
) -> tuple[Array, Array] | Array:
2627
"""Calculate expected CLs/p-values via hypothesis tests.
2728
@@ -53,7 +54,9 @@ def hypotest(
5354
The MLE parameters, if `return_mle_pars` is True.
5455
"""
5556
if test_stat == "q":
56-
return qmu_test(test_poi, data, model, lr, return_mle_pars, expected_pars)
57+
return qmu_test(
58+
test_poi, data, model, lr, return_mle_pars, expected_pars, cls_method
59+
)
5760
elif test_stat == "q0":
5861
logging.info(
5962
"test_poi automatically set to 0 for q0 test (bkg-only null hypothesis)"
@@ -64,7 +67,7 @@ def hypotest(
6467

6568

6669
@partial(
67-
jit, static_argnames=["model", "return_mle_pars"]
70+
jit, static_argnames=["model", "return_mle_pars", "cls_method"]
6871
) # can remove model eventually
6972
def qmu_test(
7073
test_poi: float,
@@ -73,6 +76,7 @@ def qmu_test(
7376
lr: float,
7477
return_mle_pars: bool = False,
7578
expected_pars: Array | None = None,
79+
cls_method: bool = True,
7680
) -> tuple[Array, Array] | Array:
7781
# hard-code 1 as inits for now
7882
# TODO: need to parse different inits for constrained and global fits
@@ -93,9 +97,12 @@ def qmu_test(
9397
qmu = jnp.where(poi_hat < test_poi, profile_likelihood, 0.0)
9498

9599
CLsb = 1 - pyhf.tensorlib.normal_cdf(jnp.sqrt(qmu))
96-
altval = 0.0
97-
CLb = 1 - pyhf.tensorlib.normal_cdf(altval)
98-
CLs = CLsb / CLb
100+
if cls_method:
101+
altval = 0.0
102+
CLb = 1 - pyhf.tensorlib.normal_cdf(altval)
103+
CLs = CLsb / CLb
104+
else:
105+
CLs = CLsb
99106
return (CLs, mle_pars) if return_mle_pars else CLs
100107

101108

tests/test_infer.py

+24
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,30 @@ def pipeline(x):
8686
jacrev(pipeline)(jnp.asarray(0.5))
8787

8888

89+
@pytest.mark.parametrize("expected_pars", [True, False])
90+
def test_hypotest_grad_noCLs(expected_pars):
91+
pars = jnp.array([0.0, 1.0])
92+
if expected_pars:
93+
expars = pars
94+
else:
95+
expars = None
96+
97+
def pipeline(x):
98+
model = uncorrelated_background(x * 5.0, x * 20, x * 2)
99+
expected_cls = relaxed.infer.hypotest(
100+
1.0,
101+
model=model,
102+
data=model.expected_data(pars),
103+
lr=1e-2,
104+
test_stat="q",
105+
expected_pars=expars,
106+
cls_method=False,
107+
)
108+
return expected_cls
109+
110+
jacrev(pipeline)(jnp.asarray(0.5))
111+
112+
89113
def test_wrong_test_stat():
90114
with pytest.raises(ValueError):
91115
model = example_model(0.0)

0 commit comments

Comments
 (0)