88
99
1010class AlphaSelector (OutputDir ):
11+ FIGURE_FONTSIZE = 12
12+ FIGURE_WIDTH = 12
13+ FIGURE_HEIGHT = 8
14+ plt .rcParams .update ({'font.size' : FIGURE_FONTSIZE })
15+
1116 def __init__ (self ,
1217 prediction_dataset : PredictionDataset ,
1318 cop_class ,
@@ -32,6 +37,7 @@ def __init__(self,
3237 self .pcts_empty_sets = []
3338 self .pcts_singleton_sets = []
3439 self .pcts_singleton_or_duo_sets = []
40+ self .pcts_duo_plus_sets = []
3541 self .pcts_trio_plus_sets = []
3642 self .mean_false_negative_rates = []
3743 self .mean_softmax_threshold = []
@@ -54,6 +60,7 @@ def run(self):
5460 self .pcts_singleton_sets .append (trial .pct_singleton_sets ())
5561 self .pcts_singleton_or_duo_sets .append (
5662 trial .pct_singleton_or_duo_sets ())
63+ self .pcts_duo_plus_sets .append (trial .pct_duo_plus_sets ())
5764 self .pcts_trio_plus_sets .append (trial .pct_trio_plus_sets ())
5865 self .mean_false_negative_rates .append (
5966 trial .mean_false_negative_rate ())
@@ -69,7 +76,9 @@ def run_reports(self):
6976
7077 def visualize (self ):
7178 # MEAN SET SIZES GRAPH
72- plt .figure ()
79+ plt .figure (figsize = (self .FIGURE_WIDTH ,
80+ self .FIGURE_HEIGHT ))
81+ plt .tight_layout ()
7382
7483 data = pd .DataFrame ({
7584 'Alpha' : self .alphas ,
@@ -80,7 +89,10 @@ def visualize(self):
8089 plt .savefig (f'{ self .output_dir } /alpha_to_mean_set_size.png' )
8190
8291 # PERCENT EMPTY/SINGLETON SETS GRAPH
83- plt .figure ()
92+ # MEAN SET SIZES GRAPH
93+ plt .figure (figsize = (self .FIGURE_WIDTH ,
94+ self .FIGURE_HEIGHT ))
95+ plt .tight_layout ()
8496
8597 # Labels
8698 x_label = 'Alpha'
@@ -90,9 +102,9 @@ def visualize(self):
90102 # Create a DataFrame for the pct_empty_sets and pct_singleton_sets
91103 data = pd .DataFrame ({
92104 x_label : self .alphas ,
93- 'n = 0' : self .pcts_empty_sets ,
94- 'n ∈ {1, 2} ' : self .pcts_singleton_or_duo_sets ,
95- 'n ≥ 3 ' : self .pcts_trio_plus_sets
105+ 'empty ( n = 0) ' : self .pcts_empty_sets ,
106+ 'certain (n=1) ' : self .pcts_singleton_or_duo_sets ,
107+ 'uncertain ( n ≥ 2) ' : self .pcts_duo_plus_sets
96108 })
97109
98110 # Melt the DataFrame to have the set types as a separate column
@@ -110,14 +122,16 @@ def visualize(self):
110122 # Get the current x-tick labels
111123 labels = [item .get_text () for item in plt .gca ().get_xticklabels ()]
112124
125+ target = 'certain (n=1)'
126+
113127 # Draw a horizontal line across the top of the highest orange bar
114- max_singleton_or_duo_sets = data ['n ∈ {1, 2}' ].max ()
115- plt .axhline (y = max_singleton_or_duo_sets ,
128+ optimal_value = data [target ].max ()
129+ plt .axhline (y = optimal_value ,
116130 color = '#cccccc' ,
117131 linestyle = '--' )
118132
119133 # Get the index of the label with the highest value
120- idx = data ['n ∈ {1, 2}' ].idxmax ()
134+ idx = data [target ].idxmax ()
121135
122136 # Make this label bold
123137 labels [idx ] = f'$\\ bf{{{ labels [idx ]} }}$'
@@ -131,7 +145,9 @@ def visualize(self):
131145 plt .savefig (f'{ self .output_dir } /alpha_to_set_sizes.png' )
132146
133147 def visualize_lambdas (self ):
134- plt .figure ()
148+ plt .figure (figsize = (self .FIGURE_WIDTH ,
149+ self .FIGURE_HEIGHT ))
150+ plt .tight_layout ()
135151
136152 # Only use reasonable alphas
137153 alphas = [0.05 , 0.1 , 0.15 , 0.2 , 0.3 , 0.4 ]
@@ -156,7 +172,7 @@ def visualize_lambdas(self):
156172 plt .text (self .lamhats [a ], 0 + padding ,
157173 f'{ self .lamhats [a ]:.2f} ' ,
158174 ha = 'center' , va = 'bottom' ,
159- fontsize = 8 , color = 'black' ,
175+ color = 'black' ,
160176 weight = 'bold' )
161177 i += 1
162178
@@ -179,9 +195,8 @@ def save_summary(self):
179195 'alpha' : self .alphas ,
180196 'Mean set size' : self .mean_set_sizes ,
181197 '% sets n=0' : self .pcts_empty_sets ,
182- '% sets n={1}' : self .pcts_singleton_sets ,
183- '% sets n={1|2}' : self .pcts_singleton_or_duo_sets ,
184- '% sets n>=3' : self .pcts_trio_plus_sets ,
198+ '% sets n=1' : self .pcts_singleton_sets ,
199+ '% sets n>=2' : self .pcts_duo_plus_sets ,
185200 'Mean FNR' : self .mean_false_negative_rates ,
186201 'Mean softmax threshold' : self .mean_softmax_threshold
187202 }
0 commit comments