Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions configs/data/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ max_val_lm_samples: null # Maximum number of samples for LM validation (null = u
seed: 42

# Evaluation datasets (optional)
# Set evals: null to disable all evals, or configure specific evals below
evals:
traitgym_mendelian_promoter:
dataset_name: songlab/TraitGym
dataset_config: mendelian_traits
genome_url: https://ftp.ensembl.org/pub/release-115/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna_sm.toplevel.fa.gz
genome_path: data/Homo_sapiens.GRCh38.dna_sm.toplevel.fa.gz
window_size: 512
batch_size: 128
# Set evals: null to disable all evals, or configure specific evals in dataset-specific configs
# Example structure:
# evals:
# - name: eval_name
# dataset_name: songlab/TraitGym
# dataset_config: mendelian_traits
# split: test # Dataset split to load (default: "test")
# genome_url: https://ftp.ensembl.org/...
# filter_name: traitgym_promoter # Filter from EVAL_FILTERS registry (default: "none")
# window_size: 512
# batch_size: 128
# label_column: label # Column to preserve as labels (default: "label")
# transform: minus # Transform to apply to raw LLR: minus, identity, abs (default: identity)
# metrics: [auprc] # Metrics to compute: auprc, auroc, spearman, pearson (default: [auprc])
evals: null
14 changes: 14 additions & 0 deletions configs/data/gpn_animal_promoter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,17 @@ defaults:
batch_size: 2048 # Total effective batch size
per_device_batch_size: 128 # Batch size per device (adjust based on GPU memory)
num_workers: 8

# Evaluation datasets
evals:
- name: traitgym_mendelian_promoter
dataset_name: songlab/TraitGym
dataset_config: mendelian_traits
split: test # Dataset split to load (default: "test")
genome_url: https://ftp.ensembl.org/pub/release-115/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna_sm.toplevel.fa.gz
filter_name: traitgym_promoter
window_size: 512
batch_size: 128
label_column: label # Column to preserve as labels (default: "label")
transform: minus # Transform to apply to raw LLR: minus, identity, abs (default: identity)
metrics: [auprc] # Metrics to compute: auprc, auroc, spearman, pearson (default: [auprc])
24 changes: 24 additions & 0 deletions configs/data/plants.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
defaults:
- default

# Training dataset: Angiosperm 16 genomes
dataset_name: kuleshov-group/Angiosperm_16_genomes

# Batch size configuration
batch_size: 2048 # Total effective batch size
per_device_batch_size: 128 # Batch size per device (adjust based on GPU memory)
num_workers: 8

# Evaluation datasets
evals:
- name: maize_af
dataset_name: plantcad/maize-allele-frequency
dataset_config: null
split: validation
genome_url: https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-62/fasta/zea_mays/dna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.dna_sm.toplevel.fa.gz
filter_name: none
window_size: 512
batch_size: 128
label_column: AF # Allele frequency column
transform: identity # No transform for regression (default: identity)
metrics: [pearson, spearman] # Correlation metrics for regression task
1 change: 1 addition & 0 deletions configs/experiment/clm_transformer_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Short training run with small Transformer encoder for quick testing

defaults:
- override /data: plants
- override /model: clm_transformer_small

logger:
Expand Down
1 change: 1 addition & 0 deletions configs/model/bert_bytenet_small.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_target_: glm_experiments.models.lm_lit_module.MLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.MLM
Expand Down
1 change: 1 addition & 0 deletions configs/model/clm_transformer_base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_target_: glm_experiments.models.lm_lit_module.CLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.CLM
Expand Down
1 change: 1 addition & 0 deletions configs/model/clm_transformer_small.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_target_: glm_experiments.models.lm_lit_module.CLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.CLM
Expand Down
1 change: 1 addition & 0 deletions configs/model/gpn_animal_promoter.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_target_: glm_experiments.models.lm_lit_module.MLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.MLM
Expand Down
1 change: 1 addition & 0 deletions configs/model/mlm_transformer_base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_target_: glm_experiments.models.lm_lit_module.MLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.MLM
Expand Down
1 change: 1 addition & 0 deletions configs/model/mlm_transformer_small.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_target_: glm_experiments.models.lm_lit_module.MLMLitModule

soft_masked_weight: ${data.soft_masked_weight}
evals: ${data.evals}

net:
_target_: glm_experiments.models.components.lm.MLM
Expand Down
186 changes: 186 additions & 0 deletions glm_experiments/data/evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Evaluation dataset loading for variant effect prediction.

This module provides functions to load and transform evaluation datasets
(TraitGym, PlantCAD, etc.) for variant effect prediction during training.

Supports configurable dataset-specific filtering via a registry pattern.
"""

import logging
import urllib.request
from functools import partial
from pathlib import Path

import pandas as pd
from biofoundation.data import Genome, transform_llr_clm, transform_llr_mlm
from biofoundation.model.base import Tokenizer
from datasets import Dataset, load_dataset

log = logging.getLogger(__name__)


def filter_traitgym_promoter(dataset: Dataset) -> Dataset:
"""Filter TraitGym dataset to non-exonic proximal promoter variants.

Args:
dataset: TraitGym dataset from HuggingFace

