A combined State Embedding (SE) and State Transition (ST) model for cross-cell-type perturbation prediction in single-cell genomics.
- Cross-cell-type generalization: Better prediction accuracy across different cell types
- Cell-type-agnostic modeling: Uses universal state embeddings for robust predictions
- Pre-trained SE integration: Leverages pre-trained State Embedding models
- Easy installation: Install via pip from GitHub
- Complete pipeline: Training, inference, and evaluation utilities
# Install directly from GitHub
pip install git+https://github.com/maggie26375/se-st-combined@main
# Or using uv (faster)
uv add git+https://github.com/maggie26375/se-st-combined@main# Clone the repository
git clone https://github.com/maggie26375/se-st-combined.git
cd se-st-combined
# Install in development mode
pip install -e .
# Or using uv
uv pip install -e .from se_st_combined.models.se_st_combined import SE_ST_CombinedModel
from se_st_combined.utils.se_st_utils import load_se_st_model, predict_perturbation_effects
# Load the model
model = load_se_st_model(
model_dir="path/to/model",
checkpoint_path="path/to/checkpoint.ckpt",
se_model_path="path/to/se/model",
se_checkpoint_path="path/to/se/checkpoint.ckpt"
)
# Make predictions
predictions = predict_perturbation_effects(
model=model,
ctrl_expressions=ctrl_expressions,
pert_embeddings=pert_embeddings
)# See examples/se_st_virtual_cell_challenge.py for complete training example
import se_st_combined
# Training command (similar to original STATE training)
! uv run se-st-train \
data.kwargs.toml_config_path="data/starter.toml" \
data.kwargs.perturbation_features_file="data/ESM2_pert_features.pt" \
training.max_steps=40000 \
model=se_st_combined \
model.kwargs.se_model_path="SE-600M" \
model.kwargs.se_checkpoint_path="SE-600M/se600m_epoch15.ckpt" \
output_dir="results" \
name="se_st_experiment"The SE+ST Combined Model consists of two main components:
- Converts raw gene expression to universal state embeddings
- Uses pre-trained SE model (e.g., SE-600M)
- Provides cell-type-agnostic representations
- Predicts perturbation effects in state embedding space
- Uses transformer architecture (GPT2/Llama)
- Learns set-to-set functions for perturbation modeling
Raw Expression β SE Encoder β State Embeddings
β
Perturbation Embeddings β ST Predictor β Predicted Expression
Expected improvements over baseline StateTransition model:
- Cross-cell-type accuracy: 10-20% improvement
- Generalization: Better performance on unseen cell types
- Training stability: More stable convergence
- Robustness: More consistent predictions across cell types
The model can be configured via YAML files:
# se_st_combined.yaml
name: se_st_combined
kwargs:
se_model_path: "SE-600M"
se_checkpoint_path: "SE-600M/se600m_epoch15.ckpt"
freeze_se_model: true
st_hidden_dim: 672
st_cell_set_len: 128
transformer_backbone_key: llamase-st-combined/
βββ se_st_combined/
β βββ models/
β β βββ se_st_combined.py # Main SE+ST model
β β βββ base.py # Base perturbation model
β β βββ state_transition.py # ST model component
β β βββ utils.py # Model utilities
β βββ utils/
β β βββ se_st_utils.py # Training/inference utilities
β β βββ se_inference.py # SE model inference
β βββ configs/
β β βββ se_st_combined.yaml # Model configuration
β βββ data/ # Data utilities
βββ examples/
β βββ se_st_virtual_cell_challenge.py # Complete training example
βββ setup.py # Package setup
βββ requirements.txt # Dependencies
βββ README.md # This file
See examples/se_st_virtual_cell_challenge.py for a complete example of training and evaluating the SE+ST model on the Virtual Cell Challenge dataset.
from se_st_combined.utils.se_st_utils import evaluate_cross_cell_type_performance
# Evaluate performance across different cell types
results = evaluate_cross_cell_type_performance(
model=model,
test_data=test_data,
cell_types=["k562", "hepg2", "rpe1", "jurkat"]
)
# Compare with baseline
comparison = compare_with_baseline(se_st_results, baseline_results)This model is based on the STATE (State Transition and Embedding) framework for single-cell perturbation prediction. The key innovation is combining:
- State Embedding models for universal cell representations
- State Transition models for perturbation effect prediction
- Cross-cell-type generalization through shared embedding space
- Python >= 3.8
- PyTorch >= 1.12.0
- Lightning >= 2.0.0
- Scanpy >= 1.9.0
- And more (see requirements.txt)
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
- Based on the STATE framework from Arc Institute
- Uses pre-trained SE models for cell embeddings
- Inspired by the Virtual Cell Challenge
If you encounter any issues or have questions, please:
- Check the examples in the
examples/directory - Review the configuration options
- Open an issue on GitHub