Skip to content

Commit f298935

Browse files
starting with evaluationtypes
1 parent d640dcd commit f298935

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

smodels/statistics/basicStats.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,41 @@
1212
from smodels.base.smodelsLogging import logger
1313
import numpy as np
1414
from smodels.statistics.exceptions import SModelSStatisticsError as SModelSError
15-
from typing import Text
15+
from typing import Text, Union
1616
from collections.abc import Callable
1717

1818
__all__ = [ "CLsfromNLL", "determineBrentBracket", "chi2FromLmax" ]
1919

20+
from enum import Enum
21+
22+
class NllEvalType(Enum):
23+
""" an enum to account for the different types of likelihood values: observed,
24+
a priori expected, a posteriori expected """
25+
observed = 0
26+
aposteriori = 1
27+
apriori = 2
28+
29+
@classmethod
30+
def init ( cls, evaluationType : Union[str,bool] ):
31+
""" get evaluationtype either from a string (e.g. 'priori') or a bool
32+
(true is posteriori, false is observed)
33+
"""
34+
evaluationType = str(evaluationType).lower().replace("_","")
35+
if evaluationType in [ "posteriori", "aposteriori", "posterior" ]:
36+
return cls.aposteriori
37+
if evaluationType in [ "apriori", "prior", "priori", "true" ]:
38+
return cls.apriori
39+
if evaluationType in [ "false", "observed", "obs" ]:
40+
return cls.observed
41+
raise SModelSError ( f"NllEvalType {evaluationType} unknown" )
42+
43+
def __eq__ ( self, other ):
44+
if type ( other ) == NllEvalType:
45+
return super().__eq__ ( other )
46+
if type ( other ) in [ bool, str ]:
47+
return super().__eq__ ( NllEvalType.init ( other ) )
48+
raise SModelSError ( f"comparing a NllEvalType with {type(other)}" )
49+
2050
def CLsfromNLL(
2151
nllA: float, nll0A: float, nll: float, nll0: float, big_muhat : bool,
2252
return_type: Text = "CLs-alpha" ) -> float:

unittests/testEvaluationType.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
.. module:: testEvaluationType
5+
:synopsis: Tests if the evalutionType works as expected (pun intended)
6+
7+
.. moduleauthor:: Wolfgang Waltenberger <wolfgang.waltenberger@gmail.com>
8+
9+
"""
10+
import sys
11+
sys.path.insert(0,"../")
12+
import unittest
13+
14+
from smodels.statistics.basicStats import NllEvalType
15+
16+
class EvalTypeTest(unittest.TestCase):
17+
18+
def testEvalType(self):
19+
""" test that NllEvalType thing """
20+
obs = NllEvalType.init ( "observed" )
21+
obs2 = NllEvalType.init ( False )
22+
prior = NllEvalType.init ( "prior" )
23+
prior2 = NllEvalType.init ( True )
24+
posteriori = NllEvalType.init ( "posteriori" )
25+
posteriori2 = NllEvalType.init ( "aposteriori" )
26+
self.assertTrue( obs == obs2 )
27+
self.assertTrue( obs == obs )
28+
self.assertTrue( obs != prior )
29+
self.assertTrue( obs != posteriori )
30+
self.assertTrue( prior != posteriori )
31+
self.assertTrue( prior == prior2 )
32+
self.assertTrue( posteriori == posteriori2 )
33+
34+
if __name__ == "__main__":
35+
unittest.main()

0 commit comments

Comments
 (0)