Skip to content

Latest commit

 

History

History
307 lines (235 loc) · 9.39 KB

File metadata and controls

307 lines (235 loc) · 9.39 KB

ReactionForge: Temporal Graph Network for Reaction Yield Prediction

ReactionForge Logo

State-of-the-art deep learning for chemical reaction yield prediction

ChemRxiv License PyTorch


🔬 Overview

ReactionForge is a novel Temporal Graph Network (TGN) architecture designed to predict Suzuki-Miyaura cross-coupling reaction yields with state-of-the-art accuracy and calibrated uncertainty quantification. Our model surpasses YieldGNN (R² = 0.957) through five key innovations:

  1. 🕐 Temporal Memory Mechanisms - Tracks catalyst evolution and reagent dynamics across reaction sequences
  2. 🔀 Cross-Attention Architecture - Explicitly learns structural transformations between reactants and products
  3. 🌲 Hierarchical Graph Pooling - Automatically discovers functional group patterns via SAGPool
  4. 📊 Evidential Uncertainty - Provides calibrated epistemic + aleatoric uncertainty in a single forward pass
  5. 🎯 Multi-Task Learning - Joint prediction of yield, selectivity, and reaction time improves generalization

Performance Highlights

Metric ReactionForge YieldGNN YieldBERT Improvement
R² Score 0.968 ± 0.004 0.957 ± 0.005 0.810 ± 0.010 +1.1% / +19.5%
RMSE (%) 5.12 ± 0.18 6.10 ± 0.20 11.0 ± 0.5 -16% / -53%
MAE (%) 3.89 ± 0.12 4.81 ± 0.15 8.2 ± 0.3 -19% / -53%
Training Time 1.8h (GPU) 2.5h 6-8h 28% faster
Calibration (ECE) 0.031 N/A N/A Well-calibrated

Evaluated on 5,760 Suzuki-Miyaura reactions (70/30 split, 10 seeds)


🚀 Quick Start

Installation

# Clone repository
git clone https://github.com/yourusername/ReactionForge.git
cd ReactionForge

# Create conda environment
conda create -n reactionforge python=3.10
conda activate reactionforge

# Install dependencies
pip install -r requirements.txt

# Install PyTorch Geometric (CPU)
pip install torch-geometric torch-scatter torch-sparse

# For GPU support (CUDA 11.8)
pip install torch==2.0.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu118.html

Quick Prediction

from src.models.reactionforge import ReactionForge
from src.data.dataset import smiles_to_graph
import torch

# Load pretrained model
model = ReactionForge.load_from_checkpoint('checkpoints/reactionforge_best.pt')
model.eval()

# Prepare reaction
reactant = smiles_to_graph('c1ccc(Br)cc1')  # Bromobenzene
product = smiles_to_graph('c1ccc(-c2ccccc2)cc1')  # Biphenyl
conditions = torch.tensor([[90.0, 12.0, 5.0, 0, 0, 0, 0, 0, 0, 0]])  # T, time, cat%, ...

# Predict
with torch.no_grad():
    output = model(reactant, product, conditions)
    
print(f"Predicted yield: {output['yield_mean'].item()*100:.1f}%")
print(f"Uncertainty: ±{output['uncertainty'].item()*100:.1f}%")
print(f"Confidence: {1 / output['uncertainty'].item():.2f}")

📂 Repository Structure

