Skip to content

Commit 7cdffa2

Browse files
committed
Add stats for sets of size 2 or more
1 parent f6bdd88 commit 7cdffa2

File tree

5 files changed

+46
-16
lines changed

5 files changed

+46
-16
lines changed

src/conformist/alpha_selector.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99

1010
class AlphaSelector(OutputDir):
11+
FIGURE_FONTSIZE = 12
12+
FIGURE_WIDTH = 12
13+
FIGURE_HEIGHT = 8
14+
plt.rcParams.update({'font.size': FIGURE_FONTSIZE})
15+
1116
def __init__(self,
1217
prediction_dataset: PredictionDataset,
1318
cop_class,
@@ -32,6 +37,7 @@ def __init__(self,
3237
self.pcts_empty_sets = []
3338
self.pcts_singleton_sets = []
3439
self.pcts_singleton_or_duo_sets = []
40+
self.pcts_duo_plus_sets = []
3541
self.pcts_trio_plus_sets = []
3642
self.mean_false_negative_rates = []
3743
self.mean_softmax_threshold = []
@@ -54,6 +60,7 @@ def run(self):
5460
self.pcts_singleton_sets.append(trial.pct_singleton_sets())
5561
self.pcts_singleton_or_duo_sets.append(
5662
trial.pct_singleton_or_duo_sets())
63+
self.pcts_duo_plus_sets.append(trial.pct_duo_plus_sets())
5764
self.pcts_trio_plus_sets.append(trial.pct_trio_plus_sets())
5865
self.mean_false_negative_rates.append(
5966
trial.mean_false_negative_rate())
@@ -69,7 +76,9 @@ def run_reports(self):
6976

7077
def visualize(self):
7178
# MEAN SET SIZES GRAPH
72-
plt.figure()
79+
plt.figure(figsize=(self.FIGURE_WIDTH,
80+
self.FIGURE_HEIGHT))
81+
plt.tight_layout()
7382

7483
data = pd.DataFrame({
7584
'Alpha': self.alphas,
@@ -80,7 +89,10 @@ def visualize(self):
8089
plt.savefig(f'{self.output_dir}/alpha_to_mean_set_size.png')
8190

8291
# PERCENT EMPTY/SINGLETON SETS GRAPH
83-
plt.figure()
92+
# MEAN SET SIZES GRAPH
93+
plt.figure(figsize=(self.FIGURE_WIDTH,
94+
self.FIGURE_HEIGHT))
95+
plt.tight_layout()
8496

8597
# Labels
8698
x_label = 'Alpha'
@@ -90,9 +102,9 @@ def visualize(self):
90102
# Create a DataFrame for the pct_empty_sets and pct_singleton_sets
91103
data = pd.DataFrame({
92104
x_label: self.alphas,
93-
'n = 0': self.pcts_empty_sets,
94-
'n ∈ {1, 2}': self.pcts_singleton_or_duo_sets,
95-
'n ≥ 3': self.pcts_trio_plus_sets
105+
'empty (n = 0)': self.pcts_empty_sets,
106+
'certain (n=1)': self.pcts_singleton_or_duo_sets,
107+
'uncertain (n ≥ 2)': self.pcts_duo_plus_sets
96108
})
97109

98110
# Melt the DataFrame to have the set types as a separate column
@@ -110,14 +122,16 @@ def visualize(self):
110122
# Get the current x-tick labels
111123
labels = [item.get_text() for item in plt.gca().get_xticklabels()]
112124

125+
target = 'certain (n=1)'
126+
113127
# Draw a horizontal line across the top of the highest orange bar
114-
max_singleton_or_duo_sets = data['n ∈ {1, 2}'].max()
115-
plt.axhline(y=max_singleton_or_duo_sets,
128+
optimal_value = data[target].max()
129+
plt.axhline(y=optimal_value,
116130
color='#cccccc',
117131
linestyle='--')
118132

119133
# Get the index of the label with the highest value
120-
idx = data['n ∈ {1, 2}'].idxmax()
134+
idx = data[target].idxmax()
121135

122136
# Make this label bold
123137
labels[idx] = f'$\\bf{{{labels[idx]}}}$'
@@ -131,7 +145,9 @@ def visualize(self):
131145
plt.savefig(f'{self.output_dir}/alpha_to_set_sizes.png')
132146

133147
def visualize_lambdas(self):
134-
plt.figure()
148+
plt.figure(figsize=(self.FIGURE_WIDTH,
149+
self.FIGURE_HEIGHT))
150+
plt.tight_layout()
135151

136152
# Only use reasonable alphas
137153
alphas = [0.05, 0.1, 0.15, 0.2, 0.3, 0.4]
@@ -156,7 +172,7 @@ def visualize_lambdas(self):
156172
plt.text(self.lamhats[a], 0 + padding,
157173
f'{self.lamhats[a]:.2f}',
158174
ha='center', va='bottom',
159-
fontsize=8, color='black',
175+
color='black',
160176
weight='bold')
161177
i += 1
162178

@@ -179,9 +195,8 @@ def save_summary(self):
179195
'alpha': self.alphas,
180196
'Mean set size': self.mean_set_sizes,
181197
'% sets n=0': self.pcts_empty_sets,
182-
'% sets n={1}': self.pcts_singleton_sets,
183-
'% sets n={1|2}': self.pcts_singleton_or_duo_sets,
184-
'% sets n>=3': self.pcts_trio_plus_sets,
198+
'% sets n=1': self.pcts_singleton_sets,
199+
'% sets n>=2': self.pcts_duo_plus_sets,
185200
'Mean FNR': self.mean_false_negative_rates,
186201
'Mean softmax threshold': self.mean_softmax_threshold
187202
}

src/conformist/performance_report.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,17 @@ def pct_singleton_or_duo_sets(prediction_sets):
2828
prediction_set in prediction_sets) / \
2929
len(prediction_sets)
3030

31-
def pct_trio_plus_sets(prediction_sets):
32-
return sum(sum(prediction_set) >= 3 for
31+
def _pct_sets_of_min_size(prediction_sets, min_size):
32+
return sum(sum(prediction_set) >= min_size for
3333
prediction_set in prediction_sets) / \
3434
len(prediction_sets)
3535

36+
def pct_duo_plus_sets(prediction_sets):
37+
return PerformanceReport._pct_sets_of_min_size(prediction_sets, 2)
38+
39+
def pct_trio_plus_sets(prediction_sets):
40+
return PerformanceReport._pct_sets_of_min_size(prediction_sets, 3)
41+
3642
def _class_report(self,
3743
items_by_class,
3844
output_file_prefix,

src/conformist/prediction_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class PredictionDataset(OutputDir):
2121
FIGURE_WIDTH = 12
2222
plt.rcParams.update({'font.size': FIGURE_FONTSIZE})
2323

24-
2524
def __init__(self,
2625
df=None,
2726
predictions_csv=None,

src/conformist/validation_run.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def pct_singleton_sets(self):
5353
def pct_singleton_or_duo_sets(self):
5454
return PerformanceReport.pct_singleton_or_duo_sets(self.prediction_sets)
5555

56+
def pct_duo_plus_sets(self):
57+
return PerformanceReport.pct_duo_plus_sets(self.prediction_sets)
58+
5659
def pct_trio_plus_sets(self):
5760
return PerformanceReport.pct_trio_plus_sets(self.prediction_sets)
5861

@@ -130,6 +133,7 @@ def run_reports(self, base_output_dir):
130133
'pct_empty_sets': self.pct_empty_sets(),
131134
'pct_singleton_sets': self.pct_singleton_sets(),
132135
'pct_singleton_or_duo_sets': self.pct_singleton_or_duo_sets(),
136+
'pct_duo_plus_sets': self.pct_duo_plus_sets(),
133137
'pct_trio_plus_sets': self.pct_trio_plus_sets()
134138
}, index=[0])
135139

src/conformist/validation_trial.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def pct_singleton_or_duo_sets(self):
3636
singleton_or_duo.append(run.pct_singleton_or_duo_sets())
3737
return statistics.mean(singleton_or_duo)
3838

39+
def pct_duo_plus_sets(self):
40+
duo_plus = []
41+
for run in self.runs:
42+
duo_plus.append(run.pct_duo_plus_sets())
43+
return statistics.mean(duo_plus)
44+
3945
def pct_trio_plus_sets(self):
4046
trio_plus = []
4147
for run in self.runs:

0 commit comments

Comments
 (0)