Skip to content

Latest commit

 

History

History
404 lines (291 loc) · 15.9 KB

File metadata and controls

404 lines (291 loc) · 15.9 KB

Train from Scratch (SeiPlant / sei_model)

This section describes how to train a SeiPlant model from scratch, including data preparation, label generation, training, and evaluation.

The overall workflow is consistent with Fig. 1: genome tiling → chromatin label construction → model training → evaluation & genome-wide inference.

What You Need (Inputs)

To train SeiPlant for a single species (or a species group), you need:

  1. Reference genome

    • species.fa (FASTA)
    • species.size (chromosome sizes) The Zenodo file used in this study contains the fa and size files for Arabidopsis thaliana.
  2. Chromatin tracks or peak labels for each epigenomic feature (per tissue/condition if applicable)

    • Recommended formats:

      • BED/narrowPeak/broadPeak peak files (e.g., MACS2 outputs), optionally converted to window-level labels
    • Typical histone marks (example):

      • H3K4ME3, H3K27AC, H3K4ME1, H3K9AC, H3K36ME3
    • The same workflow can be applied to other sequence-to-signal tasks, such as ATAC-seq/DNase-seq accessibility or TF ChIP-seq binding signals, as long as the final labels can be aligned to fixed genomic windows.

Step 1: Generate Genomic Windows (BED + FASTA)

In this project, genomic windows are constructed based on regulatory regions (e.g., peaks) and the reference genome sequence.

The BED cutting strategy used in this study requires:

  • the peak BED file(s) (e.g., narrowPeak.bed produced by MACS2), and
  • the corresponding genome FASTA sequence.

If multiple peak BED files are available for the same species (e.g., from different tissues, replicates, or peak-calling strategies), you may merge them into a unified regulatory region set before window generation. For example:

  • use bedtools merge to merge overlapping intervals, consistent with the merging strategy described in the Experimental Details section.

Expected outputs:

  • data/windows/species_1024_128_filtered.bed
  • data/windows/species_1024_128.fa

Note

Even if signals come from multiple tissues, the DNA sequence is identical. Merging BED intervals helps build a shared candidate regulatory space for window sampling.

Step 2: Build Multi-Task Labels (Per Window, Per Mark)

SeiPlant is trained in a multi-task setting: each 1,024-bp window is associated with a label vector across multiple chromatin features (histone marks, accessibility, TF binding, etc.).

After window generation, the training dataset is constructed by pairing:

  • the window sequences (.fa), and
  • the window coordinates (.bed),

with a feature scoring strategy (see the Experimental Details section for the exact scoring definition).

As a practical reference implementation, you may also check the example script:

Two common label strategies

A) Peak-based binary labels (recommended for robust training)

  • For each mark, assign label = 1 if the window overlaps a peak (≥ X bp overlap), otherwise label = 0.

B) Signal-based continuous labels

  • For each mark, summarize the BigWig signal within each window (e.g., mean/max/center-bin aggregation), then apply optional normalization.

Dataset outputs (FA + tag file)

The dataset generated by the label-building program typically consists of two parts:

  1. A training FASTA-like file (two lines per record):

Below is an example of the training FASTA-like file generated by our label construction pipeline.

Each record consists of two lines:

  • Header line: a comma-separated list of MARK_SCORE pairs, followed by ::CHR:START-END
  • Sequence line: the corresponding 1,024 bp DNA sequence (A/T/C/G)

Important

The order and the set of marks in the header must be consistent with tag.txt (i.e., histone_modification_tag.txt).

For example, if you include H3K27ME3 in the header, it must also appear in the tag file.

Example:

