PyTorch implementation of pre-trained splicing model from "Deciphering RNA splicing logic with interpretable machine learning" (Liao et al., 2023). The manuscript is available here. Train, test sequences are provided in the data directory.
Python-side requirements:
- Python 3.9+
pandasnumpytorch
Install them with:
pip install pandas numpy torchSystem requirement:
- ViennaRNA
RNAfold
See the ViennaRNA GitHub for instructions for installing RNAFold. The preprocessing code expects RNAfold to be available on PATH, unless you pass an explicit rnafold_bin path.
The model expects channel-first inputs throughout:
- sequence one-hot:
(N, 4, L) - wobble:
(N, 1, L) - structure:
(N, 3, L)
In this repository, preprocessing utilities already return arrays in those shapes.
The canonical input sequence is an unflanked exon in a column named exon. Other columns are allowed and are preserved as metadata in the output dataset. If your sequence column has a different name, pass it explicitly.
By default, preprocessing adds the fixed model flanks:
- left flank:
CATCCAGGTT - right flank:
CAGGTCTGAC
That means a 70 nt exon becomes a 90 nt model input.
If you do not want flanks added, use add_flanks=False in Python or --no-flanks in the CLI. In that case, L is just the input sequence length.
You can prepare a dataset directly from a CSV file containing an 'exon' column.
python prepare_dataset.py \
--input-csv input.csv \
--output-path dataset.npzIf your sequence column is not exon:
python prepare_dataset.py \
--input-csv input.csv \
--output-path dataset.npz \
--sequence-column sequenceOptional RNAfold-related arguments: --rnafold-bin, --temperature, --max-bp-span, --num-threads and --commands-file.
The output is a compressed .npz archive containing the following keys (stored as numpy arrays). Specifically, it contains the fields seq_oh for one-hot-encoded sequences, struct_oh for one-hot-encoded structures, wobble for wobble-pair arrays. All additional dataframe columns are stored with the "metadata_" prefix as (e.g. dataset["metadata_PSI"]).
import numpy as np
dataset = np.load("your_dataset_path.npz")
print(dataset["seq_oh"].shape) # (2, 4, 90)
print(dataset["struct_oh"].shape) # (2, 3, 90)
print(dataset["wobbles"].shape) # (2, 1, 90)
print(dataset["structure"].shape) # (2,)
print(dataset["mfe"].shape) # (2,)After preprocessing, convert the NumPy arrays to torch tensors and pass them into PNASModel.forward():
Initialize PNASModel with an input_length that matches the prepared sequence length L. If flanks are added, this length includes the flanking nucleotides.
import torch
from model import PNASModel
x_seq = torch.tensor(dataset["seq_oh"], dtype=torch.float32)
x_struct = torch.tensor(dataset["struct_oh"], dtype=torch.float32)
x_wobble = torch.tensor(dataset["wobbles"], dtype=torch.float32)
model = PNASModel(input_length=x_seq.shape[-1])
state_dict = torch.load("model_weights.pt", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
with torch.no_grad():
prediction = model(x_seq, x_struct, x_wobble)To inspect sequence properties such as SR Balance and latent sequence activation, you can use the model methods compute_sr_balance, compute_sequence_activations().
# Get one-hot sequences
x_seq = torch.tensor(dataset["seq_oh"], dtype=torch.float32)
# Compute inclusion, skipping sequence activations
a_incl, a_skip = model.compute_sequence_activations(x_seq, agg="mean")
# Compute SR balance
sr_balance = model.compute_sr_balance(x_seq, agg="mean")- The default public preprocessing path assumes unflanked exon input and adds model flanks automatically.
load_state_dict()inPNASModelresamples position-bias tensors when checkpoint and runtime input lengths differ.load_weights_from_dict()is available for loading weights converted from an external TensorFlow/Keras export format.
Please cite: Liao, Susan E., Mukund Sudarshan, and Oded Regev. "Deciphering RNA splicing logic with interpretable machine learning." Proceedings of the National Academy of Sciences 120.41 (2023): e2221165120.