Returns:
Filtered dataset containing only non-exonic proximal promoter variants
"""
log.info("Filtering to non-exonic proximal promoter variants...")
subset_url = (
"https://huggingface.co/datasets/songlab/TraitGym/resolve/main/"
"mendelian_traits_matched_9/subset/nonexonic_AND_proximal.parquet"
)
V = dataset.to_pandas()
subset = pd.read_parquet(subset_url)
V = V.merge(subset, on=["chrom", "pos", "ref", "alt"], how="inner")
log.info(f"Filtered dataset size: {len(V)} variants (from {len(dataset)})")
return Dataset.from_pandas(V, preserve_index=False)


# Registry mapping filter names to filter functions
EVAL_FILTERS = {
"traitgym_promoter": filter_traitgym_promoter,
"none": lambda dataset: dataset, # No-op filter
}


def download_genome(url: str, data_dir: str | Path = "data") -> Path:
"""Download reference genome if not already present.

Path is auto-derived from URL basename (e.g., genome.fa.gz).

Args:
url: URL to download genome from (e.g., Ensembl FTP)
data_dir: Directory to save genome (default: "data")

Returns:
Path to the downloaded genome file
"""
path = Path(data_dir) / Path(url).name
if path.exists():
log.info(f"Genome already exists at {path}, skipping download")
return path

log.info(f"Downloading genome from {url} to {path}...")
path.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(url, path) # nosec B310
log.info(f"Genome download complete: {path}")
return path


def load_eval_dataset(
tokenizer: Tokenizer,
dataset_name: str,
genome_url: str,
filter_name: str = "none",
dataset_config: str | None = None,
split: str = "test",
window_size: int = 512,
cache_dir: str | Path = "data/evals_cache",
objective: str = "mlm",
data_dir: str | Path = "data",
label_column: str = "label",
) -> Dataset:
"""Load and transform evaluation dataset with optional filtering.

Loads a variant dataset from HuggingFace, applies optional dataset-specific
filtering, and transforms it using the appropriate objective (MLM or CLM).

The Genome is only loaded if the transformed dataset is not cached.

Args:
tokenizer: Tokenizer implementing the biofoundation Tokenizer protocol
dataset_name: HuggingFace dataset name (e.g., "songlab/TraitGym")
genome_url: URL to reference genome (path auto-derived from basename)
filter_name: Name of filter function in EVAL_FILTERS registry (default: "none")
dataset_config: Dataset configuration (e.g., "mendelian_traits")
split: Dataset split to load (default: "test")
window_size: Size of the window around variants (must be even)
cache_dir: Directory to cache transformed dataset
objective: Training objective ("mlm" or "clm") - determines transform function
data_dir: Directory for genome downloads (default: "data")
label_column: Name of the label column to preserve (default: "label")

Returns:
Transformed dataset with columns: input_ids, pos, ref, alt, {label_column}

Raises:
ValueError: If filter_name not in EVAL_FILTERS or objective not mlm/clm
"""
from datasets import load_from_disk

# Validate filter name
if filter_name not in EVAL_FILTERS:
raise ValueError(
f"Unknown filter_name: {filter_name}. " f"Must be one of {list(EVAL_FILTERS.keys())}"
)

# Create cache path based on config, filter, split, and objective
cache_name_parts = [dataset_name.replace("/", "_")]
if dataset_config:
cache_name_parts.append(dataset_config)
cache_name_parts.extend([split, filter_name, f"window{window_size}", objective])
cache_name = "_".join(cache_name_parts)
cache_path = Path(cache_dir) / cache_name

# Check if cached transformed dataset exists
if cache_path.exists():
log.info(f"Loading cached evaluation dataset from {cache_path}")
dataset = load_from_disk(str(cache_path))
dataset.set_format(type="torch")
return dataset

# Not cached - need to transform with genome
log.info(f"Loading evaluation dataset: {dataset_name}")
if dataset_config:
log.info(f" Dataset config: {dataset_config}")
log.info(f" Split: {split}")
dataset = load_dataset(dataset_name, dataset_config, split=split) # nosec B615

# Apply dataset-specific filtering
if filter_name != "none":
log.info(f" Applying filter: {filter_name}")
dataset = EVAL_FILTERS[filter_name](dataset)

# Download genome (auto-derives path from URL)
genome_path = download_genome(genome_url, data_dir)
log.info(f"Loading reference genome from {genome_path} (this may take a minute)...")
genome = Genome(genome_path)

# Keep original columns for evaluation
original_columns = dataset.column_names

# Select transform function based on objective
if objective == "mlm":
transform_func = transform_llr_mlm
elif objective == "clm":
transform_func = transform_llr_clm
else:
raise ValueError(f"Unknown objective: {objective}. Must be 'mlm' or 'clm'.")

transform_fn = partial(
transform_func,
tokenizer=tokenizer,
genome=genome,
window_size=window_size,
)

# Transform the dataset
log.info("Transforming evaluation dataset...")
dataset = dataset.map(
transform_fn,
remove_columns=[c for c in original_columns if c != label_column],
)

# Save to cache
log.info(f"Saving transformed dataset to {cache_path}")
cache_path.parent.mkdir(parents=True, exist_ok=True)
dataset.save_to_disk(str(cache_path))

# Set format to PyTorch tensors for proper DataLoader collation
dataset.set_format(type="torch")

return dataset
Loading