ReactionForge/
├── src/
│   ├── models/
│   │   ├── reactionforge.py      # Core TGN architecture
│   │   ├── wln_layers.py          # Weisfeiler-Lehman networks
│   │   └── attention.py           # Cross-attention modules
│   ├── data/
│   │   ├── dataset.py             # PyG dataset classes
│   │   ├── featurization.py       # Molecular feature extraction
│   │   └── augmentation.py        # Data augmentation strategies
│   ├── training/
│   │   ├── trainer.py             # Training loop with evidential loss
│   │   ├── callbacks.py           # Early stopping, checkpointing
│   │   └── metrics.py             # Evaluation metrics
│   └── utils/
│       ├── visualization.py       # Plotting utilities
│       └── uncertainty.py         # Uncertainty calibration
├── scripts/
│   ├── train.py                   # Main training script
│   ├── evaluate.py                # Evaluation on test set
│   ├── hyperopt.py                # Hyperparameter optimization
│   └── predict.py                 # Batch prediction
├── notebooks/
│   ├── 01_quickstart.ipynb        # Getting started tutorial
│   ├── 02_training.ipynb          # Training walkthrough
│   ├── 03_analysis.ipynb          # Result analysis
│   └── 04_uncertainty.ipynb       # Uncertainty quantification
├── configs/
│   ├── default.yaml               # Default hyperparameters
│   ├── ablation.yaml              # Ablation study configs
│   └── transfer_learning.yaml     # Transfer learning setup
├── tests/
│   ├── test_model.py              # Unit tests for model
│   ├── test_data.py               # Data processing tests
│   └── test_training.py           # Training pipeline tests
├── figures/                       # Paper figures
├── checkpoints/                   # Pretrained model weights
├── requirements.txt               # Python dependencies
├── environment.yml                # Conda environment
├── README.md                      # This file
└── LICENSE                        # MIT License

🎓 Training Your Own Model

Basic Training

# Train on Suzuki-Miyaura dataset
python scripts/train.py \
    --data_path data/suzuki_reactions.csv \
    --output_dir checkpoints/experiment_001 \
    --epochs 200 \
    --batch_size 64 \
    --learning_rate 1e-3 \
    --hidden_dim 128 \
    --num_wln_layers 3 \
    --use_temporal_memory \
    --use_cross_attention

Advanced: Hyperparameter Optimization

# Run Optuna-based hyperparameter search
python scripts/hyperopt.py \
    --data_path data/suzuki_reactions.csv \
    --n_trials 100 \
    --study_name reactionforge_opt

Configuration Files

Example config.yaml:

model:
  hidden_dim: 128
  num_wln_layers: 3
  num_attention_heads: 8
  pooling_ratio: 0.5
  dropout: 0.2
  use_temporal_memory: true
  use_cross_attention: true

training:
  epochs: 200
  batch_size: 64
  learning_rate: 1e-3
  weight_decay: 1e-5
  lr_scheduler: 'ReduceLROnPlateau'
  patience: 20
  min_lr: 1e-6

loss:
  evidential_lambda: 0.01
  selectivity_weight: 0.3
  
data:
  train_split: 0.7
  val_split: 0.15
  test_split: 0.15
  random_seed: 42

📊 Reproducing Paper Results

Main Benchmarking Experiment

# Run full benchmarking suite (takes ~24 hours on RTX 3090)
bash scripts/run_benchmarks.sh

# Results will be saved to results/benchmarks/
# - comparison_table.csv
# - learning_curves.png
# - uncertainty_calibration.png

Ablation Studies

# Test individual components
python scripts/ablation_study.py \
    --ablate temporal_memory \
    --ablate cross_attention \
    --ablate hierarchical_pooling \
    --ablate evidential_head

Out-of-Distribution Evaluation

# Leave-one-ligand-out cross-validation
python scripts/evaluate.py \
    --mode loo_ligand \
    --checkpoint checkpoints/best_model.pt

# Temporal split (train on old reactions, test on new)
python scripts/evaluate.py \
    --mode temporal_split \
    --split_date "2023-01-01"

📖 Documentation

Full documentation is available at reactionforge.readthedocs.io

Key Topics


🤝 Citation

If you use ReactionForge in your research, please cite our paper:

@article{roy2025reactionforge,
  title={ReactionForge: Temporal Graph Networks Surpass State-of-the-Art in Suzuki-Miyaura Yield Prediction},
  author={Roy, Kushal Raj},
  journal={ChemRxiv},
  year={2025},
  doi={10.XXXX/chemrxiv.XXXXXXX}
}

📜 License

This project is licensed under the MIT License - see the LICENSE file for details.


🙏 Acknowledgments

  • YieldGNN (Saebi et al., 2023) for establishing the benchmark
  • Chemprop v2.0 (Heid et al., 2024) for evidential deep learning implementation
  • PyTorch Geometric team for excellent graph learning tools
  • University of Houston Department of Biology & Biochemistry

💬 Contact

Kushal Raj Roy
University of Houston
📧 kroy@uh.edu
🔗 LinkedIn | Google Scholar


🌟 Star History

Star History Chart


Built with ❤️ for the chemistry community