|
| 1 | +# EP-01: Pluggable Training Infrastructure for sbi |
| 2 | + |
| 3 | +Status: Discussion |
| 4 | +Feedback: See GitHub Discussion → [EP-01 Discussion](https://github.com/sbi-dev/sbi/discussions/new?category=ideas) |
| 5 | + |
| 6 | +## Summary |
| 7 | + |
| 8 | +This enhancement proposal describes the refactoring of `sbi` neural network training |
| 9 | +infrastructure to address technical debt, improve maintainability, and provide users |
| 10 | +with better experiment tracking capabilities. The refactoring will introduce typed |
| 11 | +configuration objects, extract a unified training loop, and implement pluggable |
| 12 | +interfaces for logging and early stopping while maintaining full backward compatibility. |
| 13 | + |
| 14 | +## Motivation |
| 15 | + |
| 16 | +The `sbi` package has evolved organically over five years to serve researchers in |
| 17 | +simulation-based inference. This growth has resulted in several architectural issues in |
| 18 | +the training infrastructure: |
| 19 | + |
| 20 | +1. **Code duplication**: Training logic is duplicated across NPE, NLE, NRE, and other |
| 21 | + inference methods, spanning over 3,000 lines of code |
| 22 | +2. **Type safety issues**: Unstructured dictionary parameters lead to runtime errors and |
| 23 | + poor IDE support |
| 24 | +3. **Limited logging options**: Users are restricted to TensorBoard with no easy way to |
| 25 | + integrate their preferred experiment tracking tools |
| 26 | +4. **Missing features**: No built-in early stopping, leading to wasted compute and |
| 27 | + potential overfitting |
| 28 | +5. **Integration barriers**: Community contributions like PR #1629 cannot be easily |
| 29 | + integrated due to architectural constraints |
| 30 | + |
| 31 | +These issues affect both users and maintainers, slowing down development and making the |
| 32 | +codebase harder to work with. |
| 33 | + |
| 34 | +## Goals |
| 35 | + |
| 36 | +### For Users |
| 37 | + |
| 38 | +- **Seamless experiment tracking**: Support for TensorBoard, WandB, MLflow, and stdout |
| 39 | + without changing existing code |
| 40 | +- **Clear API for early stopping strategies**: Better docs for patience-based |
| 41 | + (implemented internally), plus possibly a lightweight interface to external early |
| 42 | + stopping strategies. |
| 43 | +- **Better debugging**: Typed configurations with clear error messages |
| 44 | +- **Zero migration effort**: Full backward compatibility with existing code |
| 45 | + |
| 46 | +### For Maintainers |
| 47 | + |
| 48 | +- **Reduced code duplication**: Extract shared training logic into reusable components |
| 49 | +- **Type safety**: Prevent entire classes of bugs through static typing |
| 50 | +- **Easier feature integration**: Clean interfaces for community contributions |
| 51 | +- **Future extensibility**: Lightweight interfaces for external tools, e.g., logging and |
| 52 | + early stoppping. |
| 53 | + |
| 54 | +## Non-Goals |
| 55 | + |
| 56 | +We want to avoid removing code duplication just for the sake of reduced LOC. Sometimes, |
| 57 | +code duplication is required and preferred in favor of overcomplicated large class |
| 58 | +structures with unclear separation of concerns. In other words, having clear API interfaces is |
| 59 | +more important than reducing code duplication. |
| 60 | + |
| 61 | +We also want to avoid adding complexity by aiming implementing all possible features. |
| 62 | +E.g., we probably should not implement all kinds of early stopping tools internally |
| 63 | +because this will add maintainance and documentation burden for us, and overhead for the |
| 64 | +user to understand the API and the docs. Instead, we should implement either a |
| 65 | +lightweight interface that allows to plug external early stopping tool, or implement |
| 66 | +just a basic version in our internal training (like we do now), and refer to the |
| 67 | +flexible training interface when a user wants to use other approaches. |
| 68 | + |
| 69 | +## Design |
| 70 | + |
| 71 | +These are very rough sketches of how this could look like. They should be open for |
| 72 | +discussion and can be changed substantially when we implement this (s.t. to the |
| 73 | +non-goals defined above). |
| 74 | + |
| 75 | +### Current API |
| 76 | + |
| 77 | +```python |
| 78 | +# Current: Each method has its own training implementation, mixing general training options with method-specific loss options. |
| 79 | +inference = NPE(prior=prior) |
| 80 | +inference.train( |
| 81 | + training_batch_size=50, |
| 82 | + learning_rate=5e-4, |
| 83 | + validation_fraction=0.1, |
| 84 | + stop_after_epochs=20, |
| 85 | + max_num_epochs=2**31-1, |
| 86 | + clip_max_norm=5.0, |
| 87 | + exclude_invalid_x=True, |
| 88 | + resume_training=False, |
| 89 | + show_train_summary=False, |
| 90 | +) |
| 91 | +``` |
| 92 | + |
| 93 | +### Proposed API |
| 94 | + |
| 95 | +```python |
| 96 | +# Proposed: Cleaner API with typed configurations |
| 97 | +from sbi.training import TrainConfig, LossArgs |
| 98 | + |
| 99 | +# Configure training (with IDE autocomplete and validation) |
| 100 | +train_config = TrainConfig( |
| 101 | + batch_size=50, |
| 102 | + learning_rate=5e-4, |
| 103 | + max_epochs=1000, |
| 104 | + device="cuda" |
| 105 | +) |
| 106 | + |
| 107 | +# Method-specific loss configuration |
| 108 | +loss_args = LossArgsNPE(exclude_invalid_x=True) |
| 109 | + |
| 110 | +# Train with clean API |
| 111 | +inference = NPE(prior=prior) |
| 112 | +inference.train(train_config, loss_args) |
| 113 | +``` |
| 114 | + |
| 115 | +### Logging Interface |
| 116 | + |
| 117 | +Users can seamlessly switch between logging backends: |
| 118 | + |
| 119 | +```python |
| 120 | +from sbi.training import LoggingConfig |
| 121 | + |
| 122 | +# Choose your backend - no other code changes needed |
| 123 | +logging = LoggingConfig(backend="wandb", project="my-experiment") |
| 124 | +# or: LoggingConfig(backend="tensorboard", log_dir="./runs") |
| 125 | +# or: LoggingConfig(backend="mlflow", experiment_name="sbi-run") |
| 126 | +# or: LoggingConfig(backend="stdout") # default |
| 127 | + |
| 128 | +inference.train(train_config, loss_args, logging=logging) |
| 129 | +``` |
| 130 | + |
| 131 | +### Early Stopping |
| 132 | + |
| 133 | +Multiple strategies available out of the box: |
| 134 | + |
| 135 | +```python |
| 136 | +from sbi.training import EarlyStopping |
| 137 | + |
| 138 | +# Stop when validation loss plateaus |
| 139 | +early_stop = EarlyStopping.validation_loss(patience=20, min_delta=1e-4) |
| 140 | + |
| 141 | +# Stop when learning rate drops too low |
| 142 | +early_stop = EarlyStopping.lr_threshold(min_lr=1e-6) |
| 143 | + |
| 144 | +inference.train(train_config, loss_args, early_stopping=early_stop) |
| 145 | +``` |
| 146 | + |
| 147 | +### Backward Compatibility |
| 148 | + |
| 149 | +All existing code continues to work: |
| 150 | + |
| 151 | +```python |
| 152 | +# Old API still supported - no breaking changes |
| 153 | +inference.train(training_batch_size=100, learning_rate=1e-3) |
| 154 | + |
| 155 | +# Mix old and new as needed during migration |
| 156 | +inference.train( |
| 157 | + training_batch_size=100, # old style |
| 158 | + logging=LoggingConfig(backend="wandb") # new feature |
| 159 | +) |
| 160 | +``` |
| 161 | + |
| 162 | +### Unified Backend |
| 163 | + |
| 164 | +All inference methods share the same training infrastructure: |
| 165 | + |
| 166 | +```python |
| 167 | +# NPE, NLE, NRE all use the same configuration |
| 168 | +npe = NPE(prior=prior) |
| 169 | +npe.train(train_config, loss_args) |
| 170 | + |
| 171 | +nle = NLE(prior=prior) |
| 172 | +nle.train(train_config, loss_args) |
| 173 | +``` |
| 174 | + |
| 175 | +## Example: Complete Training Pipeline |
| 176 | + |
| 177 | +```python |
| 178 | +from sbi import utils |
| 179 | +from sbi.inference import NPE |
| 180 | +from sbi.training import TrainConfig, LossArgsNPE, LoggingConfig, EarlyStopping |
| 181 | + |
| 182 | +# Setup simulation |
| 183 | +prior = utils.BoxUniform(low=-2*torch.ones(2), high=2*torch.ones(2)) |
| 184 | +simulator = lambda theta: theta + 0.1 * torch.randn_like(theta) |
| 185 | + |
| 186 | +# Configure training with type safety and autocomplete |
| 187 | +config = TrainConfig( |
| 188 | + batch_size=100, |
| 189 | + learning_rate=1e-3, |
| 190 | + max_epochs=1000 |
| 191 | +) |
| 192 | + |
| 193 | +# Setup logging and early stopping |
| 194 | +logging = LoggingConfig(backend="wandb", project="sbi-experiment") |
| 195 | +early_stop = EarlyStopping.validation_loss(patience=20) |
| 196 | + |
| 197 | +# Train with new features |
| 198 | +inference = NPE(prior=prior) |
| 199 | +theta, x = utils.simulate_for_sbi(simulator, prior, num_simulations=10000) |
| 200 | +inference.append_simulations(theta, x) |
| 201 | + |
| 202 | +neural_net = inference.train( |
| 203 | + config, |
| 204 | + LossArgsNPE(exclude_invalid_x=True), |
| 205 | + logging=logging, |
| 206 | + early_stopping=early_stop |
| 207 | +) |
| 208 | +``` |
| 209 | + |
| 210 | +## Next steps |
| 211 | + |
| 212 | +Centralizing training logic in `base.py` has historically increased the size and |
| 213 | +responsibilities of the `NeuralInference` “god class”. As a natural next step, we |
| 214 | +propose extracting the entire training loop into a standalone function that takes the |
| 215 | +configured options and training components, and returns the trained network (plus |
| 216 | +optional artifacts), e.g., something like: |
| 217 | + |
| 218 | +```python |
| 219 | +def run_training( |
| 220 | + config: TrainConfig, |
| 221 | + model: torch.nn.Module, |
| 222 | + loss_fn: Callable[..., torch.Tensor], |
| 223 | + train_loader: DataLoader, |
| 224 | + val_loader: DataLoader | None = None, |
| 225 | + optimizer: torch.optim.Optimizer | None = None, |
| 226 | + scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, |
| 227 | + callbacks: Sequence[Callback] | None = None, # logging, early stopping, etc. |
| 228 | + device: str | torch.device | None = None, |
| 229 | +) -> tuple[torch.nn.Module, TrainingSummary]: |
| 230 | + """Runs the unified training loop and returns the trained model and summary.""" |
| 231 | +``` |
| 232 | + |
| 233 | +Benefits: |
| 234 | + |
| 235 | +- Shrinks `NeuralInference` and makes responsibilities explicit. |
| 236 | +- Improves testability (train loop covered independently; inference classes can be |
| 237 | + tested with lightweight mocks). |
| 238 | +- Enables pluggable logging/early-stopping via callbacks without entangling method- |
| 239 | + specific logic. |
| 240 | +- Keeps backward compatibility: inference classes compose `run_training()` internally |
| 241 | + while still exposing the existing `.train(...)` entry point. |
| 242 | + |
| 243 | +This should be tackled in a follow-up EP or PR that would introduce `run_training()` |
| 244 | +(and a minimal `Callback` protocol), migrate NPE/NLE/NRE to call it, and add focused |
| 245 | +unit tests for the training runner. |
| 246 | + |
| 247 | +## Feedback Wanted |
| 248 | + |
| 249 | +We welcome feedback and implementation interest in GitHub Discussions: |
| 250 | + |
| 251 | +1. Which logging backends are most important? |
| 252 | +2. What early stopping strategies would be useful? |
| 253 | +3. Any concerns about the proposed API? |
| 254 | +4. What do you think about the external training function? |
| 255 | + |
| 256 | +- Discussion thread: [EP-01 Discussion](https://github.com/sbi-dev/sbi/discussions/new?category=ideas) |
| 257 | + |
| 258 | +## References |
| 259 | + |
| 260 | +- [PR #1629](https://github.com/sbi-dev/sbi/pull/1629): Community early stopping implementation |
| 261 | +- [NUMFOCUS SDG Proposal](https://github.com/numfocus/small-development-grant-proposals/issues/60): Related funding proposal |
0 commit comments