Skip to content

regev-lab/interpretable-splicing-model-torch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch implementation of pre-trained splicing model from "Deciphering RNA splicing logic with interpretable machine learning" (Liao et al., 2023). The manuscript is available here. Train, test sequences are provided in the data directory.

Setup

Requirements

Python-side requirements:

  • Python 3.9+
  • pandas
  • numpy
  • torch

Install them with:

pip install pandas numpy torch

System requirement:

  • ViennaRNA RNAfold

RNAfold Installation

See the ViennaRNA GitHub for instructions for installing RNAFold. The preprocessing code expects RNAfold to be available on PATH, unless you pass an explicit rnafold_bin path.

Usage

Expected Input Shapes

The model expects channel-first inputs throughout:

  • sequence one-hot: (N, 4, L)
  • wobble: (N, 1, L)
  • structure: (N, 3, L)

In this repository, preprocessing utilities already return arrays in those shapes.

From CSV to Full Inputs

The canonical input sequence is an unflanked exon in a column named exon. Other columns are allowed and are preserved as metadata in the output dataset. If your sequence column has a different name, pass it explicitly.

By default, preprocessing adds the fixed model flanks:

  • left flank: CATCCAGGTT
  • right flank: CAGGTCTGAC

That means a 70 nt exon becomes a 90 nt model input.

If you do not want flanks added, use add_flanks=False in Python or --no-flanks in the CLI. In that case, L is just the input sequence length.

Dataset Preparation

You can prepare a dataset directly from a CSV file containing an 'exon' column.

python prepare_dataset.py \
  --input-csv input.csv \
  --output-path dataset.npz

If your sequence column is not exon:

python prepare_dataset.py \
  --input-csv input.csv \
  --output-path dataset.npz \
  --sequence-column sequence

Optional RNAfold-related arguments: --rnafold-bin, --temperature, --max-bp-span, --num-threads and --commands-file.

The output is a compressed .npz archive containing the following keys (stored as numpy arrays). Specifically, it contains the fields seq_oh for one-hot-encoded sequences, struct_oh for one-hot-encoded structures, wobble for wobble-pair arrays. All additional dataframe columns are stored with the "metadata_" prefix as (e.g. dataset["metadata_PSI"]).

Load dataset

import numpy as np

dataset = np.load("your_dataset_path.npz")

print(dataset["seq_oh"].shape)      # (2, 4, 90)
print(dataset["struct_oh"].shape)   # (2, 3, 90)
print(dataset["wobbles"].shape)     # (2, 1, 90)
print(dataset["structure"].shape)   # (2,)
print(dataset["mfe"].shape)         # (2,)

Running the Model

After preprocessing, convert the NumPy arrays to torch tensors and pass them into PNASModel.forward():

Initialize PNASModel with an input_length that matches the prepared sequence length L. If flanks are added, this length includes the flanking nucleotides.

import torch
from model import PNASModel

x_seq = torch.tensor(dataset["seq_oh"], dtype=torch.float32)
x_struct = torch.tensor(dataset["struct_oh"], dtype=torch.float32)
x_wobble = torch.tensor(dataset["wobbles"], dtype=torch.float32)

model = PNASModel(input_length=x_seq.shape[-1])
state_dict = torch.load("model_weights.pt", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

with torch.no_grad():
    prediction = model(x_seq, x_struct, x_wobble)

Sequence-Only Analysis

To inspect sequence properties such as SR Balance and latent sequence activation, you can use the model methods compute_sr_balance, compute_sequence_activations().

# Get one-hot sequences
x_seq = torch.tensor(dataset["seq_oh"], dtype=torch.float32)

# Compute inclusion, skipping sequence activations
a_incl, a_skip = model.compute_sequence_activations(x_seq, agg="mean")

# Compute SR balance
sr_balance = model.compute_sr_balance(x_seq, agg="mean")

Notes

  • The default public preprocessing path assumes unflanked exon input and adds model flanks automatically.
  • load_state_dict() in PNASModel resamples position-bias tensors when checkpoint and runtime input lengths differ.
  • load_weights_from_dict() is available for loading weights converted from an external TensorFlow/Keras export format.

Citation

Please cite: Liao, Susan E., Mukund Sudarshan, and Oded Regev. "Deciphering RNA splicing logic with interpretable machine learning." Proceedings of the National Academy of Sciences 120.41 (2023): e2221165120.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages