-
Notifications
You must be signed in to change notification settings - Fork 473
Active learning abstraction #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
laserkelvin
merged 52 commits into
NVIDIA:main
from
laserkelvin:active-learning-abstraction
Oct 27, 2025
Merged
Changes from all commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
c83b6ed
chore: define namespace for active learning components
laserkelvin 45a48cb
feat: add protocols module defining active learning interfaces
laserkelvin ab7190d
feat: add registry abstraction for managing active learning strategies
laserkelvin d3f6e19
feat: adding default training loop used for training and fine-tuning
laserkelvin dcae0cd
feat: adding active learning specialized logger module
laserkelvin 25b73bd
feat: adding config module for defining active learning behavior
laserkelvin 3bb4674
feat: adding driver module for active learning orchestration
laserkelvin cfa833e
test: defining active learning unit test folder structure
laserkelvin 430c124
test: add reusable components for active learning unit tests
laserkelvin 0d8f07b
test: add unit tests for active learning registry
laserkelvin a49437d
test: add unit tests for default training loop
laserkelvin 7660570
test: add unit tests for active learning configuration
laserkelvin c23eb45
test: add unit tests for active learning driver orchestrator
laserkelvin ad031c0
test: add dedicated unit tests for verifying checkpointing behavior
laserkelvin dab1fb4
docs: adding example README for moons active learning example
laserkelvin 46c7c62
feat: adding data module used to construct moons example
laserkelvin 3c3075e
feat: adding active learning strategy definitions for moons example
laserkelvin e2c5a94
script: adding example script for running the moons example
laserkelvin 208d031
chore: adding gitignore for active learning logs in moons example
laserkelvin ba1591e
docs: added active learning entry to changelog
laserkelvin ae4edac
docs: added brief readme to describe the active learning namespace
laserkelvin 9b58a61
refactor: renaming registry module to not clash, and exposing classes…
laserkelvin ed2c275
fix: using correct pool of training data in moon example
laserkelvin e32285d
fix: correcting precision implementation
laserkelvin be4576c
fix: correcting sanity check at the end of moon example
laserkelvin 6e2a988
fix: disabling static capture for moon example
laserkelvin 1f67f7a
script: making moon example train longer in defaults
laserkelvin 9ecfc3b
refactor: adding moon example components to registry
laserkelvin 2717bc8
fix: actually pulling static captured functions from cache
laserkelvin cb122cb
fix: making sure max_fine_tuning_epochs sets defaults
laserkelvin 0c4aab4
refactor: adding F1Metrology to __all__
laserkelvin 5f31a6d
fix: correcting logic for registry construct
laserkelvin a31cd18
fix: correcting label generation offset
laserkelvin 8b8d324
fix: correcting device mapping if mismatched in DefaultTrainingLoop
laserkelvin da6c676
Merge branch 'main' into active-learning-abstraction
laserkelvin ad1c90c
fix: correcting configure optimizer behavior when optimizer might not…
laserkelvin 1057ea3
refactor: resetting optimizer state only if it's requested and confi…
laserkelvin dfc40e0
fix: correcting logic for computing true positives
laserkelvin 41124e7
Merge branch 'main' into active-learning-abstraction
laserkelvin 38a8082
script: adding restart instructions at the end of moon example
laserkelvin eb05afe
feat: addoeg strategy_dir property and protocol_type attribute for fo…
laserkelvin bddcd4d
refactor: updating methods in F1Metrology to serialize records
laserkelvin 90a7f04
refactor: adding option to pass a path to serialize metric records
laserkelvin 7813c44
feat: adding checkpoint_dir property to base protocol
laserkelvin b89d915
refactor: adding load_records interface to metrology
laserkelvin b0cc691
docs: adding load_checkpoint important message on strategy states
laserkelvin aba57bc
test: fixing expected exception in registry test
laserkelvin fda5468
Merge branch 'main' into active-learning-abstraction
dallasfoster 70e234d
docs: bumping year on SPDX header
laserkelvin f979b40
docs: addressing doctest issues
laserkelvin d77995c
docs: revising protocol doctests with capture methods and fixing regi…
laserkelvin 73a5fb8
Merge branch 'main' into active-learning-abstraction
ktangsali File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| active_learning_logs/ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,286 @@ | ||
| # Moons with Active Learning | ||
|
|
||
| This example is intended to give a quick, high level overview of one kind of active learning | ||
| experiments that can be put together using the `physicsnemo` active learning modules and | ||
| protocols. | ||
|
|
||
| The experiment that is being done in `moon_example.py` is to use a simple MLP classifier | ||
| to label 2D coordinates from the famous two-moons data distribution. The platform for | ||
| the experiment is to initially show the MLP a minimal set of data (with some class imbalance), | ||
| and use the prediction uncertainties from the model to query points that will be the most | ||
| informative to it. | ||
|
|
||
| The main thing to monitor in this experiment is the `f1_metrology.json` output, which is | ||
| a product of the `F1Metrology` strategy: in here, we compute precision/recall/F1 values | ||
| as a function of the number of active learning cycles. Due to class imbalance, the initial | ||
| precision will be quite poor as the predictions will heavily bias towards false negatives, | ||
| but as more samples (chosen by `ClassifierUQQuery`) are added to the training set, the | ||
| precision and subsequently F1 scores will improve. | ||
|
|
||
| ## Quick Start | ||
|
|
||
| To run this example: | ||
|
|
||
| ```bash | ||
| python moon_example.py | ||
| ``` | ||
|
|
||
| This will create an `active_learning_logs/<run_id>/` directory containing: | ||
|
|
||
| - **Model checkpoints**: `.mdlus` files saved according to `checkpoint_interval` | ||
| - **Driver logs**: `driver_log.json` tracking the active learning process | ||
| - **Metrology outputs**: `f1_metrology.json` with precision/recall/F1 scores over iterations | ||
| - **Console logs**: `console.log` with detailed execution logs | ||
|
|
||
| The `<run_id>` is an 8-character UUID prefix that uniquely identifies each run. You can | ||
| specify a custom `run_id` in `DriverConfig` if needed. | ||
|
|
||
| ## Implementation notes | ||
|
|
||
| To illustrate a simple active learning process, this example implements the bare necessary | ||
| ingredients: | ||
|
|
||
| 1. **Training step logic** - `training_step` function defines the per-batch training logic, | ||
| computing the loss from model predictions. This is passed to the `Driver` which uses it | ||
| within the training loop. | ||
|
|
||
| 2. **Training loop** - The example uses `DefaultTrainingLoop`, a built-in training loop | ||
| that handles epoch iteration, progress bars, validation, and static capture optimizations. | ||
| For reference, a custom `training_loop` function is also defined in the example but not used, | ||
| showing how you could implement your own if needed. | ||
|
|
||
| 3. **Query strategy** - `moon_strategies.ClassifierUQQuery` uses the classifier uncertainty | ||
| to rank data indices from the full (not in training) sample set, selecting points where | ||
| the model is most uncertain (predictions closest to 0.5). | ||
|
|
||
| 4. **Label strategy** - `DummyLabelStrategy` handles obtaining data labels. | ||
| Because ground truths are already known for this dataset, it's essentially a | ||
| no-op but the `Driver` pipeline relies on it to append labeled data to the | ||
| training set. | ||
|
|
||
| 5. **Metrology strategy** - `F1Metrology` computes precision/recall/F1 scores and serializes | ||
| them to JSON. This makes it easy to track how model performance improves with each active | ||
| learning iteration, helping inform hyperparameter choices for future experiments. | ||
|
|
||
| ## Configuration | ||
|
|
||
| The rest is configuration: we take the components we've written, and compose them in | ||
| the various configuration dataclasses in `moon_example.py::main`. | ||
|
|
||
| ### TrainingConfig | ||
|
|
||
| The `train_datapool` specifies what set of data to train on. We configure the training | ||
| loop using `DefaultTrainingLoop` with progress bars enabled. The `OptimizerConfig` specifies | ||
| which optimizer to use and its hyperparameters. You can configure different epoch counts | ||
| for initial training vs. subsequent fine-tuning iterations. | ||
|
|
||
| ```python | ||
| # configure how training/fine-tuning is done within active learning | ||
| training_config = c.TrainingConfig( | ||
| train_datapool=dataset, | ||
| optimizer_config=c.OptimizerConfig( | ||
| torch.optim.SGD, | ||
| optimizer_kwargs={"lr": 0.01}, | ||
| ), | ||
| # configure different times for initial training and subsequent | ||
| # fine-tuning | ||
| max_training_epochs=10, | ||
| max_fine_tuning_epochs=5, | ||
| # this configures the training loop | ||
| train_loop_fn=DefaultTrainingLoop( | ||
| use_progress_bars=True, | ||
| ), | ||
| ) | ||
| ``` | ||
|
|
||
| **Key options:** | ||
|
|
||
| - `max_training_epochs`: Epochs for initial training (step 0) | ||
| - `max_fine_tuning_epochs`: Epochs for subsequent fine-tuning steps | ||
| - `DefaultTrainingLoop(use_progress_bars=True)`: Built-in loop with tqdm progress bars | ||
| - `val_datapool`: Optional validation dataset (not used in this example) | ||
|
|
||
| ### StrategiesConfig | ||
|
|
||
| The `StrategiesConfig` localizes all of the different active learning components | ||
| into one place. The `queue_cls` is used to pipeline query samples to label processes. | ||
| Because we're carrying out a single process workflow, `queue.Queue` is sufficient, | ||
| but multiprocess variants, up to constructs like Redis Queue, can be used to pass | ||
| data around the pipeline. | ||
|
|
||
| ```python | ||
| strategy_config = c.StrategiesConfig( | ||
| query_strategies=[ClassifierUQQuery(max_samples=10)], | ||
| queue_cls=queue.Queue, | ||
| label_strategy=DummyLabelStrategy(), | ||
| metrology_strategies=[F1Metrology()], | ||
| ) | ||
| ``` | ||
|
|
||
| **Key components:** | ||
|
|
||
| - `query_strategies`: List of strategies for selecting samples (can have multiple) | ||
| - `queue_cls`: Queue implementation for passing data between phases (e.g., `queue.Queue`) | ||
| - `label_strategy`: Single strategy for labeling queried samples | ||
| - `metrology_strategies`: List of strategies for measuring model performance | ||
| - `unlabeled_datapool`: Optional pool of unlabeled data for query strategies (not shown here) | ||
|
|
||
| ### DriverConfig | ||
|
|
||
| Finally, the `DriverConfig` specifies orchestration parameters that control the overall | ||
| active learning loop execution: | ||
|
|
||
| ```python | ||
| driver_config = c.DriverConfig( | ||
| batch_size=16, | ||
| max_active_learning_steps=70, | ||
| fine_tuning_lr=0.005, | ||
| device=torch.device("cpu"), # set to other accelerators if needed | ||
| ) | ||
| driver = Driver( | ||
| config=driver_config, | ||
| learner=uq_model, | ||
| strategies_config=strategy_config, | ||
| training_config=training_config, | ||
| ) | ||
| # our model doesn't implement a `training_step` method but in principle | ||
| # it could be implemented, and we wouldn't need to pass the step function here | ||
| driver(train_step_fn=training_step) | ||
| ``` | ||
|
|
||
| **Key parameters:** | ||
|
|
||
| - `batch_size`: Batch size for training and validation dataloaders | ||
| - `max_active_learning_steps`: Total number of active learning iterations | ||
| - `fine_tuning_lr`: Learning rate to switch to after the first AL step (optional) | ||
| - `device`: Device for computation (e.g., `torch.device("cpu")`, `torch.device("cuda:0")`) | ||
| - `dtype`: Data type for tensors (defaults to `torch.get_default_dtype()`) | ||
| - `skip_training`: Set to `True` to skip training phase (default: `False`) | ||
| - `skip_metrology`: Set to `True` to skip metrology phase (default: `False`) | ||
| - `skip_labeling`: Set to `True` to skip labeling phase (default: `False`) | ||
| - `checkpoint_interval`: Save model every N steps (default: 1, set to 0 to disable) | ||
| - `root_log_dir`: Directory for logs and checkpoints (default: `"active_learning_logs"`) | ||
| - `dist_manager`: Optional `DistributedManager` for multi-GPU training | ||
|
|
||
| ### Running the Driver | ||
|
|
||
| The final `driver(...)` call is syntactic sugar for `driver.run(...)`, which executes the | ||
| full active learning loop. The `train_step_fn` argument provides the per-batch training logic. | ||
|
|
||
| **Two ways to provide training logic:** | ||
|
|
||
| 1. **Pass as function** (shown in example): | ||
|
|
||
| ```python | ||
| driver(train_step_fn=training_step) | ||
| ``` | ||
|
|
||
| 1. **Implement in model** (alternative): | ||
|
|
||
| ```python | ||
| class MLP(Module): | ||
| def training_step(self, data): | ||
| # training logic here | ||
| ... | ||
|
|
||
| driver() # no train_step_fn needed | ||
| ``` | ||
|
|
||
| **Optional validation step:** | ||
|
|
||
| You can also provide a `validate_step_fn` parameter: | ||
|
|
||
| ```python | ||
| driver(train_step_fn=training_step, validate_step_fn=validation_step) | ||
| ``` | ||
|
|
||
| ### Active Learning Workflow | ||
|
|
||
| Under the hood, `Driver.active_learning_step` is called repeatedly for the number of | ||
| iterations specified in `max_active_learning_steps`. Each iteration follows this sequence: | ||
|
|
||
| 1. **Training Phase**: Train model on current `train_datapool` using the | ||
| training loop | ||
| 2. **Metrology Phase**: Compute performance metrics via metrology strategies | ||
| 3. **Query Phase**: Select new samples to label via query strategies → | ||
| `query_queue` | ||
| 4. **Labeling Phase**: Label queued samples via label strategy → `label_queue` | ||
| → append to `train_datapool` | ||
|
|
||
| The logic for each phase is in methods like `Driver._training_phase`, | ||
| `Driver._query_phase`, etc. | ||
|
|
||
| ## Advanced Customization | ||
|
|
||
| ### Custom Training Loops | ||
|
|
||
| While `DefaultTrainingLoop` is suitable for most use cases, you can write | ||
| custom training loops that implement the `TrainingLoop` protocol, which is the | ||
| overarching logic for how to carry out model training and validation over some | ||
| number of epochs. Custom loops are useful when you need: | ||
|
|
||
| - Specialized training logic (e.g., alternating, or multiple optimizers) | ||
| - Custom logging or checkpointing within the loop | ||
| - Non-standard epoch/batch iteration patterns | ||
|
|
||
| ### Custom Strategies | ||
|
|
||
| All strategies must implement their respective protocols: | ||
|
|
||
| - **QueryStrategy**: Implement `sample(query_queue, *args, **kwargs)` and | ||
| `attach(driver)` | ||
| - **LabelStrategy**: Implement `label(queue_to_label, serialize_queue, *args, | ||
| **kwargs)` and `attach(driver)` | ||
| - **MetrologyStrategy**: Implement `compute(*args, **kwargs)`, | ||
| `serialize_records(*args, **kwargs)`, and `attach(driver)` | ||
|
|
||
| The `attach(driver)` method gives your strategy access to the driver's | ||
| attributes like `driver.learner`, `driver.train_datapool`, | ||
| `driver.unlabeled_datapool`, etc. | ||
|
|
||
| ### Static Capture and Performance | ||
|
|
||
| The `DefaultTrainingLoop` supports static capture via CUDA graphs for | ||
| performance optimization: | ||
|
|
||
| ```python | ||
| train_loop_fn=DefaultTrainingLoop( | ||
| enable_static_capture=True, # Enable CUDA graph capture (default) | ||
| use_progress_bars=True, | ||
| ) | ||
| ``` | ||
|
|
||
| For custom training loops, you can use: | ||
|
|
||
| - `StaticCaptureTraining` for training steps | ||
| - `StaticCaptureEvaluateNoGrad` for validation/inference steps | ||
|
|
||
| ### Distributed Training | ||
|
|
||
| To use multiple GPUs, provide a `DistributedManager` in `DriverConfig`: | ||
|
|
||
| ```python | ||
| from physicsnemo.distributed import DistributedManager | ||
|
|
||
| dist_manager = DistributedManager() | ||
| driver_config = c.DriverConfig( | ||
| batch_size=16, | ||
| max_active_learning_steps=70, | ||
| dist_manager=dist_manager, # Handles device placement and DDP | ||
| ) | ||
| ``` | ||
|
|
||
| The driver will automatically wrap the model in `DistributedDataParallel` and use | ||
| `DistributedSampler` for dataloaders. | ||
|
|
||
| ## Experiment Ideas | ||
|
|
||
| Here, we perform a relatively straightforward experiment without a baseline; suitable | ||
| ones could be to train a model using the full data, and see how the precision/recall/F1 | ||
| scores differ between the `ClassifierUQQuery` learner to the full data model (i.e. use | ||
| the latter as a roofline). | ||
|
|
||
| A suitable baseline to compare against would be random selection: to check the efficacy | ||
| of `ClassifierUQQuery`, samples could be chosen uniformly and see if and how the same | ||
| metrology scores differ. If the UQ is performing as intended, then precision/recall/F1 | ||
| should improve at a faster rate. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.