@@ -71,6 +71,15 @@ def test_brier_score():
7171 expected_bs = 0.4
7272 assert_allclose (expected_bs , value_test , atol = 0.01 )
7373
74+ def test_root_brier_score ():
75+ ref_bs = [1 , 0 ]
76+ pred_bs = [[0.2 ,0.8 ],
77+ [0.4 ,0.6 ]]
78+ ppm = CalibrationMeasures (np .asarray (pred_bs ), np .asarray (ref_bs ))
79+ value_test = ppm .root_brier_score ()
80+ expected_bs = 0.6325
81+ assert_allclose (expected_bs , value_test , atol = 0.01 )
82+
7483#To use SN 2.14 p 99 of Metrics Reloaded
7584
7685def test_top_label_classification_error ():
@@ -86,6 +95,22 @@ def test_top_label_classification_error():
8695 value_test = cm .top_label_classification_error ()
8796 assert_allclose (value_test , expected_tce , atol = 0.001 )
8897
98+ def test_top_label_classification_error_oneemptyclass ():
99+ ref_tce = [1 , 0 , 1 , 1 ]
100+ pred_tce = [[0.1 , 0.8 , 0 , 0.1 ], [0.6 , 0.1 , 0.6 , 0.7 ], [0.3 , 0.1 , 0.4 , 0.2 ]]
101+ # 0.25 - 0.75 - 0
102+ #
103+ pred_tce = np .asarray (pred_tce ).T
104+ ref_tce = np .asarray (ref_tce )
105+ expected_prob = [0.75 , 0.25 , 0.75 , 0.75 ]
106+ best_prob = [0.6 , 0.8 , 0.6 , 0.7 ]
107+ pred_class = [1 , 0 , 1 , 1 ]
108+ # sqrt(0.15^2 + 0.55^2 + 0.15^2 + 0.05^2)/4
109+ expected_tce = 0.2958
110+ cm = CalibrationMeasures (pred_tce , ref_tce )
111+ value_test = cm .top_label_classification_error ()
112+ assert_allclose (value_test , expected_tce , atol = 0.001 )
113+
89114
90115def test_negative_log_likelihood ():
91116 ref_nll = [1 , 0 , 2 , 1 ]
@@ -111,32 +136,56 @@ def test_class_wise_expectation_calibration_error():
111136 pred_cwece = np .asarray (pred_cwece ).T
112137 dict_args = {"bins_ece" : 2 }
113138 cm = CalibrationMeasures (pred_cwece , ref_cwece , dict_args = dict_args )
139+ cm2 = CalibrationMeasures (pred_cwece , ref_cwece )
114140 value_test = cm .class_wise_expectation_calibration_error ()
141+ value_test2 = cm2 .class_wise_expectation_calibration_error ()
115142 expected_cwece = 0.150
143+ expected_cwece2 = 0.150
116144 assert_allclose (value_test , expected_cwece , atol = 0.001 )
145+ assert_allclose (value_test2 , expected_cwece2 , atol = 0.001 )
117146
118147
119148def test_gamma_ik ():
120149 pred = [[0.1 , 0.8 , 0 , 0.1 ], [0.6 , 0.1 , 0 , 0.7 ], [0.3 , 0.1 , 1 , 0.2 ]]
121150 pred = np .asarray (pred ).T
122151 ref = np .asarray ([1 , 0 , 2 , 1 ])
123152 cm = CalibrationMeasures (pred , ref )
153+ cm2 = CalibrationMeasures (pred , ref , dict_args = {'bandwidth' :0.5 })
124154 value_test = cm .gamma_ik (0 , 0 )
155+ value_test2 = cm2 .gamma_ik (0 ,0 )
125156 expected_gamma = gamma (1.2 )
157+ expected_gamma2 = gamma (1.2 )
126158 assert_allclose (value_test , expected_gamma , atol = 0.001 )
159+ assert_allclose (value_test2 , expected_gamma2 , atol = 0.001 )
127160
128161
129162def test_dirichlet_kernel ():
130163 pred = [[0.1 , 0.8 , 0 , 0.1 ], [0.6 , 0.1 , 0 , 0.7 ], [0.3 , 0.1 , 1 , 0.2 ]]
131164 pred = np .asarray (pred ).T
132165 ref = np .asarray ([1 , 0 , 2 , 1 ])
133166 cm = CalibrationMeasures (pred , ref )
167+ cm2 = CalibrationMeasures (pred ,ref ,dict_args = {'bandwidth' :0.5 })
134168 numerator = gamma (1.2 + 2.2 + 1.6 )
135169 denominator = gamma (1.2 ) * gamma (2.2 ) * gamma (1.6 )
136170 prod = np .power (0.8 , 0.2 ) * np .power (0.1 , 1.2 ) * np .power (0.1 , 0.6 )
137171 value_test = cm .dirichlet_kernel (1 , 0 )
172+ value_test2 = cm2 .dirichlet_kernel (1 ,0 )
138173 expected_dir = numerator * prod / denominator
174+ expected_dir2 = expected_dir
139175 assert_allclose (value_test , expected_dir , atol = 0.001 )
176+ assert_allclose (value_test2 , expected_dir2 , atol = 0.001 )
177+
178+
179+
180+ def test_kernel_calculation ():
181+ pred = [[0.1 , 0.8 , 0 , 0.1 ], [0.6 , 0.1 , 0 , 0.7 ], [0.3 , 0.1 , 1 , 0.2 ]]
182+ #sqrt(0.7^2 + 0.5^2 + 0.2^2)/0.2
183+ pred = np .asarray (pred ).T
184+ ref = np .asarray ([1 , 0 , 2 , 1 ])
185+ cm = CalibrationMeasures (pred ,ref ,dict_args = {'bandwidth_kce' :0.2 })
186+ value_test = cm .kernel_calculation (0 ,1 )[0 ,0 ]
187+ expected_value = 0.01208
188+ assert_allclose (value_test , expected_value , atol = 0.001 )
140189
141190def test_kernel_calibration_error ():
142191 pred = [[0.1 , 0.8 , 0 , 0.1 ], [0.6 , 0.1 , 0 , 0.7 ], [0.3 , 0.1 , 1 , 0.2 ]]
0 commit comments