Skip to content

Commit 55f4f3d

Browse files
committed
fix issue where plotted threshold was b4 -log10
1 parent b8323d9 commit 55f4f3d

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

Diff for: analysis/workflow/scripts/midway_manhattan_summary.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -445,40 +445,40 @@ def main(
445445
pos_labs[:,pos_type] = True
446446
y_true = pos_labs.flatten("F")
447447
y_score = vals.flatten("F")
448-
if bic:
449-
y_score = -y_score
450448
fpr, tpr, roc_threshold = roc_curve(y_true, y_score, drop_intermediate=True)
451449
precision, recall, prc_threshold = precision_recall_curve(y_true, y_score, drop_intermediate=True)
452-
if thresh is None:
453-
with warnings.catch_warnings():
454-
# safe to ignore runtime warning caused by division of 0 by 0
455-
warnings.simplefilter("ignore")
456-
# compute FDR for each threshold
457-
# note that this computation only works because number of TN == number of TP
458-
fdr = fpr / (fpr + tpr)
459-
# set nan to 0
460-
fdr[np.isnan(fdr)] = 0
461-
# find the threshold (last index) where FDR <= 0.05
462-
thresh_idx = np.where(fdr <= 0.05)[0][-1]
463-
if bic:
464-
thresh = roc_threshold[thresh_idx]
465-
# flip the thresh back around b/c we had made y_score negative, before
466-
final_metrics["Significance Threshold"] = -thresh
467-
else:
468-
thresh = 10**(-roc_threshold[thresh_idx])
469-
final_metrics["Significance Threshold"] = thresh
450+
451+
# now, try to get the optimal threshold
452+
with warnings.catch_warnings():
453+
# safe to ignore runtime warning caused by division of 0 by 0
454+
warnings.simplefilter("ignore")
455+
# compute FDR for each threshold
456+
# note that this computation only works because number of TN == number of TP
457+
fdr = fpr / (fpr + tpr)
458+
# set nan to 0
459+
fdr[np.isnan(fdr)] = 0
460+
# find the threshold (last index) where FDR <= 0.05
461+
thresh_idx = np.where(fdr <= 0.05)[0][-1]
462+
if bic:
463+
optimal_thresh = roc_threshold[thresh_idx]
470464
else:
465+
optimal_thresh = 10**(-roc_threshold[thresh_idx])
466+
final_metrics["Significance Threshold"] = optimal_thresh
467+
if thresh is not None:
471468
thresh_idx = np.argmax(roc_threshold < tsfm_pval(thresh))
472-
final_metrics["Significance Threshold"] = thresh
469+
else:
470+
thresh = optimal_thresh
473471
roc_auc = auc(fpr, tpr)
474472
prc_ap = average_precision_score(y_true, y_score)
475473
# Find the index where thresholds > log_thresh b/c prc_threshold increases from 0 to inf
476474
prc_thresh_idx = np.argmax(prc_threshold > tsfm_pval(thresh))
475+
477476
# Ensure index is within bounds (prc_threshold is shorter with precision/recall than roc)
478477
if prc_thresh_idx >= len(precision):
479478
prc_thresh_idx = len(precision) - 1
480479
final_metrics["AUROC"] = roc_auc
481480
final_metrics["Average Precision"] = roc_auc
481+
482482
# now, make the fig
483483
fig = plt.figure(figsize=(16, 6), layout='constrained')
484484
subfigs = fig.subfigures(1, 3, wspace=0, width_ratios=(6, 5, 5))
@@ -546,16 +546,17 @@ def main(
546546
scatter_hist(vals[:,1], vals[:,0], ax, ax_histx, ax_histy, colors=colors)
547547
max_val = vals.max()
548548
curr_thresh = final_metrics["Significance Threshold"]
549-
fig.text(0.98, 0.98, f'Threshold: {curr_thresh:.2f}', ha='right', va='top', fontsize=15)
549+
threshold_type = "Bayes Factor" if bic else "P-value"
550+
fig.text(0.98, 0.98, f'{threshold_type} Threshold: {curr_thresh:.2f}', ha='right', va='top', fontsize=15)
550551
curr_thresh = tsfm_pval(curr_thresh)
551552
ax.set_xlabel(case_type + ": " + ax_labs[1])
552553
ax.set_ylabel(case_type + ": " + ax_labs[0])
553554
ax.axline((0,0), (max_val, max_val), linestyle="--", color="orange")
554-
if thresh != 0:
555-
ax.axline((0,thresh), (thresh, thresh), color="red")
556-
ax_histx.axline((thresh,0), (thresh, thresh), color="red")
557-
ax.axline((thresh,0), (thresh, thresh), color="red")
558-
ax_histy.axline((0,thresh), (thresh, thresh), color="red")
555+
if curr_thresh != 0:
556+
ax.axline((0,curr_thresh), (curr_thresh, curr_thresh), color="red")
557+
ax_histx.axline((curr_thresh,0), (curr_thresh, curr_thresh), color="red")
558+
ax.axline((curr_thresh,0), (curr_thresh, curr_thresh), color="red")
559+
ax_histy.axline((0,curr_thresh), (curr_thresh, curr_thresh), color="red")
559560
ax_histy.spines['top'].set_visible(False)
560561
ax_histx.spines['top'].set_visible(False)
561562
ax_histy.spines['right'].set_visible(False)

Diff for: happler/tree/terminator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,12 @@ def check(
284284
)
285285
return True
286286
else:
287-
if stat < self.thresh:
287+
if stat <= self.thresh:
288288
self.log.debug(
289-
f"Terminated with delta BIC {stat} and p-value {pval} >= {self.thresh}"
289+
f"Terminated with delta BIC {stat} and p-value {pval} <= {self.thresh}"
290290
)
291291
return True
292292
self.log.debug(
293-
f"Significant with delta BIC {stat} and p-value {pval} < {self.thresh}"
293+
f"Significant with delta BIC {stat} > {self.thresh}"
294294
)
295295
return False

0 commit comments

Comments
 (0)