Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions experiments/evals/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,22 @@ scorings:
- LLR.minus.score
- absLLR.plus.score

datasets:
- traitgym_mendelian_promoter
- traitgym_complex_promoter
- gnomad_promoter
- sat_mut_mpra_promoter_F9
- sat_mut_mpra_promoter_GP1BA
- sat_mut_mpra_promoter_HBB
- sat_mut_mpra_promoter_HBG1
- sat_mut_mpra_promoter_HNF4A
- sat_mut_mpra_promoter_LDLR
- sat_mut_mpra_promoter_MSMB
- sat_mut_mpra_promoter_PKLR
- sat_mut_mpra_promoter_TERT
# Dataset evaluation configurations
# Each dataset specifies which metrics and scoring functions to compute
dataset_configs:
traitgym_mendelian_promoter:
metrics: [AUPRC]
scorings: [LLR.minus.score]

traitgym_complex_promoter:
metrics: [AUPRC]
scorings: [absLLR.plus.score]

gnomad_promoter:
metrics: [AUROC]
scorings: [LLR.minus.score]

sat_mut_mpra_promoter:
# This applies to all promoter-specific datasets
metrics: [Spearman]
scorings: [absLLR.plus.score]
90 changes: 60 additions & 30 deletions experiments/evals/workflow/Snakefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,64 @@
configfile: "config/config.yaml"


def get_all_datasets():
"""Get list of all dataset names for wildcard constraints."""
datasets = []
for dataset in config["dataset_configs"].keys():
if dataset == "sat_mut_mpra_promoter":
# Expand for each promoter
for promoter in config["sat_mut_mpra_promoter"]:
datasets.append(f"sat_mut_mpra_promoter_{promoter}")
else:
datasets.append(dataset)
return datasets


def get_all_metric_files():
"""Generate list of all metric files based on dataset_configs."""
files = []

for dataset, cfg in config["dataset_configs"].items():
# Handle sat_mut_mpra_promoter specially - expand for each promoter
if dataset == "sat_mut_mpra_promoter":
for promoter in config["sat_mut_mpra_promoter"]:
dataset_name = f"sat_mut_mpra_promoter_{promoter}"
for metric in cfg["metrics"]:
for model in config["models"].keys():
for scoring in cfg["scorings"]:
files.append(
f"results/metrics/{dataset_name}/{metric}/{model}_{scoring}.tsv"
)
else:
# Regular datasets
for metric in cfg["metrics"]:
for model in config["models"].keys():
for scoring in cfg["scorings"]:
files.append(
f"results/metrics/{dataset}/{metric}/{model}_{scoring}.tsv"
)

return files


def get_all_correlation_files():
"""Generate list of all correlation analysis output files."""
return [
"results/correlations/metrics_wide.parquet",
"results/correlations/metrics_long.parquet",
"results/correlations/pearson.tsv",
"results/correlations/spearman.tsv",
"results/correlations/pearson_heatmap.png",
"results/correlations/pearson_heatmap.pdf",
"results/correlations/spearman_heatmap.png",
"results/correlations/spearman_heatmap.pdf",
"results/correlations/metrics_vs_step.png",
"results/correlations/metrics_vs_step.pdf",
"results/correlations/metric_pairs.png",
"results/correlations/metric_pairs.pdf",
]


include: "rules/common.smk"
include: "rules/gnomad.smk"
include: "rules/metrics.smk"
Expand All @@ -11,33 +69,5 @@ include: "rules/traitgym.smk"

rule all:
input:
expand(
"results/metrics/traitgym_mendelian_promoter/AUPRC/{model}_{scoring}.tsv",
model=config["models"].keys(),
scoring=[
"LLR.minus.score",
]
),
expand(
"results/metrics/traitgym_complex_promoter/AUPRC/{model}_{scoring}.tsv",
model=config["models"].keys(),
scoring=[
"absLLR.plus.score",
]
),
expand(
"results/metrics/sat_mut_mpra_promoter_{promoter}/Spearman/{model}_{scoring}.tsv",
promoter=config["sat_mut_mpra_promoter"],
model=config["models"].keys(),
scoring=[
"absLLR.plus.score",
]
),
expand(
"results/metrics/gnomad_promoter/{metric}/{model}_{scoring}.tsv",
metric=["AUPRC", "AUROC"],
model=config["models"].keys(),
scoring=[
"LLR.minus.score",
]
),
get_all_metric_files(),
get_all_correlation_files()
Loading