@@ -445,40 +445,40 @@ def main(
445
445
pos_labs [:,pos_type ] = True
446
446
y_true = pos_labs .flatten ("F" )
447
447
y_score = vals .flatten ("F" )
448
- if bic :
449
- y_score = - y_score
450
448
fpr , tpr , roc_threshold = roc_curve (y_true , y_score , drop_intermediate = True )
451
449
precision , recall , prc_threshold = precision_recall_curve (y_true , y_score , drop_intermediate = True )
452
- if thresh is None :
453
- with warnings .catch_warnings ():
454
- # safe to ignore runtime warning caused by division of 0 by 0
455
- warnings .simplefilter ("ignore" )
456
- # compute FDR for each threshold
457
- # note that this computation only works because number of TN == number of TP
458
- fdr = fpr / (fpr + tpr )
459
- # set nan to 0
460
- fdr [np .isnan (fdr )] = 0
461
- # find the threshold (last index) where FDR <= 0.05
462
- thresh_idx = np .where (fdr <= 0.05 )[0 ][- 1 ]
463
- if bic :
464
- thresh = roc_threshold [thresh_idx ]
465
- # flip the thresh back around b/c we had made y_score negative, before
466
- final_metrics ["Significance Threshold" ] = - thresh
467
- else :
468
- thresh = 10 ** (- roc_threshold [thresh_idx ])
469
- final_metrics ["Significance Threshold" ] = thresh
450
+
451
+ # now, try to get the optimal threshold
452
+ with warnings .catch_warnings ():
453
+ # safe to ignore runtime warning caused by division of 0 by 0
454
+ warnings .simplefilter ("ignore" )
455
+ # compute FDR for each threshold
456
+ # note that this computation only works because number of TN == number of TP
457
+ fdr = fpr / (fpr + tpr )
458
+ # set nan to 0
459
+ fdr [np .isnan (fdr )] = 0
460
+ # find the threshold (last index) where FDR <= 0.05
461
+ thresh_idx = np .where (fdr <= 0.05 )[0 ][- 1 ]
462
+ if bic :
463
+ optimal_thresh = roc_threshold [thresh_idx ]
470
464
else :
465
+ optimal_thresh = 10 ** (- roc_threshold [thresh_idx ])
466
+ final_metrics ["Significance Threshold" ] = optimal_thresh
467
+ if thresh is not None :
471
468
thresh_idx = np .argmax (roc_threshold < tsfm_pval (thresh ))
472
- final_metrics ["Significance Threshold" ] = thresh
469
+ else :
470
+ thresh = optimal_thresh
473
471
roc_auc = auc (fpr , tpr )
474
472
prc_ap = average_precision_score (y_true , y_score )
475
473
# Find the index where thresholds > log_thresh b/c prc_threshold increases from 0 to inf
476
474
prc_thresh_idx = np .argmax (prc_threshold > tsfm_pval (thresh ))
475
+
477
476
# Ensure index is within bounds (prc_threshold is shorter with precision/recall than roc)
478
477
if prc_thresh_idx >= len (precision ):
479
478
prc_thresh_idx = len (precision ) - 1
480
479
final_metrics ["AUROC" ] = roc_auc
481
480
final_metrics ["Average Precision" ] = roc_auc
481
+
482
482
# now, make the fig
483
483
fig = plt .figure (figsize = (16 , 6 ), layout = 'constrained' )
484
484
subfigs = fig .subfigures (1 , 3 , wspace = 0 , width_ratios = (6 , 5 , 5 ))
@@ -546,16 +546,17 @@ def main(
546
546
scatter_hist (vals [:,1 ], vals [:,0 ], ax , ax_histx , ax_histy , colors = colors )
547
547
max_val = vals .max ()
548
548
curr_thresh = final_metrics ["Significance Threshold" ]
549
- fig .text (0.98 , 0.98 , f'Threshold: { curr_thresh :.2f} ' , ha = 'right' , va = 'top' , fontsize = 15 )
549
+ threshold_type = "Bayes Factor" if bic else "P-value"
550
+ fig .text (0.98 , 0.98 , f'{ threshold_type } Threshold: { curr_thresh :.2f} ' , ha = 'right' , va = 'top' , fontsize = 15 )
550
551
curr_thresh = tsfm_pval (curr_thresh )
551
552
ax .set_xlabel (case_type + ": " + ax_labs [1 ])
552
553
ax .set_ylabel (case_type + ": " + ax_labs [0 ])
553
554
ax .axline ((0 ,0 ), (max_val , max_val ), linestyle = "--" , color = "orange" )
554
- if thresh != 0 :
555
- ax .axline ((0 ,thresh ), (thresh , thresh ), color = "red" )
556
- ax_histx .axline ((thresh ,0 ), (thresh , thresh ), color = "red" )
557
- ax .axline ((thresh ,0 ), (thresh , thresh ), color = "red" )
558
- ax_histy .axline ((0 ,thresh ), (thresh , thresh ), color = "red" )
555
+ if curr_thresh != 0 :
556
+ ax .axline ((0 ,curr_thresh ), (curr_thresh , curr_thresh ), color = "red" )
557
+ ax_histx .axline ((curr_thresh ,0 ), (curr_thresh , curr_thresh ), color = "red" )
558
+ ax .axline ((curr_thresh ,0 ), (curr_thresh , curr_thresh ), color = "red" )
559
+ ax_histy .axline ((0 ,curr_thresh ), (curr_thresh , curr_thresh ), color = "red" )
559
560
ax_histy .spines ['top' ].set_visible (False )
560
561
ax_histx .spines ['top' ].set_visible (False )
561
562
ax_histy .spines ['right' ].set_visible (False )
0 commit comments