|
| 1 | +# Moons with Active Learning |
| 2 | + |
| 3 | +This example is intended to give a quick, high level overview of one kind of active learning |
| 4 | +experiments that can be put together using the `physicsnemo` active learning modules and |
| 5 | +protocols. |
| 6 | + |
| 7 | +The experiment that is being done in `moon_example.py` is to use a simple MLP classifier |
| 8 | +to label 2D coordinates from the famous two-moons data distribution. The platform for |
| 9 | +the experiment is to initially show the MLP a minimal set of data (with some class imbalance), |
| 10 | +and use the prediction uncertainties from the model to query points that will be the most |
| 11 | +informative to it. |
| 12 | + |
| 13 | +The main thing to monitor in this experiment is the `f1_metrology.json` output, which is |
| 14 | +a product of the `F1Metrology` strategy: in here, we compute precision/recall/F1 values |
| 15 | +as a function of the number of active learning cycles. Due to class imbalance, the initial |
| 16 | +precision will be quite poor as the predictions will heavily bias towards false negatives, |
| 17 | +but as more samples (chosen by `ClassifierUQQuery`) are added to the training set, the |
| 18 | +precision and subsequently F1 scores will improve. |
| 19 | + |
| 20 | +## Quick Start |
| 21 | + |
| 22 | +To run this example: |
| 23 | + |
| 24 | +```bash |
| 25 | +python moon_example.py |
| 26 | +``` |
| 27 | + |
| 28 | +This will create an `active_learning_logs/<run_id>/` directory containing: |
| 29 | + |
| 30 | +- **Model checkpoints**: `.mdlus` files saved according to `checkpoint_interval` |
| 31 | +- **Driver logs**: `driver_log.json` tracking the active learning process |
| 32 | +- **Metrology outputs**: `f1_metrology.json` with precision/recall/F1 scores over iterations |
| 33 | +- **Console logs**: `console.log` with detailed execution logs |
| 34 | + |
| 35 | +The `<run_id>` is an 8-character UUID prefix that uniquely identifies each run. You can |
| 36 | +specify a custom `run_id` in `DriverConfig` if needed. |
| 37 | + |
| 38 | +## Implementation notes |
| 39 | + |
| 40 | +To illustrate a simple active learning process, this example implements the bare necessary |
| 41 | +ingredients: |
| 42 | + |
| 43 | +1. **Training step logic** - `training_step` function defines the per-batch training logic, |
| 44 | +computing the loss from model predictions. This is passed to the `Driver` which uses it |
| 45 | +within the training loop. |
| 46 | + |
| 47 | +2. **Training loop** - The example uses `DefaultTrainingLoop`, a built-in training loop |
| 48 | +that handles epoch iteration, progress bars, validation, and static capture optimizations. |
| 49 | +For reference, a custom `training_loop` function is also defined in the example but not used, |
| 50 | +showing how you could implement your own if needed. |
| 51 | + |
| 52 | +3. **Query strategy** - `moon_strategies.ClassifierUQQuery` uses the classifier uncertainty |
| 53 | +to rank data indices from the full (not in training) sample set, selecting points where |
| 54 | +the model is most uncertain (predictions closest to 0.5). |
| 55 | + |
| 56 | +4. **Label strategy** - `DummyLabelStrategy` handles obtaining data labels. |
| 57 | +Because ground truths are already known for this dataset, it's essentially a |
| 58 | +no-op but the `Driver` pipeline relies on it to append labeled data to the |
| 59 | +training set. |
| 60 | + |
| 61 | +5. **Metrology strategy** - `F1Metrology` computes precision/recall/F1 scores and serializes |
| 62 | +them to JSON. This makes it easy to track how model performance improves with each active |
| 63 | +learning iteration, helping inform hyperparameter choices for future experiments. |
| 64 | + |
| 65 | +## Configuration |
| 66 | + |
| 67 | +The rest is configuration: we take the components we've written, and compose them in |
| 68 | +the various configuration dataclasses in `moon_example.py::main`. |
| 69 | + |
| 70 | +### TrainingConfig |
| 71 | + |
| 72 | +The `train_datapool` specifies what set of data to train on. We configure the training |
| 73 | +loop using `DefaultTrainingLoop` with progress bars enabled. The `OptimizerConfig` specifies |
| 74 | +which optimizer to use and its hyperparameters. You can configure different epoch counts |
| 75 | +for initial training vs. subsequent fine-tuning iterations. |
| 76 | + |
| 77 | +```python |
| 78 | +# configure how training/fine-tuning is done within active learning |
| 79 | +training_config = c.TrainingConfig( |
| 80 | + train_datapool=dataset, |
| 81 | + optimizer_config=c.OptimizerConfig( |
| 82 | + torch.optim.SGD, |
| 83 | + optimizer_kwargs={"lr": 0.01}, |
| 84 | + ), |
| 85 | + # configure different times for initial training and subsequent |
| 86 | + # fine-tuning |
| 87 | + max_training_epochs=10, |
| 88 | + max_fine_tuning_epochs=5, |
| 89 | + # this configures the training loop |
| 90 | + train_loop_fn=DefaultTrainingLoop( |
| 91 | + use_progress_bars=True, |
| 92 | + ), |
| 93 | +) |
| 94 | +``` |
| 95 | + |
| 96 | +**Key options:** |
| 97 | + |
| 98 | +- `max_training_epochs`: Epochs for initial training (step 0) |
| 99 | +- `max_fine_tuning_epochs`: Epochs for subsequent fine-tuning steps |
| 100 | +- `DefaultTrainingLoop(use_progress_bars=True)`: Built-in loop with tqdm progress bars |
| 101 | +- `val_datapool`: Optional validation dataset (not used in this example) |
| 102 | + |
| 103 | +### StrategiesConfig |
| 104 | + |
| 105 | +The `StrategiesConfig` localizes all of the different active learning components |
| 106 | +into one place. The `queue_cls` is used to pipeline query samples to label processes. |
| 107 | +Because we're carrying out a single process workflow, `queue.Queue` is sufficient, |
| 108 | +but multiprocess variants, up to constructs like Redis Queue, can be used to pass |
| 109 | +data around the pipeline. |
| 110 | + |
| 111 | +```python |
| 112 | +strategy_config = c.StrategiesConfig( |
| 113 | + query_strategies=[ClassifierUQQuery(max_samples=10)], |
| 114 | + queue_cls=queue.Queue, |
| 115 | + label_strategy=DummyLabelStrategy(), |
| 116 | + metrology_strategies=[F1Metrology()], |
| 117 | +) |
| 118 | +``` |
| 119 | + |
| 120 | +**Key components:** |
| 121 | + |
| 122 | +- `query_strategies`: List of strategies for selecting samples (can have multiple) |
| 123 | +- `queue_cls`: Queue implementation for passing data between phases (e.g., `queue.Queue`) |
| 124 | +- `label_strategy`: Single strategy for labeling queried samples |
| 125 | +- `metrology_strategies`: List of strategies for measuring model performance |
| 126 | +- `unlabeled_datapool`: Optional pool of unlabeled data for query strategies (not shown here) |
| 127 | + |
| 128 | +### DriverConfig |
| 129 | + |
| 130 | +Finally, the `DriverConfig` specifies orchestration parameters that control the overall |
| 131 | +active learning loop execution: |
| 132 | + |
| 133 | +```python |
| 134 | +driver_config = c.DriverConfig( |
| 135 | + batch_size=16, |
| 136 | + max_active_learning_steps=70, |
| 137 | + fine_tuning_lr=0.005, |
| 138 | + device=torch.device("cpu"), # set to other accelerators if needed |
| 139 | +) |
| 140 | +driver = Driver( |
| 141 | + config=driver_config, |
| 142 | + learner=uq_model, |
| 143 | + strategies_config=strategy_config, |
| 144 | + training_config=training_config, |
| 145 | +) |
| 146 | +# our model doesn't implement a `training_step` method but in principle |
| 147 | +# it could be implemented, and we wouldn't need to pass the step function here |
| 148 | +driver(train_step_fn=training_step) |
| 149 | +``` |
| 150 | + |
| 151 | +**Key parameters:** |
| 152 | + |
| 153 | +- `batch_size`: Batch size for training and validation dataloaders |
| 154 | +- `max_active_learning_steps`: Total number of active learning iterations |
| 155 | +- `fine_tuning_lr`: Learning rate to switch to after the first AL step (optional) |
| 156 | +- `device`: Device for computation (e.g., `torch.device("cpu")`, `torch.device("cuda:0")`) |
| 157 | +- `dtype`: Data type for tensors (defaults to `torch.get_default_dtype()`) |
| 158 | +- `skip_training`: Set to `True` to skip training phase (default: `False`) |
| 159 | +- `skip_metrology`: Set to `True` to skip metrology phase (default: `False`) |
| 160 | +- `skip_labeling`: Set to `True` to skip labeling phase (default: `False`) |
| 161 | +- `checkpoint_interval`: Save model every N steps (default: 1, set to 0 to disable) |
| 162 | +- `root_log_dir`: Directory for logs and checkpoints (default: `"active_learning_logs"`) |
| 163 | +- `dist_manager`: Optional `DistributedManager` for multi-GPU training |
| 164 | + |
| 165 | +### Running the Driver |
| 166 | + |
| 167 | +The final `driver(...)` call is syntactic sugar for `driver.run(...)`, which executes the |
| 168 | +full active learning loop. The `train_step_fn` argument provides the per-batch training logic. |
| 169 | + |
| 170 | +**Two ways to provide training logic:** |
| 171 | + |
| 172 | +1. **Pass as function** (shown in example): |
| 173 | + |
| 174 | + ```python |
| 175 | + driver(train_step_fn=training_step) |
| 176 | + ``` |
| 177 | + |
| 178 | +1. **Implement in model** (alternative): |
| 179 | + |
| 180 | + ```python |
| 181 | + class MLP(Module): |
| 182 | + def training_step(self, data): |
| 183 | + # training logic here |
| 184 | + ... |
| 185 | + |
| 186 | + driver() # no train_step_fn needed |
| 187 | + ``` |
| 188 | + |
| 189 | +**Optional validation step:** |
| 190 | + |
| 191 | +You can also provide a `validate_step_fn` parameter: |
| 192 | + |
| 193 | +```python |
| 194 | +driver(train_step_fn=training_step, validate_step_fn=validation_step) |
| 195 | +``` |
| 196 | + |
| 197 | +### Active Learning Workflow |
| 198 | + |
| 199 | +Under the hood, `Driver.active_learning_step` is called repeatedly for the number of |
| 200 | +iterations specified in `max_active_learning_steps`. Each iteration follows this sequence: |
| 201 | + |
| 202 | +1. **Training Phase**: Train model on current `train_datapool` using the |
| 203 | +training loop |
| 204 | +2. **Metrology Phase**: Compute performance metrics via metrology strategies |
| 205 | +3. **Query Phase**: Select new samples to label via query strategies → |
| 206 | +`query_queue` |
| 207 | +4. **Labeling Phase**: Label queued samples via label strategy → `label_queue` |
| 208 | +→ append to `train_datapool` |
| 209 | + |
| 210 | +The logic for each phase is in methods like `Driver._training_phase`, |
| 211 | +`Driver._query_phase`, etc. |
| 212 | + |
| 213 | +## Advanced Customization |
| 214 | + |
| 215 | +### Custom Training Loops |
| 216 | + |
| 217 | +While `DefaultTrainingLoop` is suitable for most use cases, you can write |
| 218 | +custom training loops that implement the `TrainingLoop` protocol, which is the |
| 219 | +overarching logic for how to carry out model training and validation over some |
| 220 | +number of epochs. Custom loops are useful when you need: |
| 221 | + |
| 222 | +- Specialized training logic (e.g., alternating, or multiple optimizers) |
| 223 | +- Custom logging or checkpointing within the loop |
| 224 | +- Non-standard epoch/batch iteration patterns |
| 225 | + |
| 226 | +### Custom Strategies |
| 227 | + |
| 228 | +All strategies must implement their respective protocols: |
| 229 | + |
| 230 | +- **QueryStrategy**: Implement `sample(query_queue, *args, **kwargs)` and |
| 231 | +`attach(driver)` |
| 232 | +- **LabelStrategy**: Implement `label(queue_to_label, serialize_queue, *args, |
| 233 | +**kwargs)` and `attach(driver)` |
| 234 | +- **MetrologyStrategy**: Implement `compute(*args, **kwargs)`, |
| 235 | +`serialize_records(*args, **kwargs)`, and `attach(driver)` |
| 236 | + |
| 237 | +The `attach(driver)` method gives your strategy access to the driver's |
| 238 | +attributes like `driver.learner`, `driver.train_datapool`, |
| 239 | +`driver.unlabeled_datapool`, etc. |
| 240 | + |
| 241 | +### Static Capture and Performance |
| 242 | + |
| 243 | +The `DefaultTrainingLoop` supports static capture via CUDA graphs for |
| 244 | +performance optimization: |
| 245 | + |
| 246 | +```python |
| 247 | +train_loop_fn=DefaultTrainingLoop( |
| 248 | + enable_static_capture=True, # Enable CUDA graph capture (default) |
| 249 | + use_progress_bars=True, |
| 250 | +) |
| 251 | +``` |
| 252 | + |
| 253 | +For custom training loops, you can use: |
| 254 | + |
| 255 | +- `StaticCaptureTraining` for training steps |
| 256 | +- `StaticCaptureEvaluateNoGrad` for validation/inference steps |
| 257 | + |
| 258 | +### Distributed Training |
| 259 | + |
| 260 | +To use multiple GPUs, provide a `DistributedManager` in `DriverConfig`: |
| 261 | + |
| 262 | +```python |
| 263 | +from physicsnemo.distributed import DistributedManager |
| 264 | + |
| 265 | +dist_manager = DistributedManager() |
| 266 | +driver_config = c.DriverConfig( |
| 267 | + batch_size=16, |
| 268 | + max_active_learning_steps=70, |
| 269 | + dist_manager=dist_manager, # Handles device placement and DDP |
| 270 | +) |
| 271 | +``` |
| 272 | + |
| 273 | +The driver will automatically wrap the model in `DistributedDataParallel` and use |
| 274 | +`DistributedSampler` for dataloaders. |
| 275 | + |
| 276 | +## Experiment Ideas |
| 277 | + |
| 278 | +Here, we perform a relatively straightforward experiment without a baseline; suitable |
| 279 | +ones could be to train a model using the full data, and see how the precision/recall/F1 |
| 280 | +scores differ between the `ClassifierUQQuery` learner to the full data model (i.e. use |
| 281 | +the latter as a roofline). |
| 282 | + |
| 283 | +A suitable baseline to compare against would be random selection: to check the efficacy |
| 284 | +of `ClassifierUQQuery`, samples could be chosen uniformly and see if and how the same |
| 285 | +metrology scores differ. If the UQ is performing as intended, then precision/recall/F1 |
| 286 | +should improve at a faster rate. |
0 commit comments