@@ -289,10 +289,13 @@ def _plot_roc_curve(self, output_dir, tp_results, obs_results, ground_truth):
289289 ax .plot (fpr_c , tpr_c , linewidth = 2 , label = f"Included (AUC={ auc_c :.3f} )" )
290290 ax .plot (fpr_d , tpr_d , linewidth = 2 , label = f"Excluded (AUC={ auc_d :.3f} )" )
291291 ax .plot ([0 , 1 ], [0 , 1 ], "k--" , linewidth = 1 , label = "Random" )
292- ax .set_xlabel ("False Positive Rate" )
293- ax .set_ylabel ("True Positive Rate" )
294- ax .set_title ("ROC Curve (MathWorld identifier included vs. excluded)" )
295- ax .legend (loc = "lower right" )
292+ ax .set_xlabel ("False Positive Rate" , fontsize = 12 )
293+ ax .set_ylabel ("True Positive Rate" , fontsize = 12 )
294+ ax .set_title (
295+ "ROC Curve (MathWorld identifier included vs. excluded)" , fontsize = 14
296+ )
297+ ax .tick_params (axis = "both" , labelsize = 12 )
298+ ax .legend (loc = "lower right" , fontsize = 12 )
296299 fig .tight_layout ()
297300 fig .savefig (os .path .join (output_dir , "roc_curve.png" ), dpi = 150 )
298301 plt .close (fig )
@@ -342,13 +345,15 @@ def _plot_roc_curve_labeled(self, output_dir, tp_preds, obs_preds, limit):
342345 label = f"Without MathWorld ID (AUC={ auc_wo :.3f} )" ,
343346 )
344347 ax .plot ([0 , 1 ], [0 , 1 ], "k--" , linewidth = 1 , label = "Random" )
345- ax .set_xlabel ("False Positive Rate" )
346- ax .set_ylabel ("True Positive Rate" )
348+ ax .set_xlabel ("False Positive Rate" , fontsize = 12 )
349+ ax .set_ylabel ("True Positive Rate" , fontsize = 12 )
347350 ax .set_title (
348351 f"ROC Curve (n={ len (common_ids )} , "
349- "labels = include-MW-ID aggregated answer)"
352+ "labels = include-MW-ID aggregated answer)" ,
353+ fontsize = 14 ,
350354 )
351- ax .legend (loc = "lower right" )
355+ ax .tick_params (axis = "both" , labelsize = 12 )
356+ ax .legend (loc = "lower right" , fontsize = 12 )
352357 fig .tight_layout ()
353358 fig .savefig (os .path .join (output_dir , "roc_curve_labeled.png" ), dpi = 150 )
354359 plt .close (fig )
0 commit comments