Skip to content

Commit 5983ce5

Browse files
committed
Fix model roc calculations
1 parent abd54d8 commit 5983ce5

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

src/conformist/roc.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,43 +34,65 @@ def __init__(self,
3434
self.model_tpr_list = []
3535
self.model_fpr_list = []
3636

37+
def uncalibrated_model_rates(self, alphas):
38+
# Compute TPR and FPR for the uncalibrated model
39+
model_tpr = []
40+
model_fpr = []
41+
42+
for alpha in alphas:
43+
prediction_sets = self.prediction_dataset.smx >= alpha
44+
45+
# TPR calculation
46+
tps = ((prediction_sets * self.prediction_dataset.labels_idx).sum(axis=1))
47+
fns = ((~prediction_sets * self.prediction_dataset.labels_idx).sum(axis=1))
48+
49+
# Get sums of tps and fns
50+
tps = tps.sum()
51+
fns = fns.sum()
52+
53+
model_tpr.append(tps / (tps + fns))
54+
55+
# FPS calculation
56+
fps = ((prediction_sets * (1 - self.prediction_dataset.labels_idx)).sum(axis=1))
57+
tns = np.array([
58+
len(np.setdiff1d(labels, predictions))
59+
for labels, predictions in zip(self.prediction_dataset.labels_idx, prediction_sets)
60+
])
61+
62+
# Get sums of tns and fps
63+
tns = tns.sum()
64+
fps = fps.sum()
65+
66+
model_fpr.append(fps / (fps + tns))
67+
68+
return model_tpr, model_fpr
69+
3770
def run(self):
3871
# Define a range of significance levels (alpha values)
3972
alpha_levels = np.linspace(self.min_alpha,
4073
self.max_alpha,
4174
self.n_alphas)
4275

76+
self.model_tpr_list, self.model_fpr_list = \
77+
self.uncalibrated_model_rates(alpha_levels)
78+
4379
# Compute TPR and FPR for different alpha thresholds
4480
for alpha in alpha_levels:
4581
print(f'alpha={alpha}')
4682
cop = self.cop_class(self.prediction_dataset, alpha=alpha)
4783
trial = cop.do_validation_trial(n_runs=self.n_runs_per_alpha)
4884

4985
mean_cp_tpr = trial.mean_true_positive_rate()
50-
mean_model_tpr = trial.mean_model_true_positive_rate()
51-
5286
mean_cp_fpr = trial.mean_FPR()
53-
mean_model_fpr = trial.mean_model_false_positive_rate()
5487

5588
self.cp_tpr_list.append(mean_cp_tpr)
5689
self.cp_fpr_list.append(mean_cp_fpr)
5790

58-
print(f'mean_cp_tpr={mean_cp_tpr}, mean_model_tpr={mean_model_tpr}')
59-
print(f'mean_cp_fpr={mean_cp_fpr}, mean_model_fpr={mean_model_fpr}')
60-
61-
self.model_tpr_list.append(mean_model_tpr)
62-
self.model_fpr_list.append(mean_model_fpr)
63-
6491
# Ensure x values are sorted in ascending order
6592
sorted_indices = np.argsort(self.cp_fpr_list)
6693
self.cp_fpr_list = np.array(self.cp_fpr_list)[sorted_indices]
6794
self.cp_tpr_list = np.array(self.cp_tpr_list)[sorted_indices]
6895

69-
# Do same for model
70-
sorted_indices = np.argsort(self.model_fpr_list)
71-
self.model_fpr_list = np.array(self.model_fpr_list)[sorted_indices]
72-
self.model_tpr_list = np.array(self.model_tpr_list)[sorted_indices]
73-
7496
def run_reports(self):
7597
plt.figure(figsize=(self.FIGURE_WIDTH, self.FIGURE_HEIGHT))
7698
plt.tight_layout()

src/conformist/validation_trial.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ class ValidationTrial(OutputDir):
1313
FIGURE_HEIGHT = 8
1414
plt.rcParams.update({'font.size': FIGURE_FONTSIZE})
1515

16-
def __init__(self, runs, class_names=[]):
16+
def __init__(self,
17+
runs,
18+
class_names=[]):
1719
self.runs = runs
1820
self.class_names = class_names
1921

@@ -71,6 +73,7 @@ def mean_true_positive_rate(self):
7173
tps.append(run.true_positive_rate())
7274
return statistics.mean(tps)
7375

76+
# TODO: rename
7477
def mean_FPR(self):
7578
fps = []
7679
for run in self.runs:
@@ -83,12 +86,6 @@ def mean_model_true_positive_rate(self):
8386
tps.append(run.model_true_positive_rate())
8487
return statistics.mean(tps)
8588

86-
def mean_model_false_positive_rate(self):
87-
fps = []
88-
for run in self.runs:
89-
fps.append(run.model_false_positive_rate())
90-
return statistics.mean(fps)
91-
9289
def mean_softmax_threshold(self):
9390
return sum(run.softmax_threshold for run in self.runs) / len(self.runs)
9491

0 commit comments

Comments
 (0)