>H3K27AC_0.5,H3K27ME3_0.3,H3K36ME3_1.0,H3K9AC_0.6::Chr1:3584-4608
ATAAATTGTCTTATTTAAACGCTGACTTCACTGTCTTCCTCCCTCCAAATTATTAGATATACCAAACCAGAGAAAACAAATACATAATCGGAGAAATACAGATTACAGAGAGCGAGAGAGATCGACGGCGAAGCTCTTTACCCGGAAACCATTGAAATCGGACGGTTTAGTGAAAATGGAGGATCAAGTTGGGTTTGGGTTCCGTCCGAACGACGAGGAGCTCGTTGGTCACTATCTCCGTAACAAAATCGAAGGAAACACTAGCCGCGACGTTGAAGTAGCCATCAGCGAGGTCAACATCTGTAGCTACGATCCTTGGAACTTGCGCTGTAAGTTCCGAATTTTCTGAATTTCATTTGCAAGTAATCGATTTAGGTTTTTGATTTTAGGGTTTTTTTTTGTTTTGAACAGTCCAGTCAAAGTACAAATCGAGAGATGCTATGTGGTACTTCTTCTCTCGTAGAGAAAACAACAAAGGGAATCGACAGAGCAGGACAACGGTTTCTGGTAAATGGAAGCTTACCGGAGAATCTGTTGAGGTCAAGGACCAGTGGGGATTTTGTAGTGAGGGCTTTCGTGGTAAGATTGGTCATAAAAGGGTTTTGGTGTTCCTCGATGGAAGATACCCTGACAAAACCAAATCTGATTGGGTTATCCACGAGTTCCACTACGACCTCTTACCAGAACATCAGGTTTTCTTCTATTCATATATATATATATATATATATGTGGATATATATATATGTGGTTTCTGCTGATTCATAGTTAGAATTTGAGTTATGCAAATTAGAAACTATGTAATGTAACTCTATTTAGGTTCAGCAGCTATTTTAGGCTTAGCTTACTCTCACCAATGTTTTATACTGATGAACTTATGTGCTTACCTCCGGAAATTTTACAGAGGACATATGTCATCTGCAGACTTGAGTACAAGGGTGATGATGCGGACATTCTATCTGCTTATGCAATAGATCCCACTCCCGCTTTTGTCCCCAATATGACTAGTAGTGCAGGTTCTGTGGTG
>H3K27AC_0.7,H3K27ME3_0.3,H3K36ME3_0.8::Chr1:8192-9216
GCAATGCTTGAAATCAAGAACTTGAATTGAAATAGTTTTTTACCTGAATATTGACAGTTGCTGGATTAATTGCATTGTAGAGGACGTGTCTATATACCTTTGGTCTGTGAAGGATTAAATCGATGAAAATAATCTGCCAAAGAAAACAATTAAAGAACCAAAAACCAAAATTGGAAAGAAATAGGGAAACACCCAAAAAGGGAAAGAAAGTGATTAAAACAGACCATGCGTTCACACTCGATGTACTCATCTGCTACTTCCTTGCAATTTCCCTAAATATAACAATATGATCAAAGATGGAAACTTTGAAGAAATTTAATAGAGAATCTTATAAACCCTAATTGGGTCAAAGAAGATCCATTAATACAAAAATCTTACGCATTTCATGAGACGAATGTTACCCGGAGAGTATTGAATGAACAATGACTTTACCCTAAAACCACATCCCACGCATCTGTGTTCACTCGCCGCCATTGCTCTCTCTCTCTCTCTCTCTCTCTCTCTCTCAAGAGAAGAAGAATACGGAGCAATTAGAGTCCGGGTCTGGGCTACTGTTTTAACCCTAAATGGGCTTATTCATGGGCCAAGTTTTTGAAGTCTTAACTTTAAATTTGTTAGGCCCACTTTTGCTCTAAGCCGGGGTATTTGTACCCCAAAATTTAAAAATCATATACACGTTGTAATTTATAAATAGTTCAATTTGGATCAAAATCTTGTCCATATGACATAGCATTTTAAAATGCGTAGGTTCATGAATGAAACATATTATAGGCCTCAGATAAAGATATACATATTAAGTCTAAATTATTTAGTCTTCAGAATTTACCACACTTACTGAAAAGTCTAGTGGTTCACTAATATTATTACTGTCGTGTTACTTTCTATATATAGTTCATGACTTGTGAGTTGTGATGGATAAGTTTATAAGAAAATAAATTATTTATTACAATTCAACAGTGAAGAAATTTATTTAGTTTGATTAAATAAGAAAGGTAAATAAATCTTCGTTTGCCACACCAAACAA

A header like:

H3K27AC_0.5,H3K27ME3_0.3,H3K36ME3_1.0,H3K9AC_0.6::Chr1:3584-4608

means that the sequence record corresponds to the genomic interval Chr1:3584–4608 (1,024 bp), and the values (0.5, 0.3, 1.0, 0.6) are the per-mark label scores assigned to this window.

  • MARK_SCORE pairs encode the regulatory signal strength for each chromatin feature within this interval.
  • The score can be defined in different ways depending on your labeling strategy, e.g.:
    • binary labels in {0,1} (peak overlap / presence–absence), or
    • continuous scores normalized to [0,1] (e.g., coverage / aggregated BigWig signal).

