Skip to content

maggie26375/neuro-ide

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

SE+ST Combined Model

A combined State Embedding (SE) and State Transition (ST) model for cross-cell-type perturbation prediction in single-cell genomics.

πŸš€ Features

  • Cross-cell-type generalization: Better prediction accuracy across different cell types
  • Cell-type-agnostic modeling: Uses universal state embeddings for robust predictions
  • Pre-trained SE integration: Leverages pre-trained State Embedding models
  • Easy installation: Install via pip from GitHub
  • Complete pipeline: Training, inference, and evaluation utilities

πŸ“¦ Installation

From GitHub (Recommended)

# Install directly from GitHub
pip install git+https://github.com/maggie26375/se-st-combined@main

# Or using uv (faster)
uv add git+https://github.com/maggie26375/se-st-combined@main

From Source

# Clone the repository
git clone https://github.com/maggie26375/se-st-combined.git
cd se-st-combined

# Install in development mode
pip install -e .

# Or using uv
uv pip install -e .

🎯 Quick Start

Basic Usage

from se_st_combined.models.se_st_combined import SE_ST_CombinedModel
from se_st_combined.utils.se_st_utils import load_se_st_model, predict_perturbation_effects

# Load the model
model = load_se_st_model(
    model_dir="path/to/model",
    checkpoint_path="path/to/checkpoint.ckpt",
    se_model_path="path/to/se/model",
    se_checkpoint_path="path/to/se/checkpoint.ckpt"
)

# Make predictions
predictions = predict_perturbation_effects(
    model=model,
    ctrl_expressions=ctrl_expressions,
    pert_embeddings=pert_embeddings
)

Training

# See examples/se_st_virtual_cell_challenge.py for complete training example
import se_st_combined

# Training command (similar to original STATE training)
! uv run se-st-train \
  data.kwargs.toml_config_path="data/starter.toml" \
  data.kwargs.perturbation_features_file="data/ESM2_pert_features.pt" \
  training.max_steps=40000 \
  model=se_st_combined \
  model.kwargs.se_model_path="SE-600M" \
  model.kwargs.se_checkpoint_path="SE-600M/se600m_epoch15.ckpt" \
  output_dir="results" \
  name="se_st_experiment"

πŸ—οΈ Architecture

The SE+ST Combined Model consists of two main components:

1. State Embedding (SE) Encoder

  • Converts raw gene expression to universal state embeddings
  • Uses pre-trained SE model (e.g., SE-600M)
  • Provides cell-type-agnostic representations

2. State Transition (ST) Predictor

  • Predicts perturbation effects in state embedding space
  • Uses transformer architecture (GPT2/Llama)
  • Learns set-to-set functions for perturbation modeling

Data Flow

Raw Expression β†’ SE Encoder β†’ State Embeddings
                                    ↓
Perturbation Embeddings β†’ ST Predictor β†’ Predicted Expression

πŸ“Š Performance

Expected improvements over baseline StateTransition model:

  • Cross-cell-type accuracy: 10-20% improvement
  • Generalization: Better performance on unseen cell types
  • Training stability: More stable convergence
  • Robustness: More consistent predictions across cell types

πŸ”§ Configuration

The model can be configured via YAML files:

# se_st_combined.yaml
name: se_st_combined
kwargs:
  se_model_path: "SE-600M"
  se_checkpoint_path: "SE-600M/se600m_epoch15.ckpt"
  freeze_se_model: true
  st_hidden_dim: 672
  st_cell_set_len: 128
  transformer_backbone_key: llama

πŸ“ Project Structure

se-st-combined/
β”œβ”€β”€ se_st_combined/
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   β”œβ”€β”€ se_st_combined.py      # Main SE+ST model
β”‚   β”‚   β”œβ”€β”€ base.py                # Base perturbation model
β”‚   β”‚   β”œβ”€β”€ state_transition.py    # ST model component
β”‚   β”‚   └── utils.py               # Model utilities
β”‚   β”œβ”€β”€ utils/
β”‚   β”‚   β”œβ”€β”€ se_st_utils.py         # Training/inference utilities
β”‚   β”‚   └── se_inference.py        # SE model inference
β”‚   β”œβ”€β”€ configs/
β”‚   β”‚   └── se_st_combined.yaml    # Model configuration
β”‚   └── data/                      # Data utilities
β”œβ”€β”€ examples/
β”‚   └── se_st_virtual_cell_challenge.py  # Complete training example
β”œβ”€β”€ setup.py                       # Package setup
β”œβ”€β”€ requirements.txt               # Dependencies
└── README.md                      # This file

πŸ§ͺ Examples

Virtual Cell Challenge

See examples/se_st_virtual_cell_challenge.py for a complete example of training and evaluating the SE+ST model on the Virtual Cell Challenge dataset.

Cross-cell-type Prediction

from se_st_combined.utils.se_st_utils import evaluate_cross_cell_type_performance

# Evaluate performance across different cell types
results = evaluate_cross_cell_type_performance(
    model=model,
    test_data=test_data,
    cell_types=["k562", "hepg2", "rpe1", "jurkat"]
)

# Compare with baseline
comparison = compare_with_baseline(se_st_results, baseline_results)

πŸ”¬ Research Background

This model is based on the STATE (State Transition and Embedding) framework for single-cell perturbation prediction. The key innovation is combining:

  1. State Embedding models for universal cell representations
  2. State Transition models for perturbation effect prediction
  3. Cross-cell-type generalization through shared embedding space

πŸ“š Dependencies

  • Python >= 3.8
  • PyTorch >= 1.12.0
  • Lightning >= 2.0.0
  • Scanpy >= 1.9.0
  • And more (see requirements.txt)

🀝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments

  • Based on the STATE framework from Arc Institute
  • Uses pre-trained SE models for cell embeddings
  • Inspired by the Virtual Cell Challenge

πŸ“ž Support

If you encounter any issues or have questions, please:

  1. Check the examples in the examples/ directory
  2. Review the configuration options
  3. Open an issue on GitHub

πŸ”— Related Work

About

neural ODE-based pertubation prediction model

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published