This repository implements a Graph Neural Network (GNN) enhanced version of the SE+ST model for cellular perturbation prediction, integrating gene regulatory networks from the STRING database.
- STRING Network Integration: Automatically downloads and processes gene regulatory networks from STRING database
- Multiple GNN Architectures: Supports GCN, GAT, and GraphSAGE
- SE+ST Compatibility: Built on top of the proven SE+ST architecture
- Flexible Configuration: Easy to switch between pure SE-ST and GNN-enhanced modes
Input (Control Cells)
↓
SE Encoder (genes → cell state)
↓
GNN (propagate through gene regulatory network)
↓
Transformer (cell-cell interactions)
↓
Decoder (state → perturbed genes)
↓
Output (Perturbed Cells)
# Core dependencies
pip install torch lightning hydra-core
# PyTorch Geometric (for GNN)
pip install torch-geometric
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
# Data processing
pip install anndata scanpy pandas numpy requests# Clone repository
git clone https://github.com/maggie26375/gnn.git
cd gnn
# Install dependencies
pip install -r requirements.txt# Train GNN model with STRING network
python -m gnn.cli.train_gnn \
model.use_gnn=true \
model.gnn_type=gcn \
model.gnn_layers=2 \
model.string_confidence=0.4 \
training.max_epochs=100# Disable GNN to use pure SE-ST
python -m gnn.cli.train_gnn \
model.use_gnn=false \
training.max_epochs=100Key parameters in configs/gnn_config.yaml:
model:
# GNN settings
use_gnn: true # Enable/disable GNN
gnn_type: "gcn" # GNN architecture: "gcn", "gat", "sage"
gnn_layers: 2 # Number of GNN layers
gnn_hidden_dim: 512 # Hidden dimension
string_confidence: 0.4 # STRING confidence threshold (0-1)
# SE-ST settings
se_model_path: "SE-600M"
freeze_se_model: true
st_hidden_dim: 512The model automatically downloads gene regulatory networks from STRING:
- Species: Human (Homo sapiens, taxonomy ID 9606)
- Network Type: Physical interactions (direct protein-protein interactions)
- Confidence: Configurable threshold (default 0.4)
- Cache: Downloaded files are cached in
data/string_cache/
from gnn.utils.string_network_loader import load_string_network_for_hvgs
# Load STRING network for your genes
edge_index, gene_to_idx = load_string_network_for_hvgs(
hvg_gene_names=your_gene_list,
cache_dir="./data/string_cache",
confidence_threshold=0.4
)| Model | Description | Use Case |
|---|---|---|
| SE-ST | Transformer-based, no explicit network | General purpose, fast training |
| GNN-enhanced | SE-ST + gene regulatory network | Biologically-informed, interpretable |
-
GNN Type Selection:
- GCN: Fastest, good for large networks
- GAT: Attention mechanism, slower but more expressive
- GraphSAGE: Good for inductive learning
-
STRING Confidence:
- Lower (0.15-0.4): More edges, more noise
- Higher (0.5-0.9): Fewer edges, higher quality
-
Number of Layers:
- 1-2 layers: Local neighborhood
- 3-4 layers: Broader network propagation
gnn/
├── models/
│ ├── gnn_perturbation.py # GNN model
│ ├── se_st_combined.py # SE+ST base model
│ └── ...
├── utils/
│ ├── string_network_loader.py # STRING database loader
│ └── ...
├── cli/
│ ├── train_gnn.py # Training script
│ └── ...
├── configs/
│ └── gnn_config.yaml # Configuration
└── README.md
If you encounter issues installing PyTorch Geometric:
# Check your PyTorch version
python -c "import torch; print(torch.__version__)"
# Install PyG for your specific PyTorch version
# Replace cu118 with your CUDA version (cu102, cu113, cu118, etc.)
pip install torch-geometric -f https://data.pyg.org/whl/torch-2.0.0+cu118.htmlIf STRING download fails:
- Check internet connection
- Try manual download from: https://string-db.org/cgi/download
- Place file in
data/string_cache/directory
If training runs out of memory:
- Reduce batch_size:
training.batch_size=8 - Reduce GNN layers:
model.gnn_layers=1 - Use smaller GNN:
model.gnn_hidden_dim=256
If you use this code, please cite:
@software{gnn_perturbation,
title={GNN-based Perturbation Prediction},
author={Your Name},
year={2025},
url={https://github.com/maggie26375/gnn}
}MIT License
- STRING database: https://string-db.org/
- PyTorch Geometric: https://pytorch-geometric.readthedocs.io/
- SE+ST model: Original implementation