Note

Regardless of whether scores are binary or continuous, the set and order of marks must be consistent with histone_modification_tag.txt, because it defines the output channel order used in both training and inference.

Save as:

  • <SPECIES>_1024_512.training.fa
  1. A tag.txt file specifying the task order of the multi-label outputs. This order is critical: it must match the model output channel order in both training and prediction.

Create a tag file describing the task order (this MUST match training outputs and inference tag order):

H3K4ME3
H3K27AC
H3K4ME1
H3K9AC
H3K36ME3

Save as:

  • histone_modification_tag.txt

Step 3: Train / Validate / Test (Stability Demonstration)

In this step, we train the model on the training split, monitor performance on the validation split, and finally evaluate on a held-out test split to demonstrate the stability and generalization of SeiPlant.

This script assumes you have already generated:

  • Step 1 outputs

    • `data/windows/merged_{species}_1024_512.fa
    • data/labels/histone_modification_tag.txt
  • Step 2 outputs

    • data/labels/<SPECIES>_1024_128_labels.npy (or your FASTA-like label-encoded file, depending on your implementation)
    • data/labels/histone_modification_tag.txt

Note

We recommend splitting data by chromosomes to avoid local sequence leakage between train/val/test.

Example training script (template)

import os
import random
import numpy as np
import torch

from utils.data_utils import (
    load_and_preprocess_data,
    split_data_by_chromosome,
    create_data_loaders,
)
from scripts.train import train_model
from scripts.evaluate import evaluate_model
from utils.data import NucDataset

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def experiment():
    """
    End-to-end experiment:
    load -> chromosome split -> train/val -> test evaluation
    """

    # =========================
    # 0) Inputs from Step 1/2
    # =========================
    # FASTA-like training records generated in Step 2
    fasta_file = "<SPECIES>_1024_512.training.fa"
    # Tag file generated in Step 2 (defines multi-task output order)
    tag_file = "histone_modification_tag.txt"

    # =========================
    # 1) Hyperparameters
    # =========================
    seq_len = 1024
    batch_size = 256
    epochs = 20
    lr = 1e-5
    seed = 42

    # =========================
    # 2) Output locations
    # =========================
    model_root = "models/saved_models/<SPECIES>/"
    result_root = "experiments/<SPECIES>/predictions/"
    model_pattern = "<SPECIES>_{}_feature{}"  # (seq_len, num_tasks)

    os.makedirs(model_root, exist_ok=True)
    os.makedirs(result_root, exist_ok=True)

    # =========================
    # 3) Reproducibility
    # =========================
    set_seed(seed)

    # =========================
    # 4) Load & preprocess
    # =========================
    x_all, labels, pos, tag_dict = load_and_preprocess_data(fasta_file, tag_file)

    # =========================
    # 5) Chromosome-based split
    # =========================
    print("Splitting data by chromosome...")
    x_train, y_train, x_val, y_val, x_test, y_test, tag_dict = split_data_by_chromosome(
        x_all, labels, pos, tag_dict
    )

    # =========================
    # 6) DataLoaders
    # =========================
    print("Creating data loaders...")
    train_loader, val_loader = create_data_loaders(
        x_train, y_train, x_val, y_val, batch_size=batch_size
    )

    # =========================
    # 7) Device
    # =========================
    # Use any available GPU, or CPU if CUDA is not available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # =========================
    # 8) Train
    # =========================
    print("Training model...")
    model_name = model_pattern.format(seq_len, len(tag_dict)) + ".model"
    model_path = os.path.join(model_root, model_name)

    train_model(
        train_loader=train_loader,
        val_loader=val_loader,
        tag_dict=tag_dict,
        seq_len=seq_len,
        device=device,
        model_dir=model_path,
        epoch=epochs,
        lr=lr,
        batch_size=batch_size,
    )

    # =========================
    # 9) Test evaluation
    # =========================
    print("Evaluating model on test set...")
    test_dataset = NucDataset(x=x_test, y=y_test)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
    )

    pred_name = model_pattern.format(seq_len, len(tag_dict)) + ".test_predictions.npy"
    pred_path = os.path.join(result_root, pred_name)

    evaluate_model(
        test_loader=test_loader,
        device=device,
        model_dir=model_path,
        tag_dict=tag_dict,
        prediction_dir=pred_path,
        seq_len=seq_len,
    )

    print(f"Saved model to: {model_path}")
    print(f"Saved test predictions to: {pred_path}")

if __name__ == "__main__":
    experiment()

Step 4: Cross-Species Inference on a New Species (Model → Genome-wide Signals)

In this step, we use the trained model from Step 3 and perform cross-species prediction on a new species.

Given the new species reference genome (.fa + .size), we tile the genome into fixed windows (same seq_len and step size as training), run inference using the pretrained model, and export predicted chromatin signals into genome-browser-friendly formats (BedGraph/BigWig).

Inputs

  • Trained model checkpoint (from Step 3)

    • models/saved_models/<TRAIN_SPECIES>/<TRAIN_SPECIES>_<SEQ_LEN>_feature<TASK_NUM>.model
  • Tag file (same as Step 2; defines output channel order)

    • data/labels/histone_modification_tag.txt
  • New species reference genome

    • scripts/fasta/<NEW_SPECIES>.fa
    • scripts/fasta/<NEW_SPECIES>.size

Important

In our training setting, we use 1024 bp sequences with a 512 bp stride (1024_512) to construct training samples.

For cross-species genome-wide inference, we instead use 1024 bp sequences with a 128 bp stride (1024_128).

  • The model input length remains fixed at 1024 bp (must match the trained model).
  • The stride controls the prediction resolution along the genome. Using a smaller stride (128 bp) yields denser windows and therefore a more fine-grained predicted track.
  • Conceptually, we use a 1024 bp context window to predict the chromatin signal at the central 128 bp region (or equivalently, assign the window’s prediction to the window center), so that the exported BedGraph/BigWig signal is as detailed as possible.

Tip

Training with a coarser stride (1024_512) improves efficiency, while inference with a finer stride (1024_128) improves track resolution without changing the model architecture.


Step 4.1: Tile the New Genome into Windows (BED + FASTA)

python scripts/make_prediction_bed.py \
  --fasta scripts/fasta/<NEW_SPECIES>.fa \
  --size  scripts/fasta/<NEW_SPECIES>.size \
  --species <NEW_SPECIES> \
  --output_path data/windows/ \
  --window_size <SEQ_LEN> \
  --step_size <STEP_SIZE>

Expected outputs:

  • data/windows/<NEW_SPECIES>_<SEQ_LEN>_<STEP_SIZE>_filtered.bed
  • data/windows/<NEW_SPECIES>_<SEQ_LEN>_<STEP_SIZE>.fa

Step 4.2: Run Prediction Using the Trained SeiPlant Model

Use the Step 3 trained model to predict multi-task chromatin signals for each genomic window in the new species.

python scripts/prediction.py \
  --model_path models/saved_models/<TRAIN_SPECIES>/<TRAIN_SPECIES>_<SEQ_LEN>_feature<TASK_NUM>.model \
  --model_tag_file data/labels/histone_modification_tag.txt \
  --species <NEW_SPECIES> \
  --fa_path data/windows/<NEW_SPECIES>_<SEQ_LEN>_<STEP_SIZE>.fa \
  --output_dir outputs/predictions/<NEW_SPECIES>/ \
  --bed_file data/windows/<NEW_SPECIES>_<SEQ_LEN>_<STEP_SIZE>_filtered.bed \
  --seq_len <SEQ_LEN> \
  --batch_size <BATCH_SIZE>

Output:

  • a .npy file containing prediction scores aligned to each genomic window (multi-task output)

Step 4.3: Convert Predictions to BedGraph / BigWig

To visualize predicted tracks in genome browsers, convert per-mark predictions into BedGraph and optionally BigWig.

  1. Convert predicted arrays to per-mark BedGraph (implementation may vary):
python scripts/make_bedgraph.py \
  --pred_npy outputs/predictions/<NEW_SPECIES>/predictions.npy \
  --bed_file data/windows/<NEW_SPECIES>_<SEQ_LEN>_<STEP_SIZE>_filtered.bed \
  --tag_file data/labels/histone_modification_tag.txt \
  --out_dir outputs/bedgraph/<NEW_SPECIES>/
  1. Convert BedGraph → BigWig (UCSC utility):
bedGraphToBigWig \
  outputs/bedgraph/<NEW_SPECIES>/<MARK>.bedgraph \
  scripts/fasta/<NEW_SPECIES>.size \
  outputs/bigwig/<NEW_SPECIES>/<MARK>.bw

Note

BigWig conversion requires that the BedGraph is sorted by chromosome and coordinate, and that <NEW_SPECIES>.size matches the reference genome assembly used for window tiling.