Skip to content

Commit a9560de

Browse files
Add gnomad
1 parent 334c1a9 commit a9560de

File tree

5 files changed

+35
-1
lines changed

5 files changed

+35
-1
lines changed

experiments/evals/config/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ scorings:
3333
datasets:
3434
- traitgym_mendelian_promoter
3535
- traitgym_complex_promoter
36+
- gnomad_promoter
3637
- sat_mut_mpra_promoter_F9
3738
- sat_mut_mpra_promoter_GP1BA
3839
- sat_mut_mpra_promoter_HBB

experiments/evals/workflow/Snakefile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ configfile: "config/config.yaml"
22

33

44
include: "rules/common.smk"
5+
include: "rules/gnomad.smk"
56
include: "rules/metrics.smk"
67
include: "rules/model.smk"
78
include: "rules/sat_mut_mpra.smk"
@@ -32,3 +33,11 @@ rule all:
3233
"absLLR.plus.score",
3334
]
3435
),
36+
expand(
37+
"results/metrics/gnomad_promoter/{metric}/{model}_{scoring}.tsv",
38+
metric=["AUPRC", "AUROC"],
39+
model=config["models"].keys(),
40+
scoring=[
41+
"LLR.minus.score",
42+
]
43+
),

experiments/evals/workflow/rules/common.smk

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import gpn.model # noqa: F401 # Registers the GPN architecture
77
import numpy as np
88
import pandas as pd
99
from scipy.stats import spearmanr
10-
from sklearn.metrics import average_precision_score
10+
from sklearn.metrics import average_precision_score, roc_auc_score
1111
from transformers import AutoTokenizer, AutoModelForMaskedLM
1212

1313

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
rule gnomad_promoter_dataset:
2+
output:
3+
"results/dataset/gnomad_promoter.parquet",
4+
run:
5+
V = pd.read_parquet("hf://datasets/songlab/gnomad_balanced/test.parquet")
6+
V = V[V.consequence == "upstream_gene"]
7+
V = V.groupby("label").sample(n=5000, random_state=42).reset_index(drop=True)
8+
V = V.sort_values(COORDINATES)
9+
V.to_parquet(output[0], index=False)

experiments/evals/workflow/rules/metrics.smk

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@ rule metrics_AUPRC:
1313
pd.DataFrame({"AUPRC": [AUPRC]}).to_csv(output[0], sep="\t", index=False, float_format="%.3f")
1414

1515

16+
rule metrics_AUROC:
17+
input:
18+
"results/dataset/{dataset}.parquet",
19+
"results/prediction/{dataset}/{model}.parquet",
20+
output:
21+
"results/metrics/{dataset}/AUROC/{model}.tsv",
22+
wildcard_constraints:
23+
dataset="|".join(config["datasets"]),
24+
run:
25+
y_true = pd.read_parquet(input[0], columns=["label"]).label
26+
y_pred = pd.read_parquet(input[1], columns=["score"]).score
27+
AUROC = roc_auc_score(y_true, y_pred)
28+
pd.DataFrame({"AUROC": [AUROC]}).to_csv(output[0], sep="\t", index=False, float_format="%.3f")
29+
30+
1631
rule metrics_Spearman:
1732
input:
1833
"results/dataset/{dataset}.parquet",

0 commit comments

Comments
 (0)