Skip to content

Commit 8d018f1

Browse files
laserkelvindallasfosterktangsali
authored
Active learning abstraction (#1174)
* chore: define namespace for active learning components * feat: add protocols module defining active learning interfaces * feat: add registry abstraction for managing active learning strategies * feat: adding default training loop used for training and fine-tuning * feat: adding active learning specialized logger module * feat: adding config module for defining active learning behavior * feat: adding driver module for active learning orchestration * test: defining active learning unit test folder structure * test: add reusable components for active learning unit tests * test: add unit tests for active learning registry * test: add unit tests for default training loop * test: add unit tests for active learning configuration * test: add unit tests for active learning driver orchestrator * test: add dedicated unit tests for verifying checkpointing behavior * docs: adding example README for moons active learning example * feat: adding data module used to construct moons example * feat: adding active learning strategy definitions for moons example * script: adding example script for running the moons example * chore: adding gitignore for active learning logs in moons example * docs: added active learning entry to changelog * docs: added brief readme to describe the active learning namespace * refactor: renaming registry module to not clash, and exposing classes in active learning namespace * fix: using correct pool of training data in moon example * fix: correcting precision implementation * fix: correcting sanity check at the end of moon example * fix: disabling static capture for moon example * script: making moon example train longer in defaults * refactor: adding moon example components to registry * fix: actually pulling static captured functions from cache * fix: making sure max_fine_tuning_epochs sets defaults * refactor: adding F1Metrology to __all__ * fix: correcting logic for registry construct * fix: correcting label generation offset * fix: correcting device mapping if mismatched in DefaultTrainingLoop * fix: correcting configure optimizer behavior when optimizer might not be defined * refactor: resetting optimizer state only if it's requested and configured * fix: correcting logic for computing true positives * script: adding restart instructions at the end of moon example * feat: addoeg strategy_dir property and protocol_type attribute for folder organization * refactor: updating methods in F1Metrology to serialize records * refactor: adding option to pass a path to serialize metric records * feat: adding checkpoint_dir property to base protocol * refactor: adding load_records interface to metrology * docs: adding load_checkpoint important message on strategy states * test: fixing expected exception in registry test * docs: bumping year on SPDX header * docs: addressing doctest issues * docs: revising protocol doctests with capture methods and fixing registry register doctests --------- Co-authored-by: Dallas Foster <[email protected]> Co-authored-by: Kaustubh Tangsali <[email protected]>
1 parent 68c1854 commit 8d018f1

File tree

21 files changed

+8397
-0
lines changed

21 files changed

+8397
-0
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626
and three transient schemes.
2727
- Added a check to `stochastic_sampler` that helps handle the `EDMPrecond` model,
2828
which has a specific `.forward()` signature
29+
- Added abstract interfaces for constructing active learning workflows, contained
30+
under the `physicsnemo.active_learning` namespace. A preliminary example of how
31+
to compose and define an active learning workflow is provided in `examples/active_learning`.
32+
The `moons` example provides a minimal (pedagogical) composition that is meant to
33+
illustrate how to define the necessary parts of the workflow.
2934

3035
### Changed
3136

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
active_learning_logs/
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Moons with Active Learning
2+
3+
This example is intended to give a quick, high level overview of one kind of active learning
4+
experiments that can be put together using the `physicsnemo` active learning modules and
5+
protocols.
6+
7+
The experiment that is being done in `moon_example.py` is to use a simple MLP classifier
8+
to label 2D coordinates from the famous two-moons data distribution. The platform for
9+
the experiment is to initially show the MLP a minimal set of data (with some class imbalance),
10+
and use the prediction uncertainties from the model to query points that will be the most
11+
informative to it.
12+
13+
The main thing to monitor in this experiment is the `f1_metrology.json` output, which is
14+
a product of the `F1Metrology` strategy: in here, we compute precision/recall/F1 values
15+
as a function of the number of active learning cycles. Due to class imbalance, the initial
16+
precision will be quite poor as the predictions will heavily bias towards false negatives,
17+
but as more samples (chosen by `ClassifierUQQuery`) are added to the training set, the
18+
precision and subsequently F1 scores will improve.
19+
20+
## Quick Start
21+
22+
To run this example:
23+
24+
```bash
25+
python moon_example.py
26+
```
27+
28+
This will create an `active_learning_logs/<run_id>/` directory containing:
29+
30+
- **Model checkpoints**: `.mdlus` files saved according to `checkpoint_interval`
31+
- **Driver logs**: `driver_log.json` tracking the active learning process
32+
- **Metrology outputs**: `f1_metrology.json` with precision/recall/F1 scores over iterations
33+
- **Console logs**: `console.log` with detailed execution logs
34+
35+
The `<run_id>` is an 8-character UUID prefix that uniquely identifies each run. You can
36+
specify a custom `run_id` in `DriverConfig` if needed.
37+
38+
## Implementation notes
39+
40+
To illustrate a simple active learning process, this example implements the bare necessary
41+
ingredients:
42+
43+
1. **Training step logic** - `training_step` function defines the per-batch training logic,
44+
computing the loss from model predictions. This is passed to the `Driver` which uses it
45+
within the training loop.
46+
47+
2. **Training loop** - The example uses `DefaultTrainingLoop`, a built-in training loop
48+
that handles epoch iteration, progress bars, validation, and static capture optimizations.
49+
For reference, a custom `training_loop` function is also defined in the example but not used,
50+
showing how you could implement your own if needed.
51+
52+
3. **Query strategy** - `moon_strategies.ClassifierUQQuery` uses the classifier uncertainty
53+
to rank data indices from the full (not in training) sample set, selecting points where
54+
the model is most uncertain (predictions closest to 0.5).
55+
56+
4. **Label strategy** - `DummyLabelStrategy` handles obtaining data labels.
57+
Because ground truths are already known for this dataset, it's essentially a
58+
no-op but the `Driver` pipeline relies on it to append labeled data to the
59+
training set.
60+
61+
5. **Metrology strategy** - `F1Metrology` computes precision/recall/F1 scores and serializes
62+
them to JSON. This makes it easy to track how model performance improves with each active
63+
learning iteration, helping inform hyperparameter choices for future experiments.
64+
65+
## Configuration
66+
67+
The rest is configuration: we take the components we've written, and compose them in
68+
the various configuration dataclasses in `moon_example.py::main`.
69+
70+
### TrainingConfig
71+
72+
The `train_datapool` specifies what set of data to train on. We configure the training
73+
loop using `DefaultTrainingLoop` with progress bars enabled. The `OptimizerConfig` specifies
74+
which optimizer to use and its hyperparameters. You can configure different epoch counts
75+
for initial training vs. subsequent fine-tuning iterations.
76+
77+
```python
78+
# configure how training/fine-tuning is done within active learning
79+
training_config = c.TrainingConfig(
80+
train_datapool=dataset,
81+
optimizer_config=c.OptimizerConfig(
82+
torch.optim.SGD,
83+
optimizer_kwargs={"lr": 0.01},
84+
),
85+
# configure different times for initial training and subsequent
86+
# fine-tuning
87+
max_training_epochs=10,
88+
max_fine_tuning_epochs=5,
89+
# this configures the training loop
90+
train_loop_fn=DefaultTrainingLoop(
91+
use_progress_bars=True,
92+
),
93+
)
94+
```
95+
96+
**Key options:**
97+
98+
- `max_training_epochs`: Epochs for initial training (step 0)
99+
- `max_fine_tuning_epochs`: Epochs for subsequent fine-tuning steps
100+
- `DefaultTrainingLoop(use_progress_bars=True)`: Built-in loop with tqdm progress bars
101+
- `val_datapool`: Optional validation dataset (not used in this example)
102+
103+
### StrategiesConfig
104+
105+
The `StrategiesConfig` localizes all of the different active learning components
106+
into one place. The `queue_cls` is used to pipeline query samples to label processes.
107+
Because we're carrying out a single process workflow, `queue.Queue` is sufficient,
108+
but multiprocess variants, up to constructs like Redis Queue, can be used to pass
109+
data around the pipeline.
110+
111+
```python
112+
strategy_config = c.StrategiesConfig(
113+
query_strategies=[ClassifierUQQuery(max_samples=10)],
114+
queue_cls=queue.Queue,
115+
label_strategy=DummyLabelStrategy(),
116+
metrology_strategies=[F1Metrology()],
117+
)
118+
```
119+
120+
**Key components:**
121+
122+
- `query_strategies`: List of strategies for selecting samples (can have multiple)
123+
- `queue_cls`: Queue implementation for passing data between phases (e.g., `queue.Queue`)
124+
- `label_strategy`: Single strategy for labeling queried samples
125+
- `metrology_strategies`: List of strategies for measuring model performance
126+
- `unlabeled_datapool`: Optional pool of unlabeled data for query strategies (not shown here)
127+
128+
### DriverConfig
129+
130+
Finally, the `DriverConfig` specifies orchestration parameters that control the overall
131+
active learning loop execution:
132+
133+
```python
134+
driver_config = c.DriverConfig(
135+
batch_size=16,
136+
max_active_learning_steps=70,
137+
fine_tuning_lr=0.005,
138+
device=torch.device("cpu"), # set to other accelerators if needed
139+
)
140+
driver = Driver(
141+
config=driver_config,
142+
learner=uq_model,
143+
strategies_config=strategy_config,
144+
training_config=training_config,
145+
)
146+
# our model doesn't implement a `training_step` method but in principle
147+
# it could be implemented, and we wouldn't need to pass the step function here
148+
driver(train_step_fn=training_step)
149+
```
150+
151+
**Key parameters:**
152+
153+
- `batch_size`: Batch size for training and validation dataloaders
154+
- `max_active_learning_steps`: Total number of active learning iterations
155+
- `fine_tuning_lr`: Learning rate to switch to after the first AL step (optional)
156+
- `device`: Device for computation (e.g., `torch.device("cpu")`, `torch.device("cuda:0")`)
157+
- `dtype`: Data type for tensors (defaults to `torch.get_default_dtype()`)
158+
- `skip_training`: Set to `True` to skip training phase (default: `False`)
159+
- `skip_metrology`: Set to `True` to skip metrology phase (default: `False`)
160+
- `skip_labeling`: Set to `True` to skip labeling phase (default: `False`)
161+
- `checkpoint_interval`: Save model every N steps (default: 1, set to 0 to disable)
162+
- `root_log_dir`: Directory for logs and checkpoints (default: `"active_learning_logs"`)
163+
- `dist_manager`: Optional `DistributedManager` for multi-GPU training
164+
165+
### Running the Driver
166+
167+
The final `driver(...)` call is syntactic sugar for `driver.run(...)`, which executes the
168+
full active learning loop. The `train_step_fn` argument provides the per-batch training logic.
169+
170+
**Two ways to provide training logic:**
171+
172+
1. **Pass as function** (shown in example):
173+
174+
```python
175+
driver(train_step_fn=training_step)
176+
```
177+
178+
1. **Implement in model** (alternative):
179+
180+
```python
181+
class MLP(Module):
182+
def training_step(self, data):
183+
# training logic here
184+
...
185+
186+
driver() # no train_step_fn needed
187+
```
188+
189+
**Optional validation step:**
190+
191+
You can also provide a `validate_step_fn` parameter:
192+
193+
```python
194+
driver(train_step_fn=training_step, validate_step_fn=validation_step)
195+
```
196+
197+
### Active Learning Workflow
198+
199+
Under the hood, `Driver.active_learning_step` is called repeatedly for the number of
200+
iterations specified in `max_active_learning_steps`. Each iteration follows this sequence:
201+
202+
1. **Training Phase**: Train model on current `train_datapool` using the
203+
training loop
204+
2. **Metrology Phase**: Compute performance metrics via metrology strategies
205+
3. **Query Phase**: Select new samples to label via query strategies →
206+
`query_queue`
207+
4. **Labeling Phase**: Label queued samples via label strategy → `label_queue`
208+
→ append to `train_datapool`
209+
210+
The logic for each phase is in methods like `Driver._training_phase`,
211+
`Driver._query_phase`, etc.
212+
213+
## Advanced Customization
214+
215+
### Custom Training Loops
216+
217+
While `DefaultTrainingLoop` is suitable for most use cases, you can write
218+
custom training loops that implement the `TrainingLoop` protocol, which is the
219+
overarching logic for how to carry out model training and validation over some
220+
number of epochs. Custom loops are useful when you need:
221+
222+
- Specialized training logic (e.g., alternating, or multiple optimizers)
223+
- Custom logging or checkpointing within the loop
224+
- Non-standard epoch/batch iteration patterns
225+
226+
### Custom Strategies
227+
228+
All strategies must implement their respective protocols:
229+
230+
- **QueryStrategy**: Implement `sample(query_queue, *args, **kwargs)` and
231+
`attach(driver)`
232+
- **LabelStrategy**: Implement `label(queue_to_label, serialize_queue, *args,
233+
**kwargs)` and `attach(driver)`
234+
- **MetrologyStrategy**: Implement `compute(*args, **kwargs)`,
235+
`serialize_records(*args, **kwargs)`, and `attach(driver)`
236+
237+
The `attach(driver)` method gives your strategy access to the driver's
238+
attributes like `driver.learner`, `driver.train_datapool`,
239+
`driver.unlabeled_datapool`, etc.
240+
241+
### Static Capture and Performance
242+
243+
The `DefaultTrainingLoop` supports static capture via CUDA graphs for
244+
performance optimization:
245+
246+
```python
247+
train_loop_fn=DefaultTrainingLoop(
248+
enable_static_capture=True, # Enable CUDA graph capture (default)
249+
use_progress_bars=True,
250+
)
251+
```
252+
253+
For custom training loops, you can use:
254+
255+
- `StaticCaptureTraining` for training steps
256+
- `StaticCaptureEvaluateNoGrad` for validation/inference steps
257+
258+
### Distributed Training
259+
260+
To use multiple GPUs, provide a `DistributedManager` in `DriverConfig`:
261+
262+
```python
263+
from physicsnemo.distributed import DistributedManager
264+
265+
dist_manager = DistributedManager()
266+
driver_config = c.DriverConfig(
267+
batch_size=16,
268+
max_active_learning_steps=70,
269+
dist_manager=dist_manager, # Handles device placement and DDP
270+
)
271+
```
272+
273+
The driver will automatically wrap the model in `DistributedDataParallel` and use
274+
`DistributedSampler` for dataloaders.
275+
276+
## Experiment Ideas
277+
278+
Here, we perform a relatively straightforward experiment without a baseline; suitable
279+
ones could be to train a model using the full data, and see how the precision/recall/F1
280+
scores differ between the `ClassifierUQQuery` learner to the full data model (i.e. use
281+
the latter as a roofline).
282+
283+
A suitable baseline to compare against would be random selection: to check the efficacy
284+
of `ClassifierUQQuery`, samples could be chosen uniformly and see if and how the same
285+
metrology scores differ. If the UQ is performing as intended, then precision/recall/F1
286+
should improve at a faster rate.

0 commit comments

Comments
 (0)