Skip to content

maggie26375/se_st_upgrade

Repository files navigation

SE+ST Upgrade: Advanced Optimization Pipeline

Enhanced SE+ST model with AutoTune, AdapterTune, and RL optimization for competition.

🚀 Quick Start

# 1. Install dependencies
pip install -r requirements.txt

# 2. (Optional) Install as package
pip install -e .

# 3. Make scripts executable
chmod +x scripts/*.sh

# 4. Train baseline model (80K steps, ~16 hours)
./scripts/train_baseline.sh

📋 Features

1. AutoTune - Hyperparameter Optimization

  • Uses Optuna for Bayesian hyperparameter search
  • Automatically finds best model configuration
  • Saves best config for retraining

2. AdapterTune - Parameter-Efficient Fine-Tuning

  • LoRA (Low-Rank Adaptation) adapters
  • 95% parameter reduction
  • 3-5× faster training
  • Same performance as full fine-tuning

3. RL Optimization - Direct Competition Metric Optimization

  • Policy gradient fine-tuning
  • Optimizes EXACT competition metrics (DES + PDS + MAE)
  • Outputs standard .ckpt files for inference
  • No separate agent needed

🎯 Training Pipeline

See scripts/README.md for complete documentation.

Recommended: Systematic Testing

# Step 1: Train baseline (16 hours)
./scripts/train_baseline.sh
# → Run inference → Record baseline score

# Step 2: Test RL (10 hours)
./scripts/train_rl.sh
# → Run inference → Compare with baseline

# Step 3: Test AdapterTune (5 hours)
./scripts/train_adaptertune.sh
# → Run inference → Compare with baseline

# Step 4: Test combination (10 hours)
./scripts/train_rl_from_adapter.sh
# → Run inference → Find best approach

📊 Competition Metrics

The RL optimization uses EXACT competition scoring:

  • DES (Differential Expression Score): Wilcoxon rank-sum test + Benjamini-Hochberg FDR correction
  • PDS (Perturbation Discrimination Score): L1 Manhattan distance ranking
  • MAE (Mean Absolute Error): Pseudobulk-level prediction error

See utils/competition_metrics.py for implementation.

📁 Project Structure

se_st_upgrade/
├── cli/                    # Command-line interfaces
│   ├── train.py           # Baseline training
│   ├── adaptertune.py     # Adapter fine-tuning
│   ├── rltune.py          # RL optimization
│   ├── autotune.py        # Hyperparameter search
│   └── infer.py           # Inference
├── configs/               # Hydra configuration files
├── models/                # Model architectures
│   ├── se_st_combined.py # Main SE+ST model
│   ├── adapters.py       # LoRA adapters
│   └── adapter_se_st.py  # Adapter-enhanced model
├── utils/                 # Utilities
│   ├── competition_metrics.py      # Competition scoring
│   ├── hyperparameter_search.py    # AutoTune logic
│   └── rl_environment.py           # (Legacy, not used)
├── data/                  # Data loading
├── scripts/               # Training scripts
│   ├── README.md         # Complete script documentation
│   ├── train_baseline.sh
│   ├── train_rl.sh
│   ├── train_adaptertune.sh
│   ├── train_rl_from_adapter.sh
│   └── train_autotune.sh
└── OPTIMIZATION_GUIDE.md  # Detailed optimization guide

🔧 Configuration

All training is configured via Hydra YAML files in configs/.

Example: Modify RL learning rate

./scripts/train_rl.sh  # Uses configs/rltune.yaml

# Or override directly:
python -m cli.rltune rl.training.learning_rate=5e-6

📝 Inference

After training, run inference with your best checkpoint:

se-st-infer \
  --checkpoint competition/rl_from_adapter/best_model.ckpt \
  --adata /data/competition_val_template.h5ad \
  --output prediction.h5ad \
  --perturbation-features /data/ESM2_pert_features.pt \
  --se-model-path SE-600M \
  --batch-size 16 \
  --device cuda

📈 Expected Performance

Method Training Time Improvement Parameters Trained
Baseline (80K) 16h - 100%
+ RL +10h +2-5% 100%
+ AdapterTune +5h +3-8% ~5%
+ Adapter+RL +10h +5-12% ~5%
+ AutoTune +60h +10-20% 100% (new config)

🐛 Troubleshooting

CUDA Out of Memory

# Reduce batch size in scripts
# train_baseline.sh: BATCH_SIZE=4
# train_adaptertune.sh: BATCH_SIZE=8
# train_rl.sh: BATCH_SIZE=2

Checkpoint Not Found

# Check if previous step completed
ls competition/baseline_80k/final_model.ckpt

# Or update MODEL_CHECKPOINT in the script

Slow Training

# Check GPU utilization
nvidia-smi

# Reduce num_workers if CPU bottleneck
NUM_WORKERS=2

📖 Documentation

🤝 Contributing

This is a competition project. Feel free to experiment with:

  • New adapter architectures
  • Different RL reward functions
  • Additional hyperparameters for AutoTune
  • Alternative optimization methods

📄 License

MIT License

🙏 Acknowledgments

  • SE model from State Embedding
  • Competition hosted by [...]
  • Built with PyTorch Lightning, Hydra, and Optuna

Good luck with your training! 🚀

For questions or issues, check the documentation or create an issue on GitHub.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published