Skip to content

Commit 7bef02a

Browse files
Carole SudreCarole Sudre
authored andcommitted
Complementing testing of calibration measures
1 parent 2a72bae commit 7bef02a

2 files changed

Lines changed: 54 additions & 2 deletions

File tree

MetricsReloaded/metrics/calibration_measures.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
self.measures = measures if measures is not None else self.measures_dict
8282

8383
def class_wise_expectation_calibration_error(self):
84-
r"""
84+
"""
8585
Class_wise version of the expectation calibration error
8686
8787
Ananya Kumar, Percy S Liang, and Tengyu Ma. 2019. Verified uncertainty calibration. Advances in Neural Information
@@ -98,6 +98,7 @@ def class_wise_expectation_calibration_error(self):
9898
nbins = self.dict_args["bins_ece"]
9999
else:
100100
nbins = 10
101+
print('number bins is ',nbins)
101102
step = 1.0 / nbins
102103
range_values = np.arange(0, 1.00001, step)
103104
list_values = []
@@ -360,8 +361,10 @@ def top_label_classification_error(self):
360361
prob_ref_values, prob_ref_counts = np.unique(self.ref, return_counts=True)
361362
for k in range(nclasses):
362363
idx = np.where(prob_ref_values == k)
363-
if len(idx) == 0:
364+
print(k, idx)
365+
if np.size(idx) == 0:
364366
prob[k] = 0
367+
print('nothing in ', k)
365368
else:
366369
prob[k] = prob_ref_counts[idx[0]] / numb_samples
367370

test/test_metrics/test_calibration_metrics.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ def test_brier_score():
7171
expected_bs = 0.4
7272
assert_allclose(expected_bs, value_test, atol=0.01)
7373

74+
def test_root_brier_score():
75+
ref_bs = [1, 0]
76+
pred_bs = [[0.2,0.8],
77+
[0.4,0.6]]
78+
ppm = CalibrationMeasures(np.asarray(pred_bs), np.asarray(ref_bs))
79+
value_test = ppm.root_brier_score()
80+
expected_bs = 0.6325
81+
assert_allclose(expected_bs, value_test, atol=0.01)
82+
7483
#To use SN 2.14 p 99 of Metrics Reloaded
7584

7685
def test_top_label_classification_error():
@@ -86,6 +95,22 @@ def test_top_label_classification_error():
8695
value_test = cm.top_label_classification_error()
8796
assert_allclose(value_test, expected_tce, atol=0.001)
8897

98+
def test_top_label_classification_error_oneemptyclass():
99+
ref_tce = [1, 0, 1, 1]
100+
pred_tce = [[0.1, 0.8, 0, 0.1], [0.6, 0.1, 0.6, 0.7], [0.3, 0.1, 0.4, 0.2]]
101+
# 0.25 - 0.75 - 0
102+
#
103+
pred_tce = np.asarray(pred_tce).T
104+
ref_tce = np.asarray(ref_tce)
105+
expected_prob = [0.75, 0.25, 0.75, 0.75]
106+
best_prob = [0.6, 0.8, 0.6, 0.7]
107+
pred_class = [1, 0, 1, 1]
108+
# sqrt(0.15^2 + 0.55^2 + 0.15^2 + 0.05^2)/4
109+
expected_tce = 0.2958
110+
cm = CalibrationMeasures(pred_tce, ref_tce)
111+
value_test = cm.top_label_classification_error()
112+
assert_allclose(value_test, expected_tce, atol=0.001)
113+
89114

90115
def test_negative_log_likelihood():
91116
ref_nll = [1, 0, 2, 1]
@@ -111,32 +136,56 @@ def test_class_wise_expectation_calibration_error():
111136
pred_cwece = np.asarray(pred_cwece).T
112137
dict_args = {"bins_ece": 2}
113138
cm = CalibrationMeasures(pred_cwece, ref_cwece, dict_args=dict_args)
139+
cm2 = CalibrationMeasures(pred_cwece, ref_cwece)
114140
value_test = cm.class_wise_expectation_calibration_error()
141+
value_test2 = cm2.class_wise_expectation_calibration_error()
115142
expected_cwece = 0.150
143+
expected_cwece2 = 0.150
116144
assert_allclose(value_test, expected_cwece, atol=0.001)
145+
assert_allclose(value_test2, expected_cwece2, atol=0.001)
117146

118147

119148
def test_gamma_ik():
120149
pred = [[0.1, 0.8, 0, 0.1], [0.6, 0.1, 0, 0.7], [0.3, 0.1, 1, 0.2]]
121150
pred = np.asarray(pred).T
122151
ref = np.asarray([1, 0, 2, 1])
123152
cm = CalibrationMeasures(pred, ref)
153+
cm2 = CalibrationMeasures(pred, ref, dict_args={'bandwidth':0.5})
124154
value_test = cm.gamma_ik(0, 0)
155+
value_test2 = cm2.gamma_ik(0,0)
125156
expected_gamma = gamma(1.2)
157+
expected_gamma2 = gamma(1.2)
126158
assert_allclose(value_test, expected_gamma, atol=0.001)
159+
assert_allclose(value_test2, expected_gamma2, atol=0.001)
127160

128161

129162
def test_dirichlet_kernel():
130163
pred = [[0.1, 0.8, 0, 0.1], [0.6, 0.1, 0, 0.7], [0.3, 0.1, 1, 0.2]]
131164
pred = np.asarray(pred).T
132165
ref = np.asarray([1, 0, 2, 1])
133166
cm = CalibrationMeasures(pred, ref)
167+
cm2 = CalibrationMeasures(pred,ref,dict_args={'bandwidth':0.5})
134168
numerator = gamma(1.2 + 2.2 + 1.6)
135169
denominator = gamma(1.2) * gamma(2.2) * gamma(1.6)
136170
prod = np.power(0.8, 0.2) * np.power(0.1, 1.2) * np.power(0.1, 0.6)
137171
value_test = cm.dirichlet_kernel(1, 0)
172+
value_test2 = cm2.dirichlet_kernel(1,0)
138173
expected_dir = numerator * prod / denominator
174+
expected_dir2 = expected_dir
139175
assert_allclose(value_test, expected_dir, atol=0.001)
176+
assert_allclose(value_test2, expected_dir2, atol=0.001)
177+
178+
179+
180+
def test_kernel_calculation():
181+
pred = [[0.1, 0.8, 0, 0.1], [0.6, 0.1, 0, 0.7], [0.3, 0.1, 1, 0.2]]
182+
#sqrt(0.7^2 + 0.5^2 + 0.2^2)/0.2
183+
pred = np.asarray(pred).T
184+
ref = np.asarray([1, 0, 2, 1])
185+
cm = CalibrationMeasures(pred,ref,dict_args={'bandwidth_kce':0.2})
186+
value_test = cm.kernel_calculation(0,1)[0,0]
187+
expected_value = 0.01208
188+
assert_allclose(value_test, expected_value, atol=0.001)
140189

141190
def test_kernel_calibration_error():
142191
pred = [[0.1, 0.8, 0, 0.1], [0.6, 0.1, 0, 0.7], [0.3, 0.1, 1, 0.2]]

0 commit comments

Comments
 (0)