@@ -34,43 +34,65 @@ def __init__(self,
3434 self .model_tpr_list = []
3535 self .model_fpr_list = []
3636
37+ def uncalibrated_model_rates (self , alphas ):
38+ # Compute TPR and FPR for the uncalibrated model
39+ model_tpr = []
40+ model_fpr = []
41+
42+ for alpha in alphas :
43+ prediction_sets = self .prediction_dataset .smx >= alpha
44+
45+ # TPR calculation
46+ tps = ((prediction_sets * self .prediction_dataset .labels_idx ).sum (axis = 1 ))
47+ fns = ((~ prediction_sets * self .prediction_dataset .labels_idx ).sum (axis = 1 ))
48+
49+ # Get sums of tps and fns
50+ tps = tps .sum ()
51+ fns = fns .sum ()
52+
53+ model_tpr .append (tps / (tps + fns ))
54+
55+ # FPS calculation
56+ fps = ((prediction_sets * (1 - self .prediction_dataset .labels_idx )).sum (axis = 1 ))
57+ tns = np .array ([
58+ len (np .setdiff1d (labels , predictions ))
59+ for labels , predictions in zip (self .prediction_dataset .labels_idx , prediction_sets )
60+ ])
61+
62+ # Get sums of tns and fps
63+ tns = tns .sum ()
64+ fps = fps .sum ()
65+
66+ model_fpr .append (fps / (fps + tns ))
67+
68+ return model_tpr , model_fpr
69+
3770 def run (self ):
3871 # Define a range of significance levels (alpha values)
3972 alpha_levels = np .linspace (self .min_alpha ,
4073 self .max_alpha ,
4174 self .n_alphas )
4275
76+ self .model_tpr_list , self .model_fpr_list = \
77+ self .uncalibrated_model_rates (alpha_levels )
78+
4379 # Compute TPR and FPR for different alpha thresholds
4480 for alpha in alpha_levels :
4581 print (f'alpha={ alpha } ' )
4682 cop = self .cop_class (self .prediction_dataset , alpha = alpha )
4783 trial = cop .do_validation_trial (n_runs = self .n_runs_per_alpha )
4884
4985 mean_cp_tpr = trial .mean_true_positive_rate ()
50- mean_model_tpr = trial .mean_model_true_positive_rate ()
51-
5286 mean_cp_fpr = trial .mean_FPR ()
53- mean_model_fpr = trial .mean_model_false_positive_rate ()
5487
5588 self .cp_tpr_list .append (mean_cp_tpr )
5689 self .cp_fpr_list .append (mean_cp_fpr )
5790
58- print (f'mean_cp_tpr={ mean_cp_tpr } , mean_model_tpr={ mean_model_tpr } ' )
59- print (f'mean_cp_fpr={ mean_cp_fpr } , mean_model_fpr={ mean_model_fpr } ' )
60-
61- self .model_tpr_list .append (mean_model_tpr )
62- self .model_fpr_list .append (mean_model_fpr )
63-
6491 # Ensure x values are sorted in ascending order
6592 sorted_indices = np .argsort (self .cp_fpr_list )
6693 self .cp_fpr_list = np .array (self .cp_fpr_list )[sorted_indices ]
6794 self .cp_tpr_list = np .array (self .cp_tpr_list )[sorted_indices ]
6895
69- # Do same for model
70- sorted_indices = np .argsort (self .model_fpr_list )
71- self .model_fpr_list = np .array (self .model_fpr_list )[sorted_indices ]
72- self .model_tpr_list = np .array (self .model_tpr_list )[sorted_indices ]
73-
7496 def run_reports (self ):
7597 plt .figure (figsize = (self .FIGURE_WIDTH , self .FIGURE_HEIGHT ))
7698 plt .tight_layout ()
0 commit comments