diff --git a/CHANGELOG.md b/CHANGELOG.md index d15fd6de81..6044555ece 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 and three transient schemes. - Added a check to `stochastic_sampler` that helps handle the `EDMPrecond` model, which has a specific `.forward()` signature +- Added abstract interfaces for constructing active learning workflows, contained + under the `physicsnemo.active_learning` namespace. A preliminary example of how + to compose and define an active learning workflow is provided in `examples/active_learning`. + The `moons` example provides a minimal (pedagogical) composition that is meant to + illustrate how to define the necessary parts of the workflow. ### Changed diff --git a/examples/active_learning/moons/.gitignore b/examples/active_learning/moons/.gitignore new file mode 100644 index 0000000000..d535416f2e --- /dev/null +++ b/examples/active_learning/moons/.gitignore @@ -0,0 +1 @@ +active_learning_logs/ \ No newline at end of file diff --git a/examples/active_learning/moons/README.md b/examples/active_learning/moons/README.md new file mode 100644 index 0000000000..f7dc64b3ff --- /dev/null +++ b/examples/active_learning/moons/README.md @@ -0,0 +1,286 @@ +# Moons with Active Learning + +This example is intended to give a quick, high level overview of one kind of active learning +experiments that can be put together using the `physicsnemo` active learning modules and +protocols. + +The experiment that is being done in `moon_example.py` is to use a simple MLP classifier +to label 2D coordinates from the famous two-moons data distribution. The platform for +the experiment is to initially show the MLP a minimal set of data (with some class imbalance), +and use the prediction uncertainties from the model to query points that will be the most +informative to it. + +The main thing to monitor in this experiment is the `f1_metrology.json` output, which is +a product of the `F1Metrology` strategy: in here, we compute precision/recall/F1 values +as a function of the number of active learning cycles. Due to class imbalance, the initial +precision will be quite poor as the predictions will heavily bias towards false negatives, +but as more samples (chosen by `ClassifierUQQuery`) are added to the training set, the +precision and subsequently F1 scores will improve. + +## Quick Start + +To run this example: + +```bash +python moon_example.py +``` + +This will create an `active_learning_logs//` directory containing: + +- **Model checkpoints**: `.mdlus` files saved according to `checkpoint_interval` +- **Driver logs**: `driver_log.json` tracking the active learning process +- **Metrology outputs**: `f1_metrology.json` with precision/recall/F1 scores over iterations +- **Console logs**: `console.log` with detailed execution logs + +The `` is an 8-character UUID prefix that uniquely identifies each run. You can +specify a custom `run_id` in `DriverConfig` if needed. + +## Implementation notes + +To illustrate a simple active learning process, this example implements the bare necessary +ingredients: + +1. **Training step logic** - `training_step` function defines the per-batch training logic, +computing the loss from model predictions. This is passed to the `Driver` which uses it +within the training loop. + +2. **Training loop** - The example uses `DefaultTrainingLoop`, a built-in training loop +that handles epoch iteration, progress bars, validation, and static capture optimizations. +For reference, a custom `training_loop` function is also defined in the example but not used, +showing how you could implement your own if needed. + +3. **Query strategy** - `moon_strategies.ClassifierUQQuery` uses the classifier uncertainty +to rank data indices from the full (not in training) sample set, selecting points where +the model is most uncertain (predictions closest to 0.5). + +4. **Label strategy** - `DummyLabelStrategy` handles obtaining data labels. +Because ground truths are already known for this dataset, it's essentially a +no-op but the `Driver` pipeline relies on it to append labeled data to the +training set. + +5. **Metrology strategy** - `F1Metrology` computes precision/recall/F1 scores and serializes +them to JSON. This makes it easy to track how model performance improves with each active +learning iteration, helping inform hyperparameter choices for future experiments. + +## Configuration + +The rest is configuration: we take the components we've written, and compose them in +the various configuration dataclasses in `moon_example.py::main`. + +### TrainingConfig + +The `train_datapool` specifies what set of data to train on. We configure the training +loop using `DefaultTrainingLoop` with progress bars enabled. The `OptimizerConfig` specifies +which optimizer to use and its hyperparameters. You can configure different epoch counts +for initial training vs. subsequent fine-tuning iterations. + +```python +# configure how training/fine-tuning is done within active learning +training_config = c.TrainingConfig( + train_datapool=dataset, + optimizer_config=c.OptimizerConfig( + torch.optim.SGD, + optimizer_kwargs={"lr": 0.01}, + ), + # configure different times for initial training and subsequent + # fine-tuning + max_training_epochs=10, + max_fine_tuning_epochs=5, + # this configures the training loop + train_loop_fn=DefaultTrainingLoop( + use_progress_bars=True, + ), +) +``` + +**Key options:** + +- `max_training_epochs`: Epochs for initial training (step 0) +- `max_fine_tuning_epochs`: Epochs for subsequent fine-tuning steps +- `DefaultTrainingLoop(use_progress_bars=True)`: Built-in loop with tqdm progress bars +- `val_datapool`: Optional validation dataset (not used in this example) + +### StrategiesConfig + +The `StrategiesConfig` localizes all of the different active learning components +into one place. The `queue_cls` is used to pipeline query samples to label processes. +Because we're carrying out a single process workflow, `queue.Queue` is sufficient, +but multiprocess variants, up to constructs like Redis Queue, can be used to pass +data around the pipeline. + +```python +strategy_config = c.StrategiesConfig( + query_strategies=[ClassifierUQQuery(max_samples=10)], + queue_cls=queue.Queue, + label_strategy=DummyLabelStrategy(), + metrology_strategies=[F1Metrology()], +) +``` + +**Key components:** + +- `query_strategies`: List of strategies for selecting samples (can have multiple) +- `queue_cls`: Queue implementation for passing data between phases (e.g., `queue.Queue`) +- `label_strategy`: Single strategy for labeling queried samples +- `metrology_strategies`: List of strategies for measuring model performance +- `unlabeled_datapool`: Optional pool of unlabeled data for query strategies (not shown here) + +### DriverConfig + +Finally, the `DriverConfig` specifies orchestration parameters that control the overall +active learning loop execution: + +```python +driver_config = c.DriverConfig( + batch_size=16, + max_active_learning_steps=70, + fine_tuning_lr=0.005, + device=torch.device("cpu"), # set to other accelerators if needed +) +driver = Driver( + config=driver_config, + learner=uq_model, + strategies_config=strategy_config, + training_config=training_config, +) +# our model doesn't implement a `training_step` method but in principle +# it could be implemented, and we wouldn't need to pass the step function here +driver(train_step_fn=training_step) +``` + +**Key parameters:** + +- `batch_size`: Batch size for training and validation dataloaders +- `max_active_learning_steps`: Total number of active learning iterations +- `fine_tuning_lr`: Learning rate to switch to after the first AL step (optional) +- `device`: Device for computation (e.g., `torch.device("cpu")`, `torch.device("cuda:0")`) +- `dtype`: Data type for tensors (defaults to `torch.get_default_dtype()`) +- `skip_training`: Set to `True` to skip training phase (default: `False`) +- `skip_metrology`: Set to `True` to skip metrology phase (default: `False`) +- `skip_labeling`: Set to `True` to skip labeling phase (default: `False`) +- `checkpoint_interval`: Save model every N steps (default: 1, set to 0 to disable) +- `root_log_dir`: Directory for logs and checkpoints (default: `"active_learning_logs"`) +- `dist_manager`: Optional `DistributedManager` for multi-GPU training + +### Running the Driver + +The final `driver(...)` call is syntactic sugar for `driver.run(...)`, which executes the +full active learning loop. The `train_step_fn` argument provides the per-batch training logic. + +**Two ways to provide training logic:** + +1. **Pass as function** (shown in example): + + ```python + driver(train_step_fn=training_step) + ``` + +1. **Implement in model** (alternative): + + ```python + class MLP(Module): + def training_step(self, data): + # training logic here + ... + + driver() # no train_step_fn needed + ``` + +**Optional validation step:** + +You can also provide a `validate_step_fn` parameter: + +```python +driver(train_step_fn=training_step, validate_step_fn=validation_step) +``` + +### Active Learning Workflow + +Under the hood, `Driver.active_learning_step` is called repeatedly for the number of +iterations specified in `max_active_learning_steps`. Each iteration follows this sequence: + +1. **Training Phase**: Train model on current `train_datapool` using the +training loop +2. **Metrology Phase**: Compute performance metrics via metrology strategies +3. **Query Phase**: Select new samples to label via query strategies → +`query_queue` +4. **Labeling Phase**: Label queued samples via label strategy → `label_queue` +→ append to `train_datapool` + +The logic for each phase is in methods like `Driver._training_phase`, +`Driver._query_phase`, etc. + +## Advanced Customization + +### Custom Training Loops + +While `DefaultTrainingLoop` is suitable for most use cases, you can write +custom training loops that implement the `TrainingLoop` protocol, which is the +overarching logic for how to carry out model training and validation over some +number of epochs. Custom loops are useful when you need: + +- Specialized training logic (e.g., alternating, or multiple optimizers) +- Custom logging or checkpointing within the loop +- Non-standard epoch/batch iteration patterns + +### Custom Strategies + +All strategies must implement their respective protocols: + +- **QueryStrategy**: Implement `sample(query_queue, *args, **kwargs)` and +`attach(driver)` +- **LabelStrategy**: Implement `label(queue_to_label, serialize_queue, *args, +**kwargs)` and `attach(driver)` +- **MetrologyStrategy**: Implement `compute(*args, **kwargs)`, +`serialize_records(*args, **kwargs)`, and `attach(driver)` + +The `attach(driver)` method gives your strategy access to the driver's +attributes like `driver.learner`, `driver.train_datapool`, +`driver.unlabeled_datapool`, etc. + +### Static Capture and Performance + +The `DefaultTrainingLoop` supports static capture via CUDA graphs for +performance optimization: + +```python +train_loop_fn=DefaultTrainingLoop( + enable_static_capture=True, # Enable CUDA graph capture (default) + use_progress_bars=True, +) +``` + +For custom training loops, you can use: + +- `StaticCaptureTraining` for training steps +- `StaticCaptureEvaluateNoGrad` for validation/inference steps + +### Distributed Training + +To use multiple GPUs, provide a `DistributedManager` in `DriverConfig`: + +```python +from physicsnemo.distributed import DistributedManager + +dist_manager = DistributedManager() +driver_config = c.DriverConfig( + batch_size=16, + max_active_learning_steps=70, + dist_manager=dist_manager, # Handles device placement and DDP +) +``` + +The driver will automatically wrap the model in `DistributedDataParallel` and use +`DistributedSampler` for dataloaders. + +## Experiment Ideas + +Here, we perform a relatively straightforward experiment without a baseline; suitable +ones could be to train a model using the full data, and see how the precision/recall/F1 +scores differ between the `ClassifierUQQuery` learner to the full data model (i.e. use +the latter as a roofline). + +A suitable baseline to compare against would be random selection: to check the efficacy +of `ClassifierUQQuery`, samples could be chosen uniformly and see if and how the same +metrology scores differ. If the UQ is performing as intended, then precision/recall/F1 +should improve at a faster rate. diff --git a/examples/active_learning/moons/moon_data.py b/examples/active_learning/moons/moon_data.py new file mode 100644 index 0000000000..8811696304 --- /dev/null +++ b/examples/active_learning/moons/moon_data.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Module that defines the classic two-moon classification dataset to +use as a demonstration dataset for a minimal active learning workflow. +""" + +import torch +from torch.utils.data import Dataset + +__all__ = ["MoonsDataset"] + + +def make_moons(n_samples: int = 2000, sigma: float = 0.25) -> torch.Tensor: + """ + Make the classic two-moon dataset. Code was adapted from + ``sklearn``, but modified slightly for the purposes of + this particular example. + + Parameters + ---------- + n_samples: int + The number of samples to generate. + + Returns + ------- + X_values: torch.Tensor + The input features. + y_values: torch.Tensor + The target labels. + """ + outer_grid = torch.linspace(0.0, torch.pi, int(n_samples * 0.7)) + inner_grid = torch.linspace(0.0, torch.pi, int(n_samples * 0.3)) + outer_x = torch.cos(outer_grid) + outer_y = torch.sin(outer_grid) + inner_x = 1 - torch.cos(inner_grid) + inner_y = 1 - torch.sin(inner_grid) - 0.5 + outer = torch.stack([outer_x, outer_y], dim=-1) + inner = torch.stack([inner_x, inner_y], dim=-1) + X_values = torch.cat([outer, inner], dim=0) + # add some noise to the coordinates + X_values += torch.randn_like(X_values) * sigma + y_values = torch.zeros(n_samples) + y_values[outer_grid.shape[0] :] = 1 + return X_values, y_values + + +class MoonsDataset(Dataset): + """ + Generate the classic two-moon dataset, repurposed for a minimal + active learning example. + + This class implements the `DataPool` protocol by subclassing + ``Dataset``, which provides all the methods except for ``append``, + which we implement here. + + The intuition is to have one of the moons be data poor, and a quasi- + intelligent query strategy will help overcome class imbalance to + some extent, as it will hopefully have higher uncertainty in its + classifier output to reflect this. + + Attributes + ---------- + initial_samples: float + The initial number of samples to hold out for training. + total_samples: int + The total number of samples to generate. + train_indices: torch.LongTensor | None + The indices of the samples to use for training. + X_values: torch.Tensor + The full set of input features; i.e. the coordinates + of a point in 2D space. + y_values: torch.Tensor + The target labels; 0 for the outer moon, 1 for the inner moon. + sigma: float + The standard deviation of the noise to add to the coordinates. + """ + + def __init__( + self, + initial_samples: float = 0.05, + total_samples: int = 1000, + train_indices: torch.LongTensor | None = None, + sigma: float = 0.25, + ): + super().__init__() + self.initial_samples = initial_samples + self.total_samples = total_samples + # this holds the full dataset for training + self.X_values, self.y_values = make_moons(total_samples, sigma) + # this corresponds to the subset that is actually exposed + # during training; it grows as we 'label' more samples + if train_indices is None: + # initial hold out for training + train_indices = torch.randperm(total_samples)[ + : int(total_samples * initial_samples) + ] + self.train_indices = train_indices + + def __len__(self) -> int: + """Return the length of the training subset.""" + return len(self.train_indices) + + def _sample_indices(self) -> torch.LongTensor: + """Return the indices that are not currently in training.""" + all_indices = torch.arange(self.total_samples) + mask = ~torch.isin(all_indices, self.train_indices) + return all_indices[mask] + + def append(self, item: int) -> None: + """Append a single index to the training set; needed for 'labeling'.""" + self.train_indices = torch.cat( + [self.train_indices, torch.tensor([item])], dim=0 + ) + + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: + """Retrieve a single coordinate-label pair from the dataset.""" + actual_index = self.train_indices[index] + x_val = self.X_values[actual_index, :] + y_val = self.y_values[actual_index] + return x_val, y_val diff --git a/examples/active_learning/moons/moon_example.py b/examples/active_learning/moons/moon_example.py new file mode 100644 index 0000000000..9fd0605296 --- /dev/null +++ b/examples/active_learning/moons/moon_example.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Minimal example to show how the active learning workflow can be put together, +comprising the minimum of model training, and running some query strategy +to select data samples for labeling. + +This example implements a simple MLP that takes in 2D coordinates and +outputs logits for a binary classification task. The active learning +workflow here uses a query strategy that looks for samples that have +the highest classification uncertainty (i.e. closest to 0.5), and iterates +by adding those samples to the training set. Ideally, if uncertainty is +well-adjusted to this problem, then the query strategy will select samples +that are more likely to improve the model's general performance, as compared +to a random selection baseline. +""" + +import queue +import time + +import torch +from moon_data import MoonsDataset +from moon_strategies import ClassifierUQQuery, DummyLabelStrategy, F1Metrology +from torch import nn + +from physicsnemo import ModelMetaData, Module +from physicsnemo.active_learning import Driver, registry +from physicsnemo.active_learning import config as c +from physicsnemo.active_learning.loop import DefaultTrainingLoop + +torch.manual_seed(216167) + + +@registry.register("MLP") +class MLP(Module): + """ + Define a trivial MLIP model that will classify a 2D coordinate + into one of two classes, producing logits as the output. + + There is nothing to configure here, so focus on the active learning + components. + """ + + def __init__(self): + super().__init__(meta=ModelMetaData(amp=False)) + self.layers = nn.Sequential( + nn.Linear(2, 16), + nn.SiLU(), + nn.Linear(16, 1), + ) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the model. + + Parameters + ---------- + x: torch.Tensor + The input tensor, shape [B, 2] for a batch + size of B. + + Returns + ------- + torch.Tensor + The output tensor, shape [B, 1] for a batch + size of B. Remember to ``squeeze`` the output. + """ + return self.layers(x) + + +# this implements the `TrainingProtocol` interface +@registry.register("training_step") +def training_step(model: MLP, data: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """ + Implements the training logic for a single batch of data. + + Parameters + ---------- + model: MLP + The model to train. + data: tuple[torch.Tensor, torch.Tensor] + The data to train on. + + Returns + ------- + torch.Tensor + The loss tensor. + """ + x, y = data + logits = model(x).squeeze() + loss = model.loss_fn(logits, y) + return loss + + +def main(): + """ + Configure an end-to-end active learning workflow. + + The code below primarily demonstrates how to compose things together + to form the full workflow. There are three configurations structures + that ultimately dictate the behavior of ``Driver``, which orchestrates + the workflow: + + 1. ``TrainingConfig``: everything to do with the model training process. + 2. ``StrategiesConfig``: comprises the query, label, and metrology strategies. + 3. ``DriverConfig``: decides things like batch size, logging, and ``DistributedManager``. + + The workflow should completely quickly: an `active_learning_logs` folder will + be created, and within it, run-specific logs. You will find the model weights, + alongside JSON logs of the process and from the ``F1Metrology`` strategy, which + will records how precision/recall progresses as more data points are added to the + strategy. + """ + # instantiate the model and data + dataset = MoonsDataset() + uq_model = MLP() + + # configure how training/fine-tuning is done within active learning + training_config = c.TrainingConfig( + train_datapool=dataset, + optimizer_config=c.OptimizerConfig( + torch.optim.SGD, + optimizer_kwargs={"lr": 0.01}, + ), + # configure different times for initial training and subsequent + # fine-tuning + max_training_epochs=30, + max_fine_tuning_epochs=30, + # this configures the training loop + train_loop_fn=DefaultTrainingLoop( + use_progress_bars=False, + enable_static_capture=False, + ), + ) + # this configuration packs all the strategy components together + strategy_config = c.StrategiesConfig( + query_strategies=[ClassifierUQQuery(max_samples=10)], + queue_cls=queue.Queue, + label_strategy=DummyLabelStrategy(), + metrology_strategies=[F1Metrology()], + ) + # this driver class handles the active learning loop + driver_config = c.DriverConfig( + batch_size=16, + max_active_learning_steps=70, + fine_tuning_lr=0.005, + device=torch.device("cpu"), # set to other accelerators if needed + ) + driver = Driver( + config=driver_config, + learner=uq_model, + strategies_config=strategy_config, + training_config=training_config, + ) + # our model doesn't implement a `training_step` method but in principle + # it could be implemented, and we wouldn't need to pass the step function here + driver(train_step_fn=training_step) + + # just some sanity checks + if not ( + len(dataset.train_indices) + == int(dataset.initial_samples * dataset.total_samples) + + driver_config.max_active_learning_steps + * strategy_config.query_strategies[0].max_samples + ): + raise RuntimeError( + "Number of samples added to the training pool inconsistent with expected value." + ) + + # restart the driver from a checkpoint; in practice the path would be provided + # train_datapool must be provided since it's not serialized + # learner must nominally have the same architecture as the one used to create the checkpoint + new_driver = Driver.load_checkpoint( + driver.log_dir / "checkpoints" / "step_42" / "labeling", + learner=uq_model, + train_datapool=dataset, + ) + assert new_driver.active_learning_step_idx == 42 + # enable this to re-run the driver training: be aware that this will overwrite subsequent checkpoints!! + RERUN = True + if RERUN: + new_driver.logger.info( + f"Rerunning driver from checkpoint {new_driver.last_checkpoint}" + ) + time.sleep(5) + new_driver(train_step_fn=training_step) + + +if __name__ == "__main__": + main() diff --git a/examples/active_learning/moons/moon_strategies.py b/examples/active_learning/moons/moon_strategies.py new file mode 100644 index 0000000000..8d96b7f6bb --- /dev/null +++ b/examples/active_learning/moons/moon_strategies.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from queue import Queue +from typing import Any + +import torch + +from physicsnemo.active_learning import registry +from physicsnemo.active_learning.protocols import ( + DriverProtocol, + LabelStrategy, + MetrologyStrategy, + QueryStrategy, +) + +__all__ = ["ClassifierUQQuery", "DummyLabelStrategy", "F1Metrology"] + + +@registry.register("ClassifierUQQuery") +class ClassifierUQQuery(QueryStrategy): + """ + This query strategy is representative of a more complex + uncertainty-based query strategy: since our model produces + logits, we can use the model's confidence in class label + predictions to select data points for labeling: specifically, + we pick ``max_samples`` each active learning iteration of + the data points with the most uncertainty (closest to 0.5). + """ + + def __init__(self, max_samples: int): + """ + Initialize the query strategy. + + Parameters + ---------- + max_samples: int + The maximum number of samples to query. + """ + self.max_samples = max_samples + + def sample(self, query_queue: Queue) -> None: + """ + Identify which data points that need labels by the query strategy. + + At a high level, this method will: + 1. Slice out the data indices not currently in the training set, + 2. Query the model for predictions on the 'unlabeled' data, + 3. Enqueue indices of data points with the class predictions closest to 0.5. + + Parameters + ---------- + query_queue: Queue + The queue to enqueue data to be labeled. + """ + # strategy will be attached to a driver to access model and data + model = self.driver.learner + data = self.driver.train_datapool + unlabeled_indices = data._sample_indices() + # grab all of the data that's currently not labeled and obtain + # predictions from the model + unlabeled_coords = data.X_values[unlabeled_indices] + unlabeled_coords = unlabeled_coords.to(model.device) + model.eval() + with torch.no_grad(): + pred_logits = model(unlabeled_coords) + pred_probs = torch.sigmoid(pred_logits).squeeze() + # find probabilities that are closet to 0.5; the lower this + # value is, the more uncertain the model is + uncertainties = torch.abs(pred_probs - 0.5) + chosen_indices = torch.argsort(uncertainties)[: self.max_samples] + # enqueue indices of the chosen data points + for idx in chosen_indices: + query_queue.put(unlabeled_indices[idx]) + + def attach(self, driver: DriverProtocol) -> None: + """Attach the driver to the query strategy.""" + self.driver = driver + + +@registry.register("DummyLabelStrategy") +class DummyLabelStrategy(LabelStrategy): + """ + Since we have labels for all of our data already, this label strategy + will simply just add the data points our model has chosen to the + training set. + """ + + __is_external_process__ = False + + def __init__(self): + super().__init__() + + def label(self, query_queue: Queue, serialize_queue: Queue) -> None: + """ + Label the data points in the query queue. + + This is trivial because we are just passing indices from one queue + to another, but in a real implementation this might call an external + process to obtain ground truth data for a set of data points. + + Parameters + ---------- + query_queue: Queue + The queue to dequeue data from. + serialize_queue: Queue + The queue to enqueue labeled data to. + """ + while not query_queue.empty(): + selected_idx = query_queue.get() + serialize_queue.put(selected_idx) + + def attach(self, driver: DriverProtocol) -> None: + """Attach the driver to the label strategy.""" + self.driver = driver + + +@registry.register("F1Metrology") +class F1Metrology(MetrologyStrategy): + """ + While metrology is optional in the workflow, this provides observability + into how the model is performing over the course of active learning. + + For a simple use case like the Moons dataset, the margin between validation + and metrology is small, but for more complex use cases this strategy can + potentially represent a workflow beyond simple metrics (e.g. using the model + as a surrogate in a simulation loop). + """ + + def __init__(self): + self.records = [] + + def compute(self, *args: Any, **kwargs: Any) -> None: + """Compute the F1 score of the model on the validation set.""" + model = self.driver.learner + data = self.driver.train_datapool # this can be any `DataPool` + model.eval() + indices = torch.arange(data.total_samples) + input_data, labels = data.X_values[indices], data.y_values[indices] + input_data = input_data.to(model.device) + labels = labels.to(model.device) + with torch.no_grad(): + # pack the entire dataset into a single batch + pred_logits = model(input_data) + pred_probs = torch.sigmoid(pred_logits).squeeze() + pred_labels = torch.round(pred_probs) + precision = self.precision(pred_labels, labels) + recall = self.recall(pred_labels, labels) + # compute the F1 score + f1 = 2 * (precision * recall) / (precision + recall + 1e-8) + iteration = self.driver.active_learning_step_idx + num_train_samples = len(self.driver.train_datapool.train_indices) + report = { + "precision": precision, + "recall": recall, + "f1": f1, + "step": iteration, + "num_train_samples": num_train_samples, + } + self.append(report) + + @staticmethod + def precision(pred_labels: torch.Tensor, true_labels: torch.Tensor) -> float: + """ + Calculate precision for class 0. + + Precision is the ratio of true positives to all predicted positives: + how many of the samples predicted as class 0 are actually class 0. + + Parameters + ---------- + pred_labels : torch.Tensor + Predicted binary labels (0 or 1). + true_labels : torch.Tensor + Ground truth binary labels (0 or 1). + + Returns + ------- + float + Precision score for class 0. + """ + true_positives = ((true_labels == 1) & (pred_labels == 1)).sum().item() + predicted_positives = (pred_labels == 1).sum().item() + if predicted_positives == 0: + return 0.0 + return true_positives / predicted_positives + + @staticmethod + def recall(pred_labels: torch.Tensor, true_labels: torch.Tensor) -> float: + """ + Calculate recall for class 0. + + Recall is the ratio of true positives to all actual positives: + how many of the actual class 0 samples were predicted as class 0. + + Parameters + ---------- + pred_labels : torch.Tensor + Predicted binary labels (0 or 1). + true_labels : torch.Tensor + Ground truth binary labels (0 or 1). + + Returns + ------- + float + Recall score for class 0. + """ + true_positives = ((pred_labels == 0) & (true_labels == 0)).sum().item() + actual_positives = (true_labels == 0).sum().item() + if actual_positives == 0: + return 0.0 + return true_positives / actual_positives + + def attach(self, driver: DriverProtocol) -> None: + """Attach the driver to the metrology strategy.""" + self.driver = driver + + @property + def is_attached(self) -> bool: + """Check if the metrology strategy is attached to a driver.""" + return hasattr(self, "driver") + + def serialize_records(self, *args: Any, **kwargs: Any) -> None: + """Serialize the records of the metrology strategy.""" + output_path = self.strategy_dir / f"step_{self.driver.active_learning_step_idx}" + output_path.mkdir(parents=True, exist_ok=True) + with open(output_path / "f1_metrology.json", "w") as f: + json.dump(self.records, f, indent=2) diff --git a/physicsnemo/active_learning/README.md b/physicsnemo/active_learning/README.md new file mode 100644 index 0000000000..eb53aee190 --- /dev/null +++ b/physicsnemo/active_learning/README.md @@ -0,0 +1,66 @@ +# Active Learning Module + +The `physicsnemo.active_learning` namespace is used for defining the "scaffolding" +that can be used to construct automated, end-to-end active learning workflows. +For areas of science that are difficult to source ground-truths to train on +(of which there are many), an active learning curriculum attempts to train a +model with improved data efficiency; better generalization performance but requiring +fewer training samples. + +Generally, an active learning workflow can be decomposed into three "phases" +that are - in the simplest case - run sequentially: + +- **Training/fine-tuning**: A "learner" or surrogate model is initially trained +on available data, and in subsequent active learning iterations, is fine-tuned +with the new data appended on the original dataset. +- **Querying**: One or more strategies that encode some heuristics for what +new data is most informative for the learner. Examples of this include +uncertainty-based methods, which may screen a pool of unlabeled data for +those the model is least confident with. +- **Labeling**: A method of obtaining ground truth (labels) for new data +points, pipelined from the querying stage. This may entail running an +expensive solver, or acquiring experimental data. + +The three phases are repeated until the learner converges. Because "convergence" +may not be easily defined, we define an additional phase which we call +**metrology**: this represents a phase most similar to querying, but allows +a user to define some set of criteria to monitor over the course of active +learning *beyond* simple validation metrics to ensure the model can be used +with confidence as surrogates (e.g. within a simulation loop). + +## How to use this module + +With the context above in mind, inspecting the `driver` module will give you +a sense for how the end-to-end workflow functions; the `Driver` class acts +as an orchestrator for all the phases of active learning we described above. + +From there, you should realize that `Driver` is written in a highly abstract +way: we need concrete *strategies* that implement querying, labeling, and metrology +concepts. The `protocols` module provides the scaffolding to do so - we implement +various components as `typing.Protocol` which are used for structural sub-typing: +they can be thought of as abstract classes that define an expected interface +in a function or class from which you can define your own classes by either +inheriting from them, or defining your own class that implements the expected +methods and attributes. + +In order to perform the training portion of active learning, we provide a +minimal yet functional `DefaultTrainingLoop` inside the `loop` module. This +loop simply requires a `protocols.TrainingProtocol` to be passed, which is +a function that defines the logic for computing the loss per batch/training +step. + +## Configuring workflows + +The `config` module defines some simple `dataclass`es that can be used +to configure the behavior of various parts of active learning, e.g. how +training is conducted, etc. Because `Driver` is designed to be checkpointable, +with the exception of a few parts such as datasets, everything should be +JSON-serializable. + +## Restarting workflows + +For classes and functions that are created at runtime, checkpointing requires +that these components can be recreated when restarting from a checkpoint. To +that end, the `_registry` module provides a user-friendly way to instantiate +objects: user-defined strategy classes can be added to the registry to enable +their creation in checkpoint restarts. diff --git a/physicsnemo/active_learning/__init__.py b/physicsnemo/active_learning/__init__.py new file mode 100644 index 0000000000..8563d24c66 --- /dev/null +++ b/physicsnemo/active_learning/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from physicsnemo.active_learning._registry import registry +from physicsnemo.active_learning.config import ( + DriverConfig, + OptimizerConfig, + StrategiesConfig, + TrainingConfig, +) +from physicsnemo.active_learning.driver import Driver +from physicsnemo.active_learning.loop import DefaultTrainingLoop + +__all__ = [ + "registry", + "Driver", + "DefaultTrainingLoop", + "DriverConfig", + "OptimizerConfig", + "StrategiesConfig", + "TrainingConfig", +] diff --git a/physicsnemo/active_learning/_registry.py b/physicsnemo/active_learning/_registry.py new file mode 100644 index 0000000000..b9f137c367 --- /dev/null +++ b/physicsnemo/active_learning/_registry.py @@ -0,0 +1,332 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import inspect +from typing import Any, Callable +from warnings import warn + +from physicsnemo.active_learning.protocols import ActiveLearningProtocol + +__all__ = ["registry"] + + +class ActiveLearningRegistry: + """ + Registry for active learning protocols. + + This class provides a centralized registry for user-defined active learning + protocols that implement the `ActiveLearningProtocol`. It enables string-based + lookups for checkpointing and provides argument validation when constructing + protocol instances. + + The registry supports two primary modes of interaction: + 1. Registration via decorator: `@registry.register("my_strategy")` + 2. Construction with validation: `registry.construct("my_strategy", **kwargs)` + + Attributes + ---------- + _registry : dict[str, type[ActiveLearningProtocol]] + Internal dictionary mapping protocol names to their class types. + + Methods + ------- + register(cls_name: str) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]] + Decorator to register a protocol class with a given name. + construct(cls_name: str, **kwargs) -> ActiveLearningProtocol + Construct an instance of a registered protocol with argument validation. + is_registered(cls_name: str) -> bool + Check if a protocol name is registered. + + Properties + ---------- + registered_names : list[str] + A list of all registered protocol names, sorted alphabetically. + + Examples + -------- + Register a custom strategy: + + >>> from physicsnemo.active_learning._registry import registry + >>> @registry.register("my_custom_strategy") + ... class MyCustomStrategy: + ... def __init__(self, param1: int, param2: str): + ... self.param1 = param1 + ... self.param2 = param2 + + Construct an instance with validation: + + >>> strategy = registry.construct("my_custom_strategy", param1=42, param2="test") + """ + + def __init__(self) -> None: + """Initialize an empty registry.""" + self._registry: dict[str, type[ActiveLearningProtocol]] = {} + + def register( + self, cls_name: str + ) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]: + """ + Decorator to register an active learning protocol class. + + This decorator registers a class implementing the `ActiveLearningProtocol` + under the given name, allowing it to be retrieved and constructed later + using the `construct` method. + + Parameters + ---------- + cls_name : str + The name to register the protocol under. This will be used as the + key for later retrieval. + + Returns + ------- + Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]] + A decorator function that registers the class and returns it unchanged. + + Raises + ------ + ValueError + If a protocol with the same name is already registered. + + Examples + -------- + >>> @registry.register("my_new_strategy") + ... class MyStrategy: + ... def __init__(self, param: int): + ... self.param = param + """ + + def decorator( + cls: type[ActiveLearningProtocol], + ) -> type[ActiveLearningProtocol]: + """ + Method for decorating a class to registry it with the registry. + """ + if cls_name in self._registry: + raise ValueError( + f"Protocol '{cls_name}' is already registered. " + f"Existing class: {self._registry[cls_name].__name__}" + ) + self._registry[cls_name] = cls + return cls + + return decorator + + def construct( + self, cls_name: str, module_path: str | None = None, **kwargs: Any + ) -> ActiveLearningProtocol: + """ + Construct an instance of a registered protocol with argument validation. + + This method retrieves a registered protocol class by name, validates that + the provided keyword arguments match the class's constructor signature, + and returns a new instance of the class. + + Parameters + ---------- + cls_name : str + The name of the registered protocol to construct. + module_path: str | None + The path to the module to get the class from. + **kwargs : Any + Keyword arguments to pass to the protocol's constructor. + + Returns + ------- + ActiveLearningProtocol + A new instance of the requested protocol class. + + Raises + ------ + KeyError + If the protocol name is not registered. + TypeError + If the provided keyword arguments do not match the constructor signature. + This includes missing required parameters or unexpected parameters. + + Examples + -------- + >>> from physicsnemo.active_learning._registry import registry + >>> @registry.register("my_latest_strategy") + ... class MyStrategy: + ... def __init__(self, param: int): + ... self.param = param + >>> strategy = registry.construct("my_latest_strategy", param=42) + """ + cls = self.get_class(cls_name, module_path) + + # Validate arguments against the class signature + try: + sig = inspect.signature(cls.__init__) + except (ValueError, TypeError) as e: + raise TypeError( + f"Could not inspect signature of {cls.__name__}.__init__: {e}" + ) + + # Get parameters, excluding 'self' + params = { + name: param for name, param in sig.parameters.items() if name != "self" + } + + # Check if the signature accepts **kwargs + has_var_keyword = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + + # Check for missing required parameters + missing = [] + for name, param in params.items(): + if ( + param.kind + not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + and param.default is inspect.Parameter.empty + and name not in kwargs + ): + missing.append(name) + + if missing: + raise TypeError( + f"Missing required arguments for {cls.__name__}: {', '.join(missing)}" + ) + + # Check for unexpected parameters (unless **kwargs is present) + if not has_var_keyword: + param_names = { + name + for name, param in params.items() + if param.kind + not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + } + unexpected = [name for name in kwargs if name not in param_names] + + if unexpected: + warn( + f"Unexpected arguments for {cls.__name__}: {', '.join(unexpected)}. " + f"Valid parameters: {', '.join(sorted(param_names))}" + ) + return cls(**kwargs) + + def __getitem__(self, cls_name: str) -> type[ActiveLearningProtocol]: + """ + Retrieve a registered protocol class by name using dict-like access. + + This method allows accessing registered protocol classes using square + bracket notation, e.g., `registry['my_strategy']`. + + Parameters + ---------- + cls_name : str + The name of the registered protocol to retrieve. + + Returns + ------- + type[ActiveLearningProtocol] + The class type of the registered protocol. + + Raises + ------ + KeyError + If the protocol name is not registered. + + Examples + -------- + >>> from physicsnemo.active_learning._registry import registry + >>> @registry.register("my_strategy") + ... class MyStrategy: + ... def __init__(self, param: int): + ... self.param = param + >>> RetrievedClass = registry['my_strategy'] + >>> instance = RetrievedClass(param=42) + """ + if cls_name not in self._registry: + available = ", ".join(self._registry.keys()) if self._registry else "none" + raise KeyError( + f"Protocol '{cls_name}' is not registered. " + f"Available protocols: {available}" + ) + return self._registry[cls_name] + + def is_registered(self, cls_name: str) -> bool: + """ + Check if a protocol name is registered. + + Parameters + ---------- + cls_name : str + The name of the protocol to check. + + Returns + ------- + bool + True if the protocol is registered, False otherwise. + """ + return cls_name in self._registry + + @property + def registered_names(self) -> list[str]: + """ + A list of all registered protocol names, sorted alphabetically. + + Returns + ------- + list[str] + A list of all registered protocol names, sorted alphabetically. + """ + return sorted(self._registry.keys()) + + def get_class(self, cls_name: str, module_path: str | None = None) -> type: + """ + Get a class by name from the registry or from a module path. + + Parameters + ---------- + cls_name: str + The name of the class to get. + module_path: str | None + The path to the module to get the class from. + + Returns + ------- + type + The class. + + Raises + ------ + NameError: If the class is not found in the registry or module. + ModuleNotFoundError: If the module is not found with the specified module path. + """ + if cls_name in self.registered_names: + return self._registry[cls_name] + else: + if module_path: + module = importlib.import_module(module_path) + cls = getattr(module, cls_name, None) + if not cls: + raise NameError( + f"Class {cls_name} not found in module {module_path}" + ) + return cls + else: + raise NameError( + f"Class {cls_name} not found in registry, and no module path was provided." + ) + + +# Module-level registry instance for global access +registry = ActiveLearningRegistry() diff --git a/physicsnemo/active_learning/config.py b/physicsnemo/active_learning/config.py new file mode 100644 index 0000000000..06310cbb0f --- /dev/null +++ b/physicsnemo/active_learning/config.py @@ -0,0 +1,808 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration dataclasses for the active learning driver. + +This module provides structured configuration classes that separate different +concerns in the active learning workflow: optimization, training, strategies, +and driver orchestration. +""" + +from __future__ import annotations + +import math +import uuid +from collections import defaultdict +from dataclasses import dataclass, field +from json import dumps +from pathlib import Path +from typing import Any +from warnings import warn + +import torch +from torch import distributed as dist +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from physicsnemo.active_learning import protocols as p +from physicsnemo.active_learning._registry import registry +from physicsnemo.active_learning.loop import DefaultTrainingLoop +from physicsnemo.distributed import DistributedManager + + +@dataclass +class OptimizerConfig: + """ + Configuration for optimizer and learning rate scheduler. + + This encapsulates all training optimization parameters, keeping + them separate from the active learning orchestration logic. + + Attributes + ---------- + optimizer_cls: type[Optimizer] + The optimizer class to use. Defaults to AdamW. + optimizer_kwargs: dict[str, Any] + Keyword arguments to pass to the optimizer constructor. + Defaults to {"lr": 1e-4}. + scheduler_cls: type[_LRScheduler] | None + The learning rate scheduler class to use. If None, no + scheduler will be configured. + scheduler_kwargs: dict[str, Any] + Keyword arguments to pass to the scheduler constructor. + """ + + optimizer_cls: type[Optimizer] = AdamW + optimizer_kwargs: dict[str, Any] = field(default_factory=lambda: {"lr": 1e-4}) + scheduler_cls: type[_LRScheduler] | None = None + scheduler_kwargs: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Validate optimizer configuration.""" + # Validate learning rate if present + if "lr" in self.optimizer_kwargs: + lr = self.optimizer_kwargs["lr"] + if not isinstance(lr, (int, float)) or lr <= 0: + raise ValueError(f"Learning rate must be positive, got {lr}") + + # Validate that scheduler_kwargs is only set if scheduler_cls is provided + if self.scheduler_kwargs and self.scheduler_cls is None: + raise ValueError( + "scheduler_kwargs provided but scheduler_cls is None. " + "Provide a scheduler_cls or remove scheduler_kwargs." + ) + + def to_dict(self) -> dict[str, Any]: + """ + Returns a JSON-serializable dictionary representation of the OptimizerConfig. + + For round-tripping, the registry is used to de-serialize the optimizer and scheduler + classes. + + Returns + ------- + dict[str, Any] + A dictionary that can be JSON serialized. + """ + opt = { + "__name__": self.optimizer_cls.__name__, + "__module__": self.optimizer_cls.__module__, + } + if self.scheduler_cls: + scheduler = { + "__name__": self.scheduler_cls.__name__, + "__module__": self.scheduler_cls.__module__, + } + else: + scheduler = None + return { + "optimizer_cls": opt, + "optimizer_kwargs": self.optimizer_kwargs, + "scheduler_cls": scheduler, + "scheduler_kwargs": self.scheduler_kwargs, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> OptimizerConfig: + """ + Creates an OptimizerConfig instance from a dictionary. + + This method assumes that the optimizer and scheduler classes are + included in the ``physicsnemo.active_learning.registry``, or + a module path is specified to import the class from. + + Parameters + ---------- + data: dict[str, Any] + A dictionary that was previously serialized using the ``to_dict`` method. + + Returns + ------- + OptimizerConfig + A new ``OptimizerConfig`` instance. + """ + optimizer_cls = registry.get_class( + data["optimizer_cls"]["__name__"], data["optimizer_cls"]["__module__"] + ) + if (s := data.get("scheduler_cls")) is not None: + scheduler_cls = registry.get_class(s["__name__"], s["__module__"]) + else: + scheduler_cls = None + return cls( + optimizer_cls=optimizer_cls, + optimizer_kwargs=data["optimizer_kwargs"], + scheduler_cls=scheduler_cls, + scheduler_kwargs=data["scheduler_kwargs"], + ) + + +@dataclass +class TrainingConfig: + """ + Configuration for the training phase of active learning. + + This groups all training-related components together, making it + clear when training is or isn't being used in the AL workflow. + + Attributes + ---------- + train_datapool: p.DataPool + The pool of labeled data to use for training. + max_training_epochs: int + The maximum number of epochs to train for. If ``max_fine_tuning_epochs`` + isn't specified, this value is used for all active learning steps. + val_datapool: p.DataPool | None + Optional pool of data to use for validation during training. + optimizer_config: OptimizerConfig + Configuration for the optimizer and scheduler. Defaults to + AdamW with lr=1e-4, no scheduler. + max_fine_tuning_epochs: int | None + The maximum number of epochs used during fine-tuning steps, i.e. after + the first active learning step. If ``None``, then the fine-tuning will + be performed for the duration of the active learning loop. + train_loop_fn: p.TrainingLoop + The training loop function that orchestrates the training process. + This defaults to a concrete implementation, ``DefaultTrainingLoop``, + which provides a very typical loop that includes the use of static + capture, etc. + """ + + train_datapool: p.DataPool + max_training_epochs: int + val_datapool: p.DataPool | None = None + optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig) + max_fine_tuning_epochs: int | None = None + train_loop_fn: p.TrainingLoop = field(default_factory=DefaultTrainingLoop) + + def __post_init__(self) -> None: + """Validate training configuration.""" + # Validate datapools have consistent interface + if not hasattr(self.train_datapool, "__len__"): + raise ValueError("train_datapool must implement __len__") + if self.val_datapool is not None and not hasattr(self.val_datapool, "__len__"): + raise ValueError("val_datapool must implement __len__") + + # Validate training loop is callable + if not callable(self.train_loop_fn): + raise ValueError("train_loop_fn must be callable") + + # set the same value for fine tuning epochs if not provided + if self.max_fine_tuning_epochs is None: + self.max_fine_tuning_epochs = self.max_training_epochs + + def to_dict(self) -> dict[str, Any]: + """ + Returns a JSON-serializable dictionary representation of the TrainingConfig. + + For round-tripping, the registry is used to de-serialize the training loop + and optimizer configuration. Note that datapools (train_datapool and val_datapool) + are NOT serialized as they typically contain large datasets, file handles, or other + non-serializable state. + + Returns + ------- + dict[str, Any] + A dictionary that can be JSON serialized. Excludes datapools. + + Warnings + -------- + This method will issue a warning about the exclusion of datapools. + """ + # Warn about datapool exclusion + warn( + "The `train_datapool` and `val_datapool` attributes are not supported for " + "serialization and will be excluded from the ``TrainingConfig`` dictionary. " + "You must re-provide these datapools when deserializing." + ) + + # Serialize optimizer config + optimizer_dict = self.optimizer_config.to_dict() + + # Serialize training loop function + if not hasattr(self.train_loop_fn, "_args"): + raise ValueError( + f"Training loop {self.train_loop_fn} does not have an `_args` attribute " + "which is required for serialization. Make sure your training loop " + "either subclasses `ActiveLearningProtocol` or implements the `__new__` " + "method to capture object arguments." + ) + + train_loop_dict = self.train_loop_fn._args + + return { + "max_training_epochs": self.max_training_epochs, + "max_fine_tuning_epochs": self.max_fine_tuning_epochs, + "optimizer_config": optimizer_dict, + "train_loop_fn": train_loop_dict, + } + + @classmethod + def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> TrainingConfig: + """ + Creates a TrainingConfig instance from a dictionary. + + This method assumes that the training loop class is included in the + ``physicsnemo.active_learning.registry``, or a module path is specified + to import the class from. Note that datapools must be provided via + kwargs as they are not serialized. + + Parameters + ---------- + data: dict[str, Any] + A dictionary that was previously serialized using the ``to_dict`` method. + **kwargs: Any + Additional keyword arguments to pass to the constructor. This is where + you must provide ``train_datapool`` and optionally ``val_datapool``. + + Returns + ------- + TrainingConfig + A new ``TrainingConfig`` instance. + + Raises + ------ + ValueError + If required datapools are not provided in kwargs, if the data contains + unexpected keys, or if object construction fails. + """ + # Ensure required datapools are provided + if "train_datapool" not in kwargs: + raise ValueError( + "``train_datapool`` must be provided in kwargs when deserializing " + "TrainingConfig, as datapools are not serialized." + ) + + # Reconstruct optimizer config + optimizer_config = OptimizerConfig.from_dict(data["optimizer_config"]) + + # Reconstruct training loop function + train_loop_data = data["train_loop_fn"] + train_loop_fn = registry.construct( + train_loop_data["__name__"], + module_path=train_loop_data["__module__"], + **train_loop_data["__args__"], + ) + + # Build the config + try: + config = cls( + max_training_epochs=data["max_training_epochs"], + max_fine_tuning_epochs=data.get("max_fine_tuning_epochs"), + optimizer_config=optimizer_config, + train_loop_fn=train_loop_fn, + **kwargs, + ) + except Exception as e: + raise ValueError( + "Failed to construct ``TrainingConfig`` from dictionary." + ) from e + + return config + + +@dataclass +class StrategiesConfig: + """ + Configuration for active learning strategies and data acquisition. + + This encapsulates the query-label-metrology cycle that is at the + heart of active learning: strategies for selecting data, labeling it, + and measuring model uncertainty/performance. + + Attributes + ---------- + query_strategies: list[p.QueryStrategy] + The query strategies to use for selecting data to label. + queue_cls: type[p.AbstractQueue] + The queue implementation to use for passing data between + query and labeling phases. + label_strategy: p.LabelStrategy | None + The strategy to use for labeling queried data. If None, + labeling will be skipped. + metrology_strategies: list[p.MetrologyStrategy] | None + Strategies for measuring model performance and uncertainty. + If None, metrology will be skipped. + unlabeled_datapool: p.DataPool | None + Pool of unlabeled data that query strategies can sample from. + Not all strategies require this (some may generate synthetic data). + """ + + query_strategies: list[p.QueryStrategy] + queue_cls: type[p.AbstractQueue] + label_strategy: p.LabelStrategy | None = None + metrology_strategies: list[p.MetrologyStrategy] | None = None + unlabeled_datapool: p.DataPool | None = None + + def __post_init__(self) -> None: + """Validate strategies configuration.""" + # Must have at least one query strategy + if not self.query_strategies: + raise ValueError( + "At least one query strategy must be provided. " + "Active learning requires a mechanism to select data." + ) + + # All query strategies must be callable + for strategy in self.query_strategies: + if not callable(strategy): + raise ValueError(f"Query strategy {strategy} must be callable") + + # Label strategy must be callable if provided + if self.label_strategy is not None and not callable(self.label_strategy): + raise ValueError("label_strategy must be callable") + + # Metrology strategies must be callable if provided + if self.metrology_strategies is not None: + if not self.metrology_strategies: + raise ValueError( + "metrology_strategies is an empty list. " + "Either provide strategies or set to None to skip metrology." + ) + for strategy in self.metrology_strategies: + if not callable(strategy): + raise ValueError(f"Metrology strategy {strategy} must be callable") + + # Validate queue class has basic queue interface + if not hasattr(self.queue_cls, "__call__"): + raise ValueError("queue_cls must be a callable class") + + def to_dict(self) -> dict[str, Any]: + """ + Method that converts the present ``StrategiesConfig`` instance into a dictionary + that can be JSON serialized. + + This method, for the most part, assumes that strategies are subclasses of + ``ActiveLearningProtocol`` and/or they have an ``_args`` attribute that + captures the arguments to the constructor. + + One issue is the inability to reliably serialize the ``unlabeled_datapool``, + which for the most part, likely does not need serialization as a dataset. + Regardless, this method will trigger a warning if ``unlabeled_datapool`` is + not None. + + Returns + ------- + dict[str, Any] + A dictionary that can be JSON serialized. + """ + output = defaultdict(list) + for strategy in self.query_strategies: + if not hasattr(strategy, "_args"): + raise ValueError( + f"Query strategy {strategy} does not have an `_args` attribute" + " which is required for serialization. Make sure your strategy" + " either subclasses `ActiveLearningProtocol` or implements" + " the `__new__` method to capture object arguments." + ) + output["query_strategies"].append(strategy._args) + if self.label_strategy is not None: + if not hasattr(self.label_strategy, "_args"): + raise ValueError( + f"Label strategy {self.label_strategy} does not have an `_args` attribute" + " which is required for serialization. Make sure your strategy" + " either subclasses `ActiveLearningProtocol` or implements" + " the `__new__` method to capture object arguments." + ) + output["label_strategy"] = self.label_strategy._args + output["queue_cls"] = { + "__name__": self.queue_cls.__name__, + "__module__": self.queue_cls.__module__, + } + if self.metrology_strategies is not None: + for strategy in self.metrology_strategies: + if not hasattr(strategy, "_args"): + raise ValueError( + f"Metrology strategy {strategy} does not have an `_args` attribute" + " which is required for serialization. Make sure your strategy" + " either subclasses `ActiveLearningProtocol` or implements" + " the `__new__` method to capture object arguments." + ) + output["metrology_strategies"].append(strategy._args) + if self.unlabeled_datapool is not None: + warn( + "The `unlabeled_datapool` attribute is not supported for serialization" + " and will be excluded from the ``StrategiesConfig`` dictionary." + ) + return output + + @classmethod + def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> StrategiesConfig: + """ + Create a ``StrategiesConfig`` instance from a dictionary. + + This method heavily relies on classes being added to the + ``physicsnemo.active_learning.registry``, which is used to instantiate + all strategies and custom types used in active learning. As a fall + back, the `registry.construct` method will try and import the class + from the module path if it is not found in the registry. + + Parameters + ---------- + data: dict[str, Any] + A dictionary that was previously serialized using the ``to_dict`` method. + **kwargs: Any + Additional keyword arguments to pass to the constructor. + + Returns + ------- + StrategiesConfig + A new ``StrategiesConfig`` instance. + + Raises + ------ + ValueError: + If the data contains unexpected keys or if the object construction fails. + NameError: + If a class is not found in the registry and no module path is provided. + ModuleNotFoundError: + If a module is not found with the specified module path. + """ + # ensure that the data contains no unexpected keys + data_keys = set(data.keys()) + expected_keys = set(cls.__dataclass_fields__.keys()) + extra_keys = data_keys - expected_keys + if extra_keys: + raise ValueError( + f"Unexpected keys in data: {extra_keys}. Expected keys are {expected_keys}." + ) + # instantiate objects from the serialized data; general strategy is to + # use `registry.construct` that will try and resolve the class within + # the registry first, and if not found, then it will try and import the + # class from the module path. + output_dict = defaultdict(list) + for entry in data["query_strategies"]: + output_dict["query_strategies"].append( + registry.construct( + entry["__name__"], + module_path=entry["__module__"], + **entry["__args__"], + ) + ) + if "metrology_strategies" in data: + for entry in data["metrology_strategies"]: + output_dict["metrology_strategies"].append( + registry.construct( + entry["__name__"], + module_path=entry["__module__"], + **entry["__args__"], + ) + ) + if "label_strategy" in data: + output_dict["label_strategy"] = registry.construct( + data["label_strategy"]["__name__"], + module_path=data["label_strategy"]["__module__"], + **data["label_strategy"]["__args__"], + ) + output_dict["queue_cls"] = registry.get_class( + data["queue_cls"]["__name__"], data["queue_cls"]["__module__"] + ) + # potentially override with keyword arguments + output_dict.update(kwargs) + try: + config = cls(**output_dict) + except Exception as e: + raise ValueError( + "Failed to construct ``StrategiesConfig`` from dictionary." + ) from e + return config + + +@dataclass +class DriverConfig: + """ + Configuration for driver orchestration and infrastructure. + + This contains parameters that control the overall loop execution, + logging, checkpointing, and distributed training setup - orthogonal + to the specific AL or training logic. + + Attributes + ---------- + batch_size: int + The batch size to use for data loaders. + max_active_learning_steps: int | None, default None + Maximum number of AL iterations to perform. None means infinite. + run_id: str, default auto-generated UUID + Unique identifier for this run. Auto-generated if not provided. + fine_tuning_lr: float | None, default None + Learning rate to switch to after the first AL step for fine-tuning. + reset_optim_states: bool, default True + Whether to reset optimizer states between AL steps. + skip_training: bool, default False + If True, skip the training phase entirely. + skip_metrology: bool, default False + If True, skip the metrology phase entirely. + skip_labeling: bool, default False + If True, skip the labeling phase entirely. + checkpoint_interval: int, default 1 + Save model checkpoint every N AL steps. 0 disables checkpointing. + checkpoint_on_training: bool, default False + If True, save checkpoint at the start of the training phase. + checkpoint_on_metrology: bool, default False + If True, save checkpoint at the start of the metrology phase. + checkpoint_on_query: bool, default False + If True, save checkpoint at the start of the query phase. + checkpoint_on_labeling: bool, default True + If True, save checkpoint at the start of the labeling phase. + model_checkpoint_frequency: int, default 0 + Save model weights every N epochs during training. 0 means only save + between active learning phases. Useful for mid-training restarts. + root_log_dir: str | Path, default Path.cwd() / "active_learning_logs" + Directory to save logs and checkpoints to. Defaults to + an 'active_learning_logs' directory in the current working directory. + dist_manager: DistributedManager | None, default None + Manager for distributed training configuration. + collate_fn: callable | None, default None + Custom collate function for batching data. + num_dataloader_workers: int, default 0 + Number of worker processes for data loading. + device: str | torch.device | None, default None + Device to use for model and data. This is intended for single process + workflows; for distributed workflows, the device should be set in + ``DistributedManager`` instead. If not specified, then the device + will default to ``torch.get_default_device()``. + dtype: torch.dtype | None, default None + The dtype to use for model and data, and AMP contexts. If not provided, + then the dtype will default to ``torch.get_default_dtype()``. + """ + + batch_size: int + max_active_learning_steps: int | None = None + run_id: str = field(default_factory=lambda: str(uuid.uuid4())) + fine_tuning_lr: float | None = None # TODO: move to TrainingConfig + reset_optim_states: bool = True + skip_training: bool = False + skip_metrology: bool = False + skip_labeling: bool = False + checkpoint_interval: int = 1 + checkpoint_on_training: bool = False + checkpoint_on_metrology: bool = False + checkpoint_on_query: bool = False + checkpoint_on_labeling: bool = True + model_checkpoint_frequency: int = 0 + root_log_dir: str | Path = field(default=Path.cwd() / "active_learning_logs") + dist_manager: DistributedManager | None = None + collate_fn: callable | None = None + num_dataloader_workers: int = 0 + device: str | torch.device | None = None + dtype: torch.dtype | None = None + + def __post_init__(self) -> None: + """Validate driver configuration.""" + if self.max_active_learning_steps is None: + self.max_active_learning_steps = float("inf") + + if ( + self.max_active_learning_steps is not None + and self.max_active_learning_steps <= 0 + ): + raise ValueError( + "`max_active_learning_steps` must be a positive integer or None." + ) + + if not math.isfinite(self.batch_size) or self.batch_size <= 0: + raise ValueError("`batch_size` must be a positive integer.") + + if not math.isfinite(self.checkpoint_interval) or self.checkpoint_interval < 0: + raise ValueError( + "`checkpoint_interval` must be a non-negative integer. " + "Use 0 to disable checkpointing." + ) + + if self.fine_tuning_lr is not None and self.fine_tuning_lr <= 0: + raise ValueError("`fine_tuning_lr` must be positive if provided.") + + if self.num_dataloader_workers < 0: + raise ValueError("`num_dataloader_workers` must be non-negative.") + + if self.model_checkpoint_frequency < 0: + raise ValueError("`model_checkpoint_frequency` must be non-negative.") + + if isinstance(self.root_log_dir, str): + self.root_log_dir = Path(self.root_log_dir) + + # Validate collate_fn if provided + if self.collate_fn is not None and not callable(self.collate_fn): + raise ValueError("`collate_fn` must be callable if provided.") + + # device and dtype setup when not using DistributedManager + if self.device is None and not self.dist_manager: + self.device = torch.get_default_device() + if self.dtype is None: + self.dtype = torch.get_default_dtype() + + def to_json(self) -> str: + """ + Returns a JSON string representation of the ``DriverConfig``. + + Note that certain fields are not serialized and must be provided when + deserializing: ``dist_manager``, ``collate_fn``. + + Returns + ------- + str + A JSON string representation of the config. + """ + # base dict representation skips Python objects + dict_repr = { + key: self.__dict__[key] + for key in self.__dict__ + if key + not in ["dist_manager", "collate_fn", "root_log_dir", "device", "dtype"] + } + # Note: checkpoint flags are included in dict_repr automatically + dict_repr["default_dtype"] = str(torch.get_default_dtype()) + dict_repr["log_dir"] = str(self.root_log_dir) + # Convert dtype to string for JSON serialization + if self.dtype is not None: + dict_repr["dtype"] = str(self.dtype) + else: + dict_repr["dtype"] = None + if self.dist_manager is not None: + dict_repr["world_size"] = self.dist_manager.world_size + dict_repr["device"] = self.dist_manager.device.type + dict_repr["dist_manager_init_method"] = ( + self.dist_manager._initialization_method + ) + else: + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + dict_repr["world_size"] = world_size + if self.device is not None: + dict_repr["device"] = ( + str(self.device) + if hasattr(self.device, "type") + else str(self.device) + ) + else: + dict_repr["device"] = torch.get_default_device().type + dict_repr["dist_manager_init_method"] = None + if self.collate_fn is not None: + dict_repr["collate_fn"] = self.collate_fn.__name__ + else: + dict_repr["collate_fn"] = None + return dumps(dict_repr, indent=2) + + @classmethod + def from_json(cls, json_str: str, **kwargs: Any) -> DriverConfig: + """ + Creates a DriverConfig instance from a JSON string. + + This method reconstructs a DriverConfig from JSON. Note that certain + fields cannot be serialized and must be provided via kwargs: + - ``dist_manager``: DistributedManager instance (optional) + - ``collate_fn``: Custom collate function (optional) + + Parameters + ---------- + json_str: str + A JSON string that was previously serialized using ``to_json()``. + **kwargs: Any + Additional keyword arguments to override or provide non-serializable + fields like ``dist_manager`` and ``collate_fn``. + + Returns + ------- + DriverConfig + A new ``DriverConfig`` instance. + + Raises + ------ + ValueError + If the JSON cannot be parsed or required fields are missing. + + Notes + ----- + The device and dtype fields are reconstructed from their string + representations. The ``log_dir`` field in JSON is mapped to + ``root_log_dir`` in the config. + """ + import json + + try: + data = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON string: {e}") from e + + # Define fields that are not actual DriverConfig constructor parameters + metadata_fields = [ + "default_dtype", + "world_size", + "dist_manager_init_method", + "log_dir", # handled separately as root_log_dir + ] + non_serializable_fields = [ + "dist_manager", + "collate_fn", + "root_log_dir", + "device", + "dtype", + ] + + # Extract serializable fields that map directly + config_fields = { + key: value + for key, value in data.items() + if key not in metadata_fields + non_serializable_fields + } + + # Handle root_log_dir (stored as "log_dir" in JSON) + if "log_dir" in data: + config_fields["root_log_dir"] = Path(data["log_dir"]) + + # Handle device reconstruction from string + if "device" in data and data["device"] is not None: + device_str = data["device"] + # Parse device strings like "cuda:0", "cpu", "cuda", etc. + config_fields["device"] = torch.device(device_str) + + # Handle dtype reconstruction from string + if "dtype" in data and data["dtype"] is not None: + dtype_str = data["dtype"] + # Map string representations to torch dtypes + dtype_map = { + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.float16": torch.float16, + "torch.bfloat16": torch.bfloat16, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.int8": torch.int8, + "torch.uint8": torch.uint8, + } + if dtype_str in dtype_map: + config_fields["dtype"] = dtype_map[dtype_str] + else: + warn( + f"Unknown dtype string '{dtype_str}' in JSON. " + "Using default dtype instead." + ) + + # Merge with provided kwargs (allows overriding and adding non-serializable fields) + config_fields.update(kwargs) + + # Create the config + try: + config = cls(**config_fields) + except Exception as e: + raise ValueError( + "Failed to construct ``DriverConfig`` from JSON string." + ) from e + + return config diff --git a/physicsnemo/active_learning/driver.py b/physicsnemo/active_learning/driver.py new file mode 100644 index 0000000000..02169f5014 --- /dev/null +++ b/physicsnemo/active_learning/driver.py @@ -0,0 +1,1449 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains the definition for an active learning driver +class, which is responsible for orchestration and automation of +the end-to-end active learning process. +""" + +from __future__ import annotations + +import inspect +import pickle +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Generator + +import torch +from torch import distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader, DistributedSampler + +from physicsnemo import Module +from physicsnemo import __version__ as physicsnemo_version +from physicsnemo.active_learning import protocols as p +from physicsnemo.active_learning.config import ( + DriverConfig, + StrategiesConfig, + TrainingConfig, +) +from physicsnemo.active_learning.logger import ( + ActiveLearningLoggerAdapter, + setup_active_learning_logger, +) +from physicsnemo.distributed import DistributedManager + + +@dataclass +class ActiveLearningCheckpoint: + """ + Metadata associated with an ongoing (or completed) active + learning experiment. + + The information contained in this metadata should be sufficient + to restart the active learning experiment at the nearest point: + for example, training should be able to continue from an epoch, + while for querying/sampling, etc. we continue from a pre-existing + queue. + """ + + driver_config: DriverConfig + strategies_config: StrategiesConfig + active_learning_step_idx: int + active_learning_phase: p.ActiveLearningPhase + physicsnemo_version: str = physicsnemo_version + training_config: TrainingConfig | None = None + optimizer_state: dict[str, Any] | None = None + lr_scheduler_state: dict[str, Any] | None = None + has_query_queue: bool = False + has_label_queue: bool = False + + +class Driver(p.DriverProtocol): + """ + Provides a simple implementation of the ``DriverProtocol`` used to + orchestrate an active learning process within PhysicsNeMo. + + At a high level, the active learning process is broken down into four + phases: training, metrology, query, and labeling. + + To understand the orchestration, start by inspecting the + ``active_learning_step`` method, which defines a single iteration of + the active learning loop, which is dispatched by the ``run`` method. + From there, it should be relatively straightforward to trace the + remaining components. + + Attributes + ---------- + config: DriverConfig + Infrastructure and orchestration configuration. + learner: Module | p.LearnerProtocol + The learner module for the active learning process. + strategies_config: StrategiesConfig + Active learning strategies (query, label, metrology). + training_config: TrainingConfig | None + Training components. None if training is skipped. + inference_fn: p.InferenceProtocol | None + Custom inference function. + active_learning_step_idx: int + Current iteration index of the active learning loop. + query_queue: p.AbstractQueue + Queue populated with data by query strategies. + label_queue: p.AbstractQueue + Queue populated with labeled data by the label strategy. + optimizer: torch.optim.Optimizer | None + Configured optimizer (set after configure_optimizer is called). + lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None + Configured learning rate scheduler. + logger: logging.Logger + Persistent logger for the active learning process. + """ + + # Phase execution order for active learning step (immutable) + _PHASE_ORDER = [ + p.ActiveLearningPhase.TRAINING, + p.ActiveLearningPhase.METROLOGY, + p.ActiveLearningPhase.QUERY, + p.ActiveLearningPhase.LABELING, + ] + + def __init__( + self, + config: DriverConfig, + learner: Module | p.LearnerProtocol, + strategies_config: StrategiesConfig, + training_config: TrainingConfig | None = None, + inference_fn: p.InferenceProtocol | None = None, + ) -> None: + """ + Initializes the active learning driver. + + At the bare minimum, the driver requires a config, learner, and + strategies config to be used in a purely querying loop. Additional + arguments can be provided to enable training and other workflows. + + Parameters + ---------- + config: DriverConfig + Orchestration and infrastructure configuration, for example + the batch size, the log directory, the distributed manager, etc. + learner: Module | p.LearnerProtocol + The model to use for active learning. + strategies_config: StrategiesConfig + Container for active learning strategies (query, label, metrology). + training_config: TrainingConfig | None + Training components. Required if ``skip_training`` is False in + the ``DriverConfig``. + inference_fn: p.InferenceProtocol | None + Custom inference function. If None, uses ``learner.__call__``. + This is not actually called by the driver, but is stored as an + attribute for attached strategies to use as needed. + """ + # Configs have already validated themselves in __post_init__ + self.config = config + self.learner = learner + self.strategies_config = strategies_config + self.training_config = training_config + self.inference_fn = inference_fn + self.active_learning_step_idx = 0 + self.current_phase: p.ActiveLearningPhase | None = ( + None # Track current phase for logging context + ) + self._last_checkpoint_path: Path | None = None + + # Validate cross-config constraints + self._validate_config_consistency() + + self._setup_logger() + self.attach_strategies() + + # Initialize queues from strategies_config + self.query_queue = strategies_config.queue_cls() + self.label_queue = strategies_config.queue_cls() + + def _validate_config_consistency(self) -> None: + """ + Validate consistency across configs. + + Each config validates itself, but this method checks relationships + between configs that can only be validated when composed together. + """ + # If training is not skipped, training_config must be provided + if not self.config.skip_training and self.training_config is None: + raise ValueError( + "`training_config` must be provided when `skip_training` is False." + ) + + # If labeling is not skipped, must have label strategy and train datapool + if not self.config.skip_labeling: + if self.strategies_config.label_strategy is None: + raise ValueError( + "`label_strategy` must be provided in strategies_config " + "when `skip_labeling` is False." + ) + if ( + self.training_config is None + or self.training_config.train_datapool is None + ): + raise ValueError( + "`train_datapool` must be provided in training_config " + "when `skip_labeling` is False (labeled data is appended to it)." + ) + + # If fine-tuning lr is set, must have training enabled + if self.config.fine_tuning_lr is not None and self.config.skip_training: + raise ValueError( + "`fine_tuning_lr` has no effect when `skip_training` is True." + ) + + @property + def query_strategies(self) -> list[p.QueryStrategy]: + """Returns the query strategies from strategies_config.""" + return self.strategies_config.query_strategies + + @property + def label_strategy(self) -> p.LabelStrategy | None: + """Returns the label strategy from strategies_config.""" + return self.strategies_config.label_strategy + + @property + def metrology_strategies(self) -> list[p.MetrologyStrategy] | None: + """Returns the metrology strategies from strategies_config.""" + return self.strategies_config.metrology_strategies + + @property + def unlabeled_datapool(self) -> p.DataPool | None: + """Returns the unlabeled datapool from strategies_config.""" + return self.strategies_config.unlabeled_datapool + + @property + def train_datapool(self) -> p.DataPool | None: + """Returns the training datapool from training_config.""" + return self.training_config.train_datapool if self.training_config else None + + @property + def val_datapool(self) -> p.DataPool | None: + """Returns the validation datapool from training_config.""" + return self.training_config.val_datapool if self.training_config else None + + @property + def train_loop_fn(self) -> p.TrainingLoop | None: + """Returns the training loop function from training_config.""" + return self.training_config.train_loop_fn if self.training_config else None + + @property + def device(self) -> torch.device: + """Return a consistent device interface to use across the driver.""" + if self.dist_manager is not None and self.dist_manager.is_initialized(): + return self.dist_manager.device + else: + return torch.get_default_device() + + @property + def run_id(self) -> str: + """Returns the run id from the ``DriverConfig``. + + Returns + ------- + str + The run id. + """ + return self.config.run_id + + @property + def log_dir(self) -> Path: + """Returns the log directory. + + Note that this is the ``DriverConfig.root_log_dir`` combined + with the shortened run ID for the current run. + + Effectively, this means that each run will have its own + directory for logs, checkpoints, etc. + + Returns + ------- + Path + The log directory. + """ + return self.config.root_log_dir / self.short_run_id + + @property + def short_run_id(self) -> str: + """Returns the first 8 characters of the run id. + + The 8 character limit assumes that the run ID is a UUID4. + This is particularly useful for user-facing interfaces, + where you do not necessarily want to reference the full UUID. + + Returns + ------- + str + The first 8 characters of the run id. + """ + return self.run_id[:8] + + @property + def last_checkpoint(self) -> Path | None: + """ + Returns path to the most recently saved checkpoint. + + Returns + ------- + Path | None + Path to the last checkpoint directory, or None if no checkpoint + has been saved yet. + """ + return self._last_checkpoint_path + + @property + def active_learning_step_idx(self) -> int: + """ + Returns the current active learning step index. + + This represents the number of times the active learning step + has been called, i.e. the number of iterations of the loop. + + Returns + ------- + int + The current active learning step index. + """ + return self._active_learning_step_idx + + @active_learning_step_idx.setter + def active_learning_step_idx(self, value: int) -> None: + """ + Sets the current active learning step index. + + Parameters + ---------- + value: int + The new active learning step index. + + Raises + ------ + ValueError + If the new active learning step index is negative. + """ + if value < 0: + raise ValueError("Active learning step index must be non-negative.") + self._active_learning_step_idx = value + + @property + def dist_manager(self) -> DistributedManager | None: + """Returns the distributed manager, if it was specified as part + of the `DriverConfig` configuration. + + Returns + ------- + DistributedManager | None + The distributed manager. + """ + return self.config.dist_manager + + def configure_optimizer(self) -> None: + """Setup optimizer and LR schedulers from training_config.""" + if self.training_config is None: + self.optimizer = None + self.lr_scheduler = None + return + + opt_cfg = self.training_config.optimizer_config + + if opt_cfg.optimizer_cls is not None: + try: + _ = inspect.signature(opt_cfg.optimizer_cls).bind( + self.learner.parameters(), **opt_cfg.optimizer_kwargs + ) + except TypeError as e: + raise ValueError( + f"Invalid optimizer kwargs for {opt_cfg.optimizer_cls}; {e}" + ) + self.optimizer = opt_cfg.optimizer_cls( + self.learner.parameters(), **opt_cfg.optimizer_kwargs + ) + else: + self.optimizer = None + return + + if opt_cfg.scheduler_cls is not None and self.optimizer is not None: + try: + _ = inspect.signature(opt_cfg.scheduler_cls).bind( + self.optimizer, **opt_cfg.scheduler_kwargs + ) + except TypeError as e: + raise ValueError( + f"Invalid LR scheduler kwargs for {opt_cfg.scheduler_cls}; {e}" + ) + self.lr_scheduler = opt_cfg.scheduler_cls( + self.optimizer, **opt_cfg.scheduler_kwargs + ) + else: + self.lr_scheduler = None + # in the case where we want to reset optimizer states between active learning steps + if self.config.reset_optim_states and self.is_optimizer_configured: + self._original_optim_state = deepcopy(self.optimizer.state_dict()) + + @property + def is_optimizer_configured(self) -> bool: + """Returns whether the optimizer is configured.""" + return getattr(self, "optimizer", None) is not None + + @property + def is_lr_scheduler_configured(self) -> bool: + """Returns whether the LR scheduler is configured.""" + return getattr(self, "lr_scheduler", None) is not None + + def attach_strategies(self) -> None: + """Calls ``strategy.attach`` for all available strategies.""" + super().attach_strategies() + + def _setup_logger(self) -> None: + """ + Sets up a persistent logger for the driver. + + This logger is specialized in that it provides additional context + information depending on the part of the active learning cycle. + """ + base_logger = setup_active_learning_logger( + "core.active_learning", + run_id=self.run_id, + log_dir=self.log_dir, + ) + # Wrap with adapter to automatically include iteration context + self.logger = ActiveLearningLoggerAdapter(base_logger, driver_ref=self) + + def _should_checkpoint_at_step(self) -> bool: + """ + Determine if a checkpoint should be saved at the current AL step. + + Uses the `checkpoint_interval` from config to decide. If interval is 0, + checkpointing is disabled. Otherwise, checkpoint at step 0 and every + N steps thereafter. + + Returns + ------- + bool + True if checkpoint should be saved, False otherwise. + """ + if self.config.checkpoint_interval == 0: + return False + # Always checkpoint at step 0, then every checkpoint_interval steps + return self.active_learning_step_idx % self.config.checkpoint_interval == 0 + + def _serialize_queue(self, queue: p.AbstractQueue, file_path: Path) -> bool: + """ + Serialize queue to a file. + + If queue implements `to_list()`, serialize the list. Otherwise, use + torch.save to serialize the entire queue object. + + Parameters + ---------- + queue: p.AbstractQueue + The queue to serialize. + file_path: Path + Path where the queue should be saved. + + Returns + ------- + bool + True if serialization succeeded, False otherwise. + """ + try: + if hasattr(queue, "to_list") and callable(getattr(queue, "to_list")): + # Use custom serialization method + queue_data = {"type": "list", "data": queue.to_list()} + else: + # Fallback to torch.save for the entire queue + queue_data = {"type": "torch", "data": queue} + + torch.save(queue_data, file_path) + return True + except (TypeError, AttributeError, pickle.PicklingError, RuntimeError) as e: + # Some queues cannot be pickled, e.g. stdlib queue.Queue with thread locks + # Clean up any partially written file + if file_path.exists(): + file_path.unlink() + + self.logger.warning( + f"Failed to serialize queue to {file_path}: {e}. Queue state will not be saved. " + f"Consider implementing to_list()/from_list() methods for custom serialization." + ) + return False + + def _deserialize_queue(self, queue: p.AbstractQueue, file_path: Path) -> None: + """ + Restore queue from a file. + + Parameters + ---------- + queue: p.AbstractQueue + The queue to restore data into. + file_path: Path + Path to the saved queue file. + """ + if not file_path.exists(): + return + + try: + queue_data = torch.load(file_path, map_location="cpu", weights_only=False) + + if queue_data["type"] == "list": + if hasattr(queue, "from_list") and callable( + getattr(queue, "from_list") + ): + queue.from_list(queue_data["data"]) + else: + # Manually populate queue from list + for item in queue_data["data"]: + queue.put(item) + elif queue_data["type"] == "torch": + # Restore from torch-saved queue - copy items to current queue + restored_queue = queue_data["data"] + # Copy items from restored queue to current queue + while not restored_queue.empty(): + queue.put(restored_queue.get()) + except Exception as e: + self.logger.warning( + f"Failed to deserialize queue from {file_path}: {e}. " + f"Queue will be empty." + ) + + def save_checkpoint( + self, path: str | Path | None = None, training_epoch: int | None = None + ) -> Path | None: + """ + Save a checkpoint of the active learning experiment. + + Saves AL orchestration state (configs, queues, step index, phase) and model weights. + Training-specific state (optimizer, scheduler) is handled by DefaultTrainingLoop + and saved to training_state.pt during training. + + Parameters + ---------- + path: str | Path | None + Path to save checkpoint. If None, creates path based on current + AL step index and phase: log_dir/checkpoints/step_{idx}/{phase}/ + training_epoch: int | None + Optional epoch number for mid-training checkpoints. + + Returns + ------- + Path | None + Checkpoint directory path, or None if checkpoint not saved (non-rank-0 in distributed). + """ + # Determine checkpoint directory + if path is None: + phase_name = self.current_phase if self.current_phase else "init" + checkpoint_dir = ( + self.log_dir + / "checkpoints" + / f"step_{self.active_learning_step_idx}" + / phase_name + ) + if training_epoch is not None: + checkpoint_dir = checkpoint_dir / f"epoch_{training_epoch}" + else: + checkpoint_dir = Path(path) + + # Create checkpoint directory + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Only rank 0 saves checkpoint in distributed setting + if self.dist_manager is not None and self.dist_manager.is_initialized(): + if self.dist_manager.rank != 0: + return None + + # Serialize configurations + driver_config_json = self.config.to_json() + strategies_config_dict = self.strategies_config.to_dict() + training_config_dict = ( + self.training_config.to_dict() if self.training_config else None + ) + + # Serialize queue states to separate files + query_queue_file = checkpoint_dir / "query_queue.pt" + label_queue_file = checkpoint_dir / "label_queue.pt" + has_query_queue = self._serialize_queue(self.query_queue, query_queue_file) + has_label_queue = self._serialize_queue(self.label_queue, label_queue_file) + + # Create checkpoint dataclass (only AL orchestration state) + checkpoint = ActiveLearningCheckpoint( + driver_config=driver_config_json, + strategies_config=strategies_config_dict, + active_learning_step_idx=self.active_learning_step_idx, + active_learning_phase=self.current_phase or p.ActiveLearningPhase.TRAINING, + physicsnemo_version=physicsnemo_version, + training_config=training_config_dict, + optimizer_state=None, # Training loop handles this + lr_scheduler_state=None, # Training loop handles this + has_query_queue=has_query_queue, + has_label_queue=has_label_queue, + ) + + # Add training epoch if in mid-training checkpoint + checkpoint_dict = { + "checkpoint": checkpoint, + } + if training_epoch is not None: + checkpoint_dict["training_epoch"] = training_epoch + + # Save checkpoint metadata + checkpoint_path = checkpoint_dir / "checkpoint.pt" + torch.save(checkpoint_dict, checkpoint_path) + + # Save model weights (separate from training state) + if isinstance(self.learner, Module): + model_name = ( + self.learner.meta.name + if self.learner.meta + else self.learner.__class__.__name__ + ) + model_path = checkpoint_dir / f"{model_name}.mdlus" + self.learner.save(str(model_path)) + elif hasattr(self.learner, "module") and isinstance( + self.learner.module, Module + ): + # Unwrap DDP + model_name = ( + self.learner.module.meta.name + if self.learner.module.meta + else self.learner.module.__class__.__name__ + ) + model_path = checkpoint_dir / f"{model_name}.mdlus" + self.learner.module.save(str(model_path)) + else: + model_name = self.learner.__class__.__name__ + model_path = checkpoint_dir / f"{model_name}.pt" + torch.save(self.learner.state_dict(), model_path) + + # Update last checkpoint path + self._last_checkpoint_path = checkpoint_dir + + # Log successful checkpoint save + self.logger.info( + f"Saved checkpoint at step {self.active_learning_step_idx}, " + f"phase {self.current_phase}: {checkpoint_dir}" + ) + + return checkpoint_dir + + @classmethod + def load_checkpoint( + cls, + checkpoint_path: str | Path, + learner: Module | p.LearnerProtocol | None = None, + train_datapool: p.DataPool | None = None, + val_datapool: p.DataPool | None = None, + unlabeled_datapool: p.DataPool | None = None, + **kwargs: Any, + ) -> Driver: + """ + Load a Driver instance from a checkpoint. + + Given a checkpoint directory, this method will attempt to reconstruct + the driver and its associated components from the checkpoint. The + checkpoint path must contain a ``checkpoint.pt`` file, which contains + the metadata associated with the experiment. + + Additional parameters that might not be serialized with the checkpointing + mechanism can/need to be provided to this method; for example when + using non-`physicsnemo.Module` learners, and any data pools associated + with the workflow. + + .. important:: + + Currently, the strategy states are not reloaded from the checkpoint. + This will be addressed in a future patch, but for now it is recommended + to back up your strategy states (e.g. metrology records) manually + before restarting experiments. + + Parameters + ---------- + checkpoint_path: str | Path + Path to checkpoint directory containing checkpoint.pt and model weights. + learner: Module | p.LearnerProtocol | None + Learner model to load weights into. If None, will attempt to + reconstruct from checkpoint (only works for physicsnemo.Module). + train_datapool: p.DataPool | None + Training datapool. Required if training_config exists in checkpoint. + val_datapool: p.DataPool | None + Validation datapool. Optional. + unlabeled_datapool: p.DataPool | None + Unlabeled datapool for query strategies. Optional. + **kwargs: Any + Additional keyword arguments to override config values. + + Returns + ------- + Driver + Reconstructed Driver instance ready to resume execution. + """ + checkpoint_path = Path(checkpoint_path) + + # Load checkpoint file + checkpoint_file = checkpoint_path / "checkpoint.pt" + if not checkpoint_file.exists(): + raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_file}") + + checkpoint_dict = torch.load( + checkpoint_file, map_location="cpu", weights_only=False + ) + checkpoint: ActiveLearningCheckpoint = checkpoint_dict["checkpoint"] + training_epoch = checkpoint_dict.get("training_epoch", None) + + # Reconstruct configs + driver_config = DriverConfig.from_json( + checkpoint.driver_config, **kwargs.get("driver_config_overrides", {}) + ) + + # TODO add strategy state loading from checkpoint + strategies_config = StrategiesConfig.from_dict( + checkpoint.strategies_config, + unlabeled_datapool=unlabeled_datapool, + **kwargs.get("strategies_config_overrides", {}), + ) + + training_config = None + if checkpoint.training_config is not None: + training_config = TrainingConfig.from_dict( + checkpoint.training_config, + train_datapool=train_datapool, + val_datapool=val_datapool, + **kwargs.get("training_config_overrides", {}), + ) + + # Load or reconstruct learner + if learner is None: + # Attempt to reconstruct from checkpoint (only for Module) + # Try to find any .mdlus file in the checkpoint directory + mdlus_files = list(checkpoint_path.glob("*.mdlus")) + if mdlus_files: + # Use the first .mdlus file found + model_path = mdlus_files[0] + learner = Module.from_checkpoint(str(model_path)) + else: + raise ValueError( + "No learner provided and unable to reconstruct from checkpoint. " + "Please provide a learner instance." + ) + else: + # Load model weights into provided learner + # Determine expected model filename based on learner type + if isinstance(learner, Module): + model_name = ( + learner.meta.name if learner.meta else learner.__class__.__name__ + ) + model_path = checkpoint_path / f"{model_name}.mdlus" + if model_path.exists(): + learner.load(str(model_path)) + else: + # Fallback: try to find any .mdlus file + mdlus_files = list(checkpoint_path.glob("*.mdlus")) + if mdlus_files: + learner.load(str(mdlus_files[0])) + elif hasattr(learner, "module") and isinstance(learner.module, Module): + # Unwrap DDP + model_name = ( + learner.module.meta.name + if learner.module.meta + else learner.module.__class__.__name__ + ) + model_path = checkpoint_path / f"{model_name}.mdlus" + if model_path.exists(): + learner.module.load(str(model_path)) + else: + # Fallback: try to find any .mdlus file + mdlus_files = list(checkpoint_path.glob("*.mdlus")) + if mdlus_files: + learner.module.load(str(mdlus_files[0])) + else: + # Non-Module learner: look for .pt file with class name + model_name = learner.__class__.__name__ + model_path = checkpoint_path / f"{model_name}.pt" + if model_path.exists(): + state_dict = torch.load(model_path, map_location="cpu") + learner.load_state_dict(state_dict) + else: + # Fallback: try to find any .pt file + pt_files = list(checkpoint_path.glob("*.pt")) + # Filter out checkpoint.pt and queue files + model_pt_files = [ + f + for f in pt_files + if f.name + not in [ + "checkpoint.pt", + "query_queue.pt", + "label_queue.pt", + "training_state.pt", + ] + ] + if model_pt_files: + state_dict = torch.load(model_pt_files[0], map_location="cpu") + learner.load_state_dict(state_dict) + + # Instantiate Driver + driver = cls( + config=driver_config, + learner=learner, + strategies_config=strategies_config, + training_config=training_config, + inference_fn=kwargs.get("inference_fn", None), + ) + + # Restore active learning state + driver.active_learning_step_idx = checkpoint.active_learning_step_idx + driver.current_phase = checkpoint.active_learning_phase + driver._last_checkpoint_path = checkpoint_path + + # Load training state (optimizer, scheduler) if training_config exists + # This delegates to the training loop's checkpoint loading logic + if driver.training_config is not None: + driver.configure_optimizer() + + # Use training loop to load training state (including model weights again if needed) + from physicsnemo.active_learning.loop import DefaultTrainingLoop + + DefaultTrainingLoop.load_training_checkpoint( + checkpoint_dir=checkpoint_path, + model=driver.learner, + optimizer=driver.optimizer, + lr_scheduler=driver.lr_scheduler + if hasattr(driver, "lr_scheduler") + else None, + ) + + # Restore queue states from separate files + if checkpoint.has_query_queue: + query_queue_file = checkpoint_path / "query_queue.pt" + driver._deserialize_queue(driver.query_queue, query_queue_file) + + if checkpoint.has_label_queue: + label_queue_file = checkpoint_path / "label_queue.pt" + driver._deserialize_queue(driver.label_queue, label_queue_file) + + driver.logger.info( + f"Loaded checkpoint from {checkpoint_path} at step " + f"{checkpoint.active_learning_step_idx}, phase {checkpoint.active_learning_phase}" + ) + if training_epoch is not None: + driver.logger.info(f"Resuming from training epoch {training_epoch}") + + return driver + + def barrier(self) -> None: + """ + Wrapper to call barrier on the correct device. + + Becomes a no-op if distributed is not initialized, otherwise + will attempt to read the local device ID from either the distributed manager + or the default device. + """ + if dist.is_initialized(): + if ( + self.dist_manager is not None + and self.dist_manager.device.type == "cuda" + ): + dist.barrier(device_ids=[self.dist_manager.local_rank]) + elif torch.get_default_device().type == "cuda": + # this might occur if distributed manager is not used + dist.barrier(device_ids=[torch.cuda.current_device()]) + else: + dist.barrier() + + def _configure_model(self) -> None: + """ + Method that encapsulates all the logic for preparing the model + ahead of time. + + If the distributed manager has been configured and initialized + with a world size greater than 1, then we wrap the model in DDP. + Otherwise, we simply move the model to the correct device. + + After the model has been moved to device, we configure the optimizer + and learning rate scheduler if training is enabled. + """ + if self.dist_manager is not None and self.dist_manager.is_initialized(): + if self.dist_manager.world_size > 1 and not isinstance( + self.learner, DistributedDataParallel + ): + # wrap the model in DDP + self.learner = torch.nn.parallel.DistributedDataParallel( + self.learner, + device_ids=[self.dist_manager.local_rank], + output_device=self.dist_manager.device, + broadcast_buffers=self.dist_manager.broadcast_buffers, + find_unused_parameters=self.dist_manager.find_unused_parameters, + ) + else: + if self.config.device is not None: + self.learner = self.learner.to(self.config.device, self.config.dtype) + # assume all device management is done via the dist_manager, so at this + # point the model is on the correct device and we can set up the optimizer + # if we intend to train + if not self.config.skip_training and not self.is_optimizer_configured: + self.configure_optimizer() + if self.is_optimizer_configured and self.config.reset_optim_states: + self.optimizer.load_state_dict(self._original_optim_state) + + def _get_phase_index(self, phase: p.ActiveLearningPhase | None) -> int: + """ + Get index of phase in execution order. + + Parameters + ---------- + phase: p.ActiveLearningPhase | None + Phase to find index for. If None, returns 0 (start from beginning). + + Returns + ------- + int + Index in _PHASE_ORDER (0-3). + """ + if phase is None: + return 0 + try: + return self._PHASE_ORDER.index(phase) + except ValueError: + self.logger.warning( + f"Unknown phase {phase}, defaulting to start from beginning" + ) + return 0 + + def _build_phase_queue( + self, + train_step_fn: p.TrainingProtocol | None, + validate_step_fn: p.ValidationProtocol | None, + args: tuple, + kwargs: dict, + ) -> list[Any]: + """ + Build list of phase functions to execute for this AL step. + + If current_phase is set (e.g., from checkpoint), only phases at or after + current_phase are included. Otherwise, all non-skipped phases are included. + + Parameters + ---------- + train_step_fn: p.TrainingProtocol | None + Training function to pass to training phase. + validate_step_fn: p.ValidationProtocol | None + Validation function to pass to training phase. + args: tuple + Additional arguments to pass to phase methods. + kwargs: dict + Additional keyword arguments to pass to phase methods. + + Returns + ------- + list[Callable] + Queue of phase functions to execute in order. + """ + # Define all possible phases with their execution conditions + all_phases = [ + ( + p.ActiveLearningPhase.TRAINING, + lambda: self._training_phase( + train_step_fn, validate_step_fn, *args, **kwargs + ), + not self.config.skip_training, + ), + ( + p.ActiveLearningPhase.METROLOGY, + lambda: self._metrology_phase(*args, **kwargs), + not self.config.skip_metrology, + ), + ( + p.ActiveLearningPhase.QUERY, + lambda: self._query_phase(*args, **kwargs), + True, # Query phase always runs + ), + ( + p.ActiveLearningPhase.LABELING, + lambda: self._labeling_phase(*args, **kwargs), + not self.config.skip_labeling, + ), + ] + + # Find starting index based on current_phase (resume point) + start_idx = self._get_phase_index(self.current_phase) + + if start_idx > 0: + self.logger.info( + f"Resuming AL step {self.active_learning_step_idx} from " + f"{self.current_phase}" + ) + + # Build queue: only phases from start_idx onwards that should run + phase_queue = [] + for idx, (phase, phase_fn, should_run) in enumerate(all_phases): + # Skip phases before current_phase + if idx < start_idx: + self.logger.debug( + f"Skipping {phase} (already completed in this AL step)" + ) + continue + + # Add phase to queue if not skipped by config + if should_run: + phase_queue.append(phase_fn) + else: + self.logger.debug(f"Skipping {phase} (disabled in config)") + + return phase_queue + + def _construct_dataloader( + self, pool: p.DataPool, shuffle: bool = False, drop_last: bool = False + ) -> DataLoader: + """ + Helper method to construct a data loader for a given data pool. + + In the case that a distributed manager was provided, then a distributed + sampler will be used, which will be bound to the current rank. + Otherwise, a regular sampler will be used. Similarly, if your data + structure requires a specialized function to construct batches, + then this function can be provided via the `collate_fn` argument. + + Parameters + ---------- + pool: p.DataPool + The data pool to construct a data loader for. + shuffle: bool = False + Whether to shuffle the data. + drop_last: bool = False + Whether to drop the last batch if it is not complete. + + Returns + ------- + DataLoader + The constructed data loader. + """ + # if a distributed manager was omitted, then we assume single process + if self.dist_manager is not None and self.dist_manager.is_initialized(): + sampler = DistributedSampler( + pool, + num_replicas=self.dist_manager.world_size, + rank=self.dist_manager.rank, + shuffle=shuffle, + drop_last=drop_last, + ) + # set to None, because sampler will handle instead + shuffle = None + else: + sampler = None + # fully spec out the data loader + pin_memory = False + if self.dist_manager is not None and self.dist_manager.is_initialized(): + if self.dist_manager.device.type == "cuda": + pin_memory = True + loader = DataLoader( + pool, + shuffle=shuffle, + sampler=sampler, + collate_fn=self.config.collate_fn, + batch_size=self.config.batch_size, + num_workers=self.config.num_dataloader_workers, + persistent_workers=self.config.num_dataloader_workers > 0, + pin_memory=pin_memory, + ) + return loader + + def active_learning_step( + self, + train_step_fn: p.TrainingProtocol | None = None, + validate_step_fn: p.ValidationProtocol | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Performs a single active learning iteration. + + This method will perform the following sequence of steps: + 1. Train the model stored in ``Driver.learner`` by creating data loaders + with ``Driver.train_datapool`` and ``Driver.val_datapool``. + 2. Run the metrology strategies stored in ``Driver.metrology_strategies``. + 3. Run the query strategies stored in ``Driver.query_strategies``, if available. + 4. Run the labeling strategy stored in ``Driver.label_strategy``, if available. + + When entering each stage, we check to ensure all components necessary for the + minimum function for that stage are available before proceeding. + + If current_phase is set (e.g., from checkpoint resumption), only phases at + or after current_phase will be executed. After completing all phases, + current_phase is reset to None for the next AL step. + + Parameters + ---------- + train_step_fn: p.TrainingProtocol | None = None + The training function to use for training. If not provided, then the + ``Driver.train_loop_fn`` will be used. + validate_step_fn: p.ValidationProtocol | None = None + The validation function to use for validation. If not provided, then + validation will not be performed. + args: Any + Additional arguments to pass to the method. These will be passed to the + training loop, metrology strategies, query strategies, and labeling strategies. + kwargs: Any + Additional keyword arguments to pass to the method. These will be passed to the + training loop, metrology strategies, query strategies, and labeling strategies. + + Raises + ------ + ValueError + If any of the required components for a stage are not available. + """ + self._setup_active_learning_step() + + # Build queue of phase functions based on current_phase + phase_queue = self._build_phase_queue( + train_step_fn, validate_step_fn, args, kwargs + ) + + # Execute each phase in order (de-populate queue) + for phase_fn in phase_queue: + phase_fn() + + # Reset current_phase after completing all phases in this AL step + self.current_phase = None + + self.logger.debug("Entering barrier for synchronization.") + self.barrier() + self.active_learning_step_idx += 1 + self.logger.info( + f"Completed active learning step {self.active_learning_step_idx}" + ) + + def _setup_active_learning_step(self) -> None: + """Initialize distributed manager and configure model for the active learning step.""" + if self.dist_manager is not None and not self.dist_manager.is_initialized(): + self.logger.info( + "Distributed manager configured but not initialized; initializing." + ) + self.dist_manager.initialize() + self._configure_model() + self.logger.info( + f"Starting active learning step {self.active_learning_step_idx}" + ) + + def _training_phase( + self, + train_step_fn: p.TrainingProtocol | None, + validate_step_fn: p.ValidationProtocol | None, + *args: Any, + **kwargs: Any, + ) -> None: + """Execute the training phase of the active learning step.""" + self._validate_training_requirements(train_step_fn, validate_step_fn) + + # don't need to barrier because it'll be done at the end of training anyway + with self._phase_context("training", call_barrier=False): + # Note: Training phase checkpointing is handled by the training loop itself + # during epoch execution based on model_checkpoint_frequency + + train_loader = self._construct_dataloader(self.train_datapool, shuffle=True) + self.logger.info( + f"There are {len(train_loader)} batches in the training loader." + ) + val_loader = None + if self.val_datapool is not None: + if validate_step_fn or hasattr(self.learner, "validation_step"): + val_loader = self._construct_dataloader( + self.val_datapool, shuffle=False + ) + else: + self.logger.warning( + "Validation data is available, but no `validate_step_fn` " + "or `validation_step` method in Learner is provided." + ) + # if a fine-tuning lr is provided, adjust it after the first iteration + if ( + self.config.fine_tuning_lr is not None + and self.active_learning_step_idx > 0 + ): + self.optimizer.param_groups[0]["lr"] = self.config.fine_tuning_lr + + # Determine max epochs to train for this AL step + if self.active_learning_step_idx > 0: + target_max_epochs = self.training_config.max_fine_tuning_epochs + else: + target_max_epochs = self.training_config.max_training_epochs + + # Check if resuming from mid-training checkpoint + start_epoch = 1 + epochs_to_train = target_max_epochs + + if self._last_checkpoint_path and self._last_checkpoint_path.exists(): + training_state_path = self._last_checkpoint_path / "training_state.pt" + if training_state_path.exists(): + training_state = torch.load( + training_state_path, map_location="cpu", weights_only=False + ) + last_completed_epoch = training_state.get("training_epoch", 0) + if last_completed_epoch > 0: + start_epoch = last_completed_epoch + 1 + epochs_to_train = target_max_epochs - last_completed_epoch + self.logger.info( + f"Resuming training from epoch {start_epoch} " + f"({epochs_to_train} epochs remaining)" + ) + + # Skip training if all epochs already completed + if epochs_to_train <= 0: + self.logger.info( + f"Training already complete ({target_max_epochs} epochs), " + f"skipping training phase" + ) + return + + device = ( + self.dist_manager.device + if self.dist_manager is not None + else self.config.device + ) + dtype = self.config.dtype + + # Set checkpoint directory and frequency on training loop + # This allows the training loop to handle training state checkpointing internally + if hasattr(self.train_loop_fn, "checkpoint_base_dir") and hasattr( + self.train_loop_fn, "checkpoint_frequency" + ): + # Checkpoint base is the current AL step's training directory + checkpoint_base = ( + self.log_dir + / "checkpoints" + / f"step_{self.active_learning_step_idx}" + / "training" + ) + self.train_loop_fn.checkpoint_base_dir = checkpoint_base + self.train_loop_fn.checkpoint_frequency = ( + self.config.model_checkpoint_frequency + ) + + self.train_loop_fn( + self.learner, + self.optimizer, + train_step_fn=train_step_fn, + validate_step_fn=validate_step_fn, + train_dataloader=train_loader, + validation_dataloader=val_loader, + lr_scheduler=self.lr_scheduler, + max_epochs=epochs_to_train, # Only remaining epochs + device=device, + dtype=dtype, + **kwargs, + ) + + def _metrology_phase(self, *args: Any, **kwargs: Any) -> None: + """Execute the metrology phase of the active learning step.""" + + with self._phase_context("metrology"): + for strategy in self.metrology_strategies: + self.logger.info( + f"Running metrology strategy: {strategy.__class__.__name__}" + ) + strategy(*args, **kwargs) + self.logger.info( + f"Completed metrics for strategy: {strategy.__class__.__name__}" + ) + strategy.serialize_records(*args, **kwargs) + + def _query_phase(self, *args: Any, **kwargs: Any) -> None: + """Execute the query phase of the active learning step.""" + with self._phase_context("query"): + for strategy in self.query_strategies: + self.logger.info( + f"Running query strategy: {strategy.__class__.__name__}" + ) + strategy(self.query_queue, *args, **kwargs) + + if self.query_queue.empty(): + self.logger.warning( + "Querying strategies produced no samples this iteration." + ) + + def _labeling_phase(self, *args: Any, **kwargs: Any) -> None: + """Execute the labeling phase of the active learning step.""" + self._validate_labeling_requirements() + + if self.query_queue.empty(): + self.logger.warning("No samples to label. Skipping labeling phase.") + return + + with self._phase_context("labeling"): + try: + self.label_strategy(self.query_queue, self.label_queue, *args, **kwargs) + except Exception as e: + self.logger.error(f"Exception encountered during labeling: {e}") + self.logger.info("Labeling completed. Now appending to training pool.") + + # TODO this is done serially, could be improved with batched writes + sample_counter = 0 + while not self.label_queue.empty(): + self.train_datapool.append(self.label_queue.get()) + sample_counter += 1 + self.logger.info(f"Appended {sample_counter} samples to training pool.") + + def _validate_training_requirements( + self, + train_step_fn: p.TrainingProtocol | None, + validate_step_fn: p.ValidationProtocol | None, + ) -> None: + """Validate that all required components for training are available.""" + if self.training_config is None: + raise ValueError( + "`training_config` must be provided if `skip_training` is False." + ) + if self.train_loop_fn is None: + raise ValueError("`train_loop_fn` must be provided in training_config.") + if self.train_datapool is None: + raise ValueError("`train_datapool` must be provided in training_config.") + if not train_step_fn and not hasattr(self.learner, "training_step"): + raise ValueError( + "`train_step_fn` must be provided if the model does not implement " + "the `training_step` method." + ) + if validate_step_fn and self.val_datapool is None: + raise ValueError( + "`val_datapool` must be provided in training_config if " + "`validate_step_fn` is provided." + ) + + def _validate_labeling_requirements(self) -> None: + """Validate that all required components for labeling are available.""" + if self.label_strategy is None: + raise ValueError( + "`label_strategy` must be provided in strategies_config if " + "`skip_labeling` is False." + ) + if self.training_config is None or self.train_datapool is None: + raise ValueError( + "`train_datapool` must be provided in training_config for " + "labeling, as data will be appended to it." + ) + + @contextmanager + def _phase_context( + self, phase_name: p.ActiveLearningPhase, call_barrier: bool = True + ) -> Generator[None, Any, None]: + """ + Context manager for consistent phase tracking, error handling, and synchronization. + + Sets the current phase for logging context, handles exceptions, + and synchronizes distributed workers with a barrier. Also triggers + checkpoint saves at the start of each phase if configured. + + Parameters + ---------- + phase_name: p.ActiveLearningPhase + A discrete phase of the active learning workflow. + call_barrier: bool + Whether to call barrier for synchronization at the end. + """ + self.current_phase = phase_name + + # Save checkpoint at START of phase if configured + # Exception: training phase handles checkpointing internally + if phase_name != p.ActiveLearningPhase.TRAINING: + should_checkpoint = getattr( + self.config, f"checkpoint_on_{phase_name}", False + ) + # Check if we should checkpoint based on interval + if should_checkpoint and self._should_checkpoint_at_step(): + self.save_checkpoint() + + try: + yield + except Exception as e: + self.logger.error(f"Exception encountered during {phase_name}: {e}") + raise + finally: + if call_barrier: + self.logger.debug("Entering barrier for synchronization.") + self.barrier() + + def run( + self, + train_step_fn: p.TrainingProtocol | None = None, + validate_step_fn: p.ValidationProtocol | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Runs the active learning loop until the maximum number of + active learning steps is reached. + + Parameters + ---------- + train_step_fn: p.TrainingProtocol | None = None + The training function to use for training. If not provided, then the + ``Driver.train_loop_fn`` will be used. + validate_step_fn: p.ValidationProtocol | None = None + The validation function to use for validation. If not provided, then + validation will not be performed. + args: Any + Additional arguments to pass to the method. These will be passed to the + training loop, metrology strategies, query strategies, and labeling strategies. + kwargs: Any + Additional keyword arguments to pass to the method. These will be passed to the + training loop, metrology strategies, query strategies, and labeling strategies. + """ + # TODO: refactor initialization logic here instead of inside the step + while self.active_learning_step_idx < self.config.max_active_learning_steps: + self.active_learning_step( + train_step_fn=train_step_fn, + validate_step_fn=validate_step_fn, + *args, + **kwargs, + ) + + def __call__( + self, + train_step_fn: p.TrainingProtocol | None = None, + validate_step_fn: p.ValidationProtocol | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Provides syntactic sugar for running the active learning loop. + + Calls ``Driver.run`` internally. + + Parameters + ---------- + train_step_fn: p.TrainingProtocol | None = None + The training function to use for training. If not provided, then the + ``Driver.train_loop_fn`` will be used. + validate_step_fn: p.ValidationProtocol | None = None + The validation function to use for validation. If not provided, then + validation will not be performed. + args: Any + Additional arguments to pass to the method. These will be passed to the + training loop, metrology strategies, query strategies, and labeling strategies. + kwargs: Any + Additional keyword arguments to pass to the method. These will be passed to the + training loop, metrology strategies, query strategies, and labeling strategies. + """ + self.run( + train_step_fn=train_step_fn, + validate_step_fn=validate_step_fn, + *args, + **kwargs, + ) diff --git a/physicsnemo/active_learning/logger.py b/physicsnemo/active_learning/logger.py new file mode 100644 index 0000000000..0442f287be --- /dev/null +++ b/physicsnemo/active_learning/logger.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +from contextlib import contextmanager +from datetime import datetime +from pathlib import Path +from threading import local +from typing import Any + +try: + from termcolor import colored +except ImportError: + colored = None + + +# Thread-local storage for context information +_context_storage = local() + + +class ActiveLearningLoggerAdapter(logging.LoggerAdapter): + """Logger adapter that automatically includes active learning iteration context. + + This adapter automatically adds iteration information to log messages + by accessing the driver's current iteration state. + """ + + def __init__(self, logger: logging.Logger, driver_ref: Any = None): + """Initialize the adapter with a logger and optional driver reference. + + Parameters + ---------- + logger : logging.Logger + The underlying logger to adapt + driver_ref : Any, optional + Reference to the driver object to get iteration context from + """ + super().__init__(logger, {}) + self.driver_ref = driver_ref + + def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]: + """Process the log message to add iteration, run ID, and phase context. + + Parameters + ---------- + msg : str + The log message + kwargs : dict[str, Any] + Additional keyword arguments + + Returns + ------- + tuple[str, dict[str, Any]] + Processed message and kwargs + """ + # Add iteration, run ID, and phase context if driver reference is available + if self.driver_ref is not None: + extra = kwargs.get("extra", {}) + + # Add iteration context + if hasattr(self.driver_ref, "active_learning_step_idx"): + iteration = getattr(self.driver_ref, "active_learning_step_idx", None) + if iteration is not None: + extra["iteration"] = iteration + + # Add run ID context + if hasattr(self.driver_ref, "run_id"): + run_id = getattr(self.driver_ref, "run_id", None) + if run_id is not None: + extra["run_id"] = run_id + + # Add current phase context + if hasattr(self.driver_ref, "current_phase"): + phase = getattr(self.driver_ref, "current_phase", None) + if phase is not None: + extra["phase"] = phase + + if extra: + kwargs["extra"] = extra + + return msg, kwargs + + +class JSONFormatter(logging.Formatter): + """JSON formatter for structured logging to files. + + This formatter converts log records to JSON format, including all + contextual information and metadata for structured analysis. + """ + + def format(self, record: logging.LogRecord) -> str: + """Format the log record as JSON. + + Parameters + ---------- + record : logging.LogRecord + The log record to format + + Returns + ------- + str + JSON-formatted log message + """ + log_entry = { + "timestamp": datetime.fromtimestamp(record.created).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "module": record.module, + "function": record.funcName, + "line": record.lineno, + } + + # Add contextual information if available + if hasattr(record, "context"): + log_entry["context"] = record.context + + if hasattr(record, "caller_object"): + log_entry["caller_object"] = record.caller_object + + if hasattr(record, "iteration"): + log_entry["iteration"] = record.iteration + + if hasattr(record, "phase"): + log_entry["phase"] = record.phase + + extra_keys = list(filter(lambda x: x not in log_entry, record.__dict__.keys())) + # Add any extra fields + for key in extra_keys: + log_entry[key] = record.__dict__[key] + + return json.dumps(log_entry) + + +def _get_context_stack(): + """Get the context stack for the current thread.""" + if not hasattr(_context_storage, "context_stack"): + _context_storage.context_stack = [] + return _context_storage.context_stack + + +class ContextFormatter(logging.Formatter): + """Standard formatter that includes active learning context information with colors.""" + + def format(self, record): + # Build context string + context_parts = [] + if hasattr(record, "caller_object") and record.caller_object: + context_parts.append(f"obj:{record.caller_object}") + if hasattr(record, "run_id") and record.run_id: + context_parts.append(f"run:{record.run_id}") + if hasattr(record, "iteration") and record.iteration is not None: + context_parts.append(f"iter:{record.iteration}") + if hasattr(record, "phase") and record.phase: + context_parts.append(f"phase:{record.phase}") + if hasattr(record, "context") and record.context: + for key, value in record.context.items(): + context_parts.append(f"{key}:{value}") + + context_str = f"[{', '.join(context_parts)}]" if context_parts else "" + + # Use standard formatting + base_msg = super().format(record) + + # Add color to the message based on level if termcolor is available + if colored is not None: + match record.levelno: + case level if level >= logging.ERROR: + base_msg = colored(base_msg, "red") + case level if level >= logging.WARNING: + base_msg = colored(base_msg, "yellow") + case level if level >= logging.INFO: + base_msg = colored(base_msg, "white") + case _: # DEBUG + base_msg = colored(base_msg, "cyan") + + # Add colored context string + if context_str: + if colored is not None: + context_str = colored(context_str, "blue") + base_msg += f" {context_str}" + + return base_msg + + +class ContextInjectingFilter(logging.Filter): + """Filter that injects contextual information into log records.""" + + def filter(self, record): + # Add context information from thread-local storage + context_stack = _get_context_stack() + if context_stack: + current_context = context_stack[-1] + if current_context["caller_object"]: + record.caller_object = current_context["caller_object"] + if current_context["iteration"] is not None: + record.iteration = current_context["iteration"] + if current_context.get("phase"): + record.phase = current_context["phase"] + if current_context["context"]: + record.context = current_context["context"] + return True + + +def setup_active_learning_logger( + name: str, + run_id: str, + log_dir: str | Path = Path("active_learning_logs"), + level: int = logging.INFO, +) -> logging.Logger: + """Set up a logger with active learning-specific formatting and handlers. + + Parameters + ---------- + name : str + Logger name + run_id : str + Unique identifier for this run, used in log filename + log_dir : str | Path, optional + Directory to store log files, by default "./logs" + level : int, optional + Logging level, by default logging.INFO + + Returns + ------- + logging.Logger + Configured standard Python logger + + Example + ------- + >>> logger = setup_active_learning_logger("experiment", "run_001") + >>> logger.info("Starting experiment") + >>> with log_context(caller_object="Trainer", iteration=5): + ... logger.info("Training step") + """ + # Get standard logger + logger = logging.getLogger(name) + logger.setLevel(level) + + # Clear any existing handlers to avoid duplicates + logger.handlers.clear() + + # Disable propagation to prevent duplicate messages from parent loggers + logger.propagate = False + + # Create log directory if it doesn't exist + if isinstance(log_dir, str): + log_dir_path = Path(log_dir) + else: + log_dir_path = log_dir + log_dir_path.mkdir(parents=True, exist_ok=True) + + # Set up console handler with standard formatting + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + console_handler.setFormatter( + ContextFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) + console_handler.addFilter(ContextInjectingFilter()) + logger.addHandler(console_handler) + + # Set up file handler with JSON formatting + log_file = log_dir_path / f"{run_id}.log" + file_handler = logging.FileHandler(log_file, mode="w") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(JSONFormatter()) + file_handler.addFilter(ContextInjectingFilter()) + logger.addHandler(file_handler) + + return logger + + +@contextmanager +def log_context( + caller_object: str | None = None, + iteration: int | None = None, + phase: str | None = None, + **kwargs: Any, +): + """Context manager for adding contextual information to log messages. + + Parameters + ---------- + caller_object : str, optional + Name or identifier of the object making the log call + iteration : int, optional + Current iteration counter + phase : str, optional + Current phase of the active learning process + **kwargs : Any + Additional contextual key-value pairs + + Example + ------- + >>> from logging import getLogger + >>> from physicsnemo.active_learning.logger import log_context + >>> logger = getLogger("my_logger") + >>> with log_context(caller_object="Trainer", iteration=5, phase="training", epoch=2): + ... logger.info("Processing batch") + """ + context_info = { + "caller_object": caller_object, + "iteration": iteration, + "phase": phase, + "context": kwargs, + } + + context_stack = _get_context_stack() + context_stack.append(context_info) + + try: + yield + finally: + context_stack.pop() diff --git a/physicsnemo/active_learning/loop.py b/physicsnemo/active_learning/loop.py new file mode 100644 index 0000000000..f1ac87f74f --- /dev/null +++ b/physicsnemo/active_learning/loop.py @@ -0,0 +1,534 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from pathlib import Path +from typing import Any + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm + +from physicsnemo import Module +from physicsnemo.active_learning import protocols as p +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import LaunchLogger +from physicsnemo.utils.capture import StaticCaptureEvaluateNoGrad, StaticCaptureTraining + +__all__ = ["DefaultTrainingLoop"] + + +def _recursive_data_device_cast( + data: Any, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + **kwargs: Any, +) -> Any: + """ + Recursively moves/cast input data to a specified device and dtype. + + For iterable objects, we recurse through the elements depending on + the type of iterable until we reach an object that either has a ``to`` + method that can be called, or just returns the data unchanged. + + Parameters + ---------- + data: Any + The data to move to the device. + device: torch.device | str | None = None + The device to move the data to. + dtype: torch.dtype | None = None + The dtype to move the data to. + kwargs: Any + Additional keyword arguments to pass to the `to` method. + By default, `non_blocking` is set to `True` to allow + asynchronous data transfers. + + Returns + ------- + Any + The data moved to the device. + """ + kwargs.setdefault("non_blocking", True) + if hasattr(data, "to"): + # if there is a `to` method, then we can just call it + return data.to(device=device, dtype=dtype, **kwargs) + elif isinstance(data, dict): + return { + k: _recursive_data_device_cast(v, device, dtype) for k, v in data.items() + } + elif isinstance(data, list): + return [_recursive_data_device_cast(v, device, dtype) for v in data] + elif isinstance(data, tuple): + return tuple(_recursive_data_device_cast(v, device, dtype) for v in data) + else: + return data + + +class DefaultTrainingLoop(p.TrainingLoop): + def __new__(cls, *args: Any, **kwargs: Any) -> DefaultTrainingLoop: + """ + Wrapper for instantiating DefaultTrainingLoop. + + This method captures arguments used to instantiate the loop + and stores them in the `_args` attribute for serialization. + This follows the same pattern as `ActiveLearningProtocol.__new__`. + + Parameters + ---------- + args: Any + Arguments to pass to the loop's constructor. + kwargs: Any + Keyword arguments to pass to the loop's constructor. + + Returns + ------- + DefaultTrainingLoop + A new instance with an `_args` attribute for serialization. + """ + out = super().__new__(cls) + + # Get signature of __init__ function + sig = inspect.signature(cls.__init__) + + # Bind args and kwargs to signature + bound_args = sig.bind_partial( + *([None] + list(args)), **kwargs + ) # Add None to account for self + bound_args.apply_defaults() + + # Get args and kwargs (excluding self and unroll kwargs) + instantiate_args = {} + for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()): + # Skip self + if k == "self": + continue + + # Add args and kwargs to instantiate_args + if param.kind == param.VAR_KEYWORD: + instantiate_args.update(v) + else: + # Special handling for device: convert torch.device to string + if k == "device" and isinstance(v, torch.device): + instantiate_args[k] = str(v) + # Special handling for dtype: convert to string representation + elif k == "dtype" and isinstance(v, torch.dtype): + instantiate_args[k] = str(v) + else: + instantiate_args[k] = v + + # Store args needed for instantiation + out._args = { + "__name__": cls.__name__, + "__module__": cls.__module__, + "__args__": instantiate_args, + } + return out + + def __init__( + self, + train_step_fn: p.TrainingProtocol | None = None, + validate_step_fn: p.ValidationProtocol | None = None, + enable_static_capture: bool = True, + use_progress_bars: bool = True, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + checkpoint_frequency: int = 0, + **capture_kwargs: Any, + ) -> None: + """ + Initializes the default training loop. + + The general usage of this loop is to + + TODO: add support for early stopping + + Parameters + ---------- + train_step_fn: TrainingProtocol | None = None + A callable that implements the logic for performing a single + training step. See ``protocols.TrainingProtocol`` for the expected + interface, but ultimately the function should return a scalar loss + value that has a ``backward`` method. + validate_step_fn: ValidationProtocol | None = None + A callable that implements the logic for performing a single + validation step. See ``protocols.ValidationProtocol`` for the expected + interface, but in contrast to ``train_step_fn`` this function should + not return anything. + enable_static_capture: bool = True + Whether to enable static capture for the training and validation steps. + use_progress_bars: bool = True + Whether to show ``tqdm`` progress bars to display epoch and step progress. + device: str | torch.device | None = None + The device used for performing the loop. If not provided, then the device + will default to the model's device at runtime. + dtype: torch.dtype | None = None + The dtype used for performing the loop. If not provided, then the dtype + will default to ``torch.get_default_dtype()``. + checkpoint_frequency: int = 0 + How often to save checkpoints during training (every N epochs). + If 0, no checkpoints are saved during training. Set via Driver before + training execution. + capture_kwargs: Any + Additional keyword arguments to pass to the static capture decorators. + """ + self.train_step_fn = train_step_fn + self.validate_step_fn = validate_step_fn + self.enable_static_capture = enable_static_capture + if isinstance(device, str): + device = torch.device(device) + # check to see if we can rely on DistributedManager + if device is None and DistributedManager.is_initialized(): + device = DistributedManager.device + self.device = device + if dtype is None: + dtype = torch.get_default_dtype() + self.dtype = dtype + self.capture_kwargs = capture_kwargs + self.use_progress_bars = use_progress_bars + self.capture_functions = {} + self.checkpoint_frequency = checkpoint_frequency + self.checkpoint_base_dir: Path | None = None + + def save_training_checkpoint( + self, + checkpoint_dir: Path, + model: Module | p.LearnerProtocol, + optimizer: Optimizer, + lr_scheduler: _LRScheduler | None = None, + training_epoch: int | None = None, + ) -> None: + """ + Save training state to checkpoint directory. + + Model weights are saved separately. Optimizer, scheduler, and epoch + metadata are combined into a single training_state.pt file. + + Parameters + ---------- + checkpoint_dir: Path + Directory to save checkpoint files. + model: Module | p.LearnerProtocol + Model to save weights for. + optimizer: Optimizer + Optimizer to save state from. + lr_scheduler: _LRScheduler | None + Optional LR scheduler to save state from. + training_epoch: int | None + Current training epoch for metadata. + """ + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Save model weights separately + if isinstance(model, Module): + model_path = checkpoint_dir / "model.mdlus" + model.save(str(model_path)) + else: + model_path = checkpoint_dir / "model_state.pt" + torch.save(model.state_dict(), model_path) + + # Combine optimizer, scheduler, and epoch metadata into single file + training_state = { + "optimizer_state": optimizer.state_dict(), + "lr_scheduler_state": lr_scheduler.state_dict() if lr_scheduler else None, + "training_epoch": training_epoch, + } + training_state_path = checkpoint_dir / "training_state.pt" + torch.save(training_state, training_state_path) + + @staticmethod + def load_training_checkpoint( + checkpoint_dir: Path, + model: Module | p.LearnerProtocol, + optimizer: Optimizer, + lr_scheduler: _LRScheduler | None = None, + ) -> int | None: + """ + Load training state from checkpoint directory. + + Model weights are loaded separately. Optimizer, scheduler, and epoch + metadata are loaded from the combined training_state.pt file. + + Parameters + ---------- + checkpoint_dir: Path + Directory containing checkpoint files. + model: Module | p.LearnerProtocol + Model to load weights into. + optimizer: Optimizer + Optimizer to load state into. + lr_scheduler: _LRScheduler | None + Optional LR scheduler to load state into. + + Returns + ------- + int | None + Training epoch from metadata if available, else None. + """ + # Load model weights separately + if isinstance(model, Module): + model_path = checkpoint_dir / "model.mdlus" + if model_path.exists(): + model.load(str(model_path)) + else: + model_state_path = checkpoint_dir / "model_state.pt" + if model_state_path.exists(): + state_dict = torch.load(model_state_path, map_location="cpu") + model.load_state_dict(state_dict) + + # Load combined training state (optimizer, scheduler, epoch) + training_state_path = checkpoint_dir / "training_state.pt" + if training_state_path.exists(): + training_state = torch.load(training_state_path, map_location="cpu") + + # Restore optimizer state + if "optimizer_state" in training_state: + optimizer.load_state_dict(training_state["optimizer_state"]) + + # Restore scheduler state if present + if lr_scheduler and training_state.get("lr_scheduler_state"): + lr_scheduler.load_state_dict(training_state["lr_scheduler_state"]) + + # Return epoch metadata + return training_state.get("training_epoch", None) + + return None + + @property + def amp_type(self) -> torch.dtype: + if self.dtype in [torch.float16, torch.bfloat16]: + return self.dtype + else: + return torch.float16 + + def _create_capture_functions( + self, + model: Module | p.LearnerProtocol, + optimizer: Optimizer, + train_step_fn: p.TrainingProtocol | None = None, + validate_step_fn: p.ValidationProtocol | None = None, + ) -> tuple[p.TrainingProtocol | None, p.ValidationProtocol | None]: + """ + Attempt to create static capture functions based off training and validation + functions. + + This uses the Python object IDs to unique identify functions, and adds the + decorated functions to an internal `capture_functions` dictionary. If the + decorated functions already exist, then this function will be no-op. + + Parameters + ---------- + model: Module | p.LearnerProtocol + The model to train. + optimizer: Optimizer + The optimizer to use for training. + train_step_fn: p.TrainingProtocol | None = None + The training function to use for training. + validate_step_fn: p.ValidationProtocol | None = None + The validation function to use for validation. + + Returns + ------- + tuple[p.TrainingProtocol | None, p.ValidationProtocol | None] + The training and validation functions with static capture applied. + """ + if not train_step_fn: + train_step_fn = self.train_step_fn + train_func_id = id(train_step_fn) + if train_func_id not in self.capture_functions: + try: + train_step_fn = StaticCaptureTraining( + model=model, + optim=optimizer, + amp_type=self.amp_type, + **self.capture_kwargs, + )(train_step_fn) + self.capture_functions[train_func_id] = train_step_fn + except Exception as e: + raise RuntimeError( + "Failed to create static capture for `train_step_fn`. " + ) from e + else: + train_step_fn = self.capture_functions[train_func_id] + if not validate_step_fn: + validate_step_fn = self.validate_step_fn + if validate_step_fn: + val_func_id = id(validate_step_fn) + if val_func_id not in self.capture_functions: + try: + validate_step_fn = StaticCaptureEvaluateNoGrad( + model=model, amp_type=self.amp_type, **self.capture_kwargs + )(validate_step_fn) + self.capture_functions[val_func_id] = validate_step_fn + except Exception as e: + raise RuntimeError( + "Failed to create static capture for `validate_step_fn`. " + ) from e + else: + validate_step_fn = self.capture_functions[val_func_id] + return train_step_fn, validate_step_fn + + def __call__( + self, + model: Module | p.LearnerProtocol, + optimizer: Optimizer, + train_dataloader: DataLoader, + max_epochs: int, + validation_dataloader: DataLoader | None = None, + train_step_fn: p.TrainingProtocol | None = None, + validate_step_fn: p.ValidationProtocol | None = None, + lr_scheduler: _LRScheduler | None = None, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Performs ``max_epochs`` epochs of training and optionally validation. + + Some of the arguments, such as ``train_step_fn`` and ``validate_step_fn``, + are optional only if the ``model`` implements the ``p.LearnerProtocol``. + If they are passed, however, they will take precedence over the methods + originally provided to the constructor method. + + The bare minimum required arguments for this loop to work are: + 1. A model to train + 2. An optimizer to step + 3. A training dataloader to iterate over + 4. The maximum number of epochs to train for + + If validation is required, then both ``validation_dataloader`` and + ``validate_step_fn`` must be specified. + + Parameters + ---------- + model: Module | p.LearnerProtocol + The model to train. + optimizer: torch.optim.Optimizer + The optimizer to use for training. + train_dataloader: DataLoader + The dataloader to use for training. + max_epochs: int + The number of epochs to train for. + validation_dataloader: DataLoader | None + The dataloader to use for validation. If not provided, then validation + will not be performed. + train_step_fn: p.TrainingProtocol | None = None + The training function to use for training. If passed, it will take + precedence over the method provided to the constructor method. + validate_step_fn: p.ValidationProtocol | None = None + The validation function to use for validation. + lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None = None + The learning rate scheduler to use for training. + device: str | torch.device | None = None + The device used for performing the loop. If provided, it will + override the device specified in the constructor. If both values + are not provided, then we default to PyTorch's default device. + dtype: torch.dtype | None = None + The dtype used for performing the loop. If provided, it will + override the dtype specified in the constructor. If both values + are not provided, then we default to PyTorch's default dtype. + args: Any + Additional arguments to pass the training and validation + step functions. + kwargs: Any + Additional keyword arguments to pass the training and validation + step functions. + """ + if not train_step_fn and not self.train_step_fn: + raise RuntimeError( + """ + No training step function provided. + Either provide a `train_step_fn` to this constructor, or + provide a `train_step_fn` to the `__call__` method. + """ + ) + if not device and not self.device: + device = torch.get_default_device() + if not dtype and not self.dtype: + dtype = torch.get_default_dtype() + # if a device is specified, move the model + if device and device != model.device: + # not 100% sure this will trigger issues with the optimizer + # but allows a potentially different device to be used + model = model.to(device) + if self.enable_static_capture: + # if static capture is enabled, we check for a cache hit based on + # the incoming function IDs. If we miss, we then create new wrappers. + train_step_fn, validate_step_fn = self._create_capture_functions( + model, optimizer, train_step_fn, validate_step_fn + ) + epoch_iter = range(1, max_epochs + 1) + if self.use_progress_bars: + epoch_iter = tqdm(epoch_iter, desc="Epoch", leave=False, position=0) + ########### EPOCH LOOP ########### + for epoch in epoch_iter: + model.train() + train_iter = iter(train_dataloader) + if self.use_progress_bars: + train_iter = tqdm( + train_iter, desc="Training step", leave=False, unit="batch" + ) + ########### TRAINING STEP LOOP ########### + with LaunchLogger( + "train", epoch=epoch, num_mini_batch=len(train_dataloader) + ) as log: + for batch in train_iter: + batch = _recursive_data_device_cast( + batch, device=device, dtype=dtype + ) + model.zero_grad(set_to_none=True) + loss = train_step_fn(model, batch, *args, **kwargs) + log.log_minibatch({"train_loss": loss.detach().item()}) + # normally, static capture will call backward because of AMP + if not self.enable_static_capture: + loss.backward() + optimizer.step() + if lr_scheduler: + lr_scheduler.step() + ########### VALIDATION STEP LOOP ########### + if validate_step_fn and validation_dataloader: + model.eval() + val_iter = iter(validation_dataloader) + if self.use_progress_bars: + val_iter = tqdm( + val_iter, desc="Validation step", leave=False, unit="batch" + ) + with LaunchLogger( + "validation", epoch=epoch, num_mini_batch=len(validation_dataloader) + ) as log: + for batch in val_iter: + batch = _recursive_data_device_cast( + batch, device=device, dtype=dtype + ) + validate_step_fn(model, batch, *args, **kwargs) + + ########### CHECKPOINT SAVE ########### + # Save training state at specified frequency + if self.checkpoint_base_dir and self.checkpoint_frequency > 0: + if epoch % self.checkpoint_frequency == 0: + epoch_checkpoint_dir = self.checkpoint_base_dir / f"epoch_{epoch}" + self.save_training_checkpoint( + checkpoint_dir=epoch_checkpoint_dir, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + training_epoch=epoch, + ) diff --git a/physicsnemo/active_learning/protocols.py b/physicsnemo/active_learning/protocols.py new file mode 100644 index 0000000000..6eb9f9e60e --- /dev/null +++ b/physicsnemo/active_learning/protocols.py @@ -0,0 +1,1394 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains base classes for active learning protocols. + +These are protocols intended to be abstract, and importing these +classes specifically is intended to either be subclassed, or for +type annotations. + +Protocol Architecture +--------------------- +Python ``Protocol``s are used for structural typing: essentially, they are used to +describe an expected interface in a way that is helpful for static type checkers +to make sure concrete implementations provide everything that is needed for a workflow +to function. ``Protocol``s are not actually enforced at runtime, and inheritance is not +required for them to function: as long as the implementation provides the expected +attributes and methods, they will be compatible with the protocol. + +The active learning framework is built around several key protocol abstractions +that work together to orchestrate the active learning workflow: + +**Core Infrastructure Protocols:** + - `AbstractQueue[T]` - Generic queue protocol for passing data between components + - `DataPool[T]` - Protocol for data reservoirs that support appending and sampling + - `ActiveLearningProtocol` - Base protocol providing common interface for all AL strategies + +**Strategy Protocols (inherit from ActiveLearningProtocol):** + - `QueryStrategy` - Defines how to select data points for labeling + - `LabelStrategy` - Defines processes for adding ground truth labels to unlabeled data + - `MetrologyStrategy` - Defines procedures that assess model improvements beyond validation metrics + +**Model Interface Protocols:** + - `TrainingProtocol` - Interface for training step functions + - `ValidationProtocol` - Interface for validation step functions + - `InferenceProtocol` - Interface for inference step functions + - `TrainingLoop` - Interface for complete training loop implementations + - `LearnerProtocol` - Comprehensive interface for learner modules (combines training/validation/inference) + +**Orchestration Protocol:** + - `DriverProtocol` - Main orchestrator that coordinates all components in the active learning loop + +Protocol Relationships +---------------------- + +```mermaid +graph TB + subgraph "Core Infrastructure" + AQ[AbstractQueue<T>] + DP[DataPool<T>] + ALP[ActiveLearningProtocol] + end + + subgraph "Strategy Layer" + QS[QueryStrategy] + LS[LabelStrategy] + MS[MetrologyStrategy] + end + + subgraph "Model Interface Layer" + TP[TrainingProtocol] + VP[ValidationProtocol] + IP[InferenceProtocol] + TL[TrainingLoop] + LP[LearnerProtocol] + end + + subgraph "Orchestration Layer" + Driver[DriverProtocol] + end + + %% Inheritance relationships (thick blue arrows) + ALP ==>|inherits| QS + ALP ==>|inherits| LS + ALP ==>|inherits| MS + + %% Composition relationships (dashed green arrows) + Driver -.->|uses| LP + Driver -.->|manages| QS + Driver -.->|manages| LS + Driver -.->|manages| MS + Driver -.->|contains| DP + Driver -.->|contains| AQ + + %% Protocol usage relationships (dotted purple arrows) + TL -.->|can use| TP + TL -.->|can use| VP + TL -.->|can use| LP + LP -.->|implements| TP + LP -.->|implements| VP + LP -.->|implements| IP + + %% Data flow relationships (solid red arrows) + QS -->|enqueues to| AQ + AQ -->|consumed by| LS + LS -->|enqueues to| AQ + + %% Styling for different relationship types + linkStyle 0 stroke:#1976d2,stroke-width:4px + linkStyle 1 stroke:#1976d2,stroke-width:4px + linkStyle 2 stroke:#1976d2,stroke-width:4px + linkStyle 3 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5 + linkStyle 4 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5 + linkStyle 5 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5 + linkStyle 6 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5 + linkStyle 7 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5 + linkStyle 8 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5 + linkStyle 9 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2 + linkStyle 10 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2 + linkStyle 11 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2 + linkStyle 12 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2 + linkStyle 13 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2 + linkStyle 14 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2 + linkStyle 15 stroke:#d32f2f,stroke-width:3px + linkStyle 16 stroke:#d32f2f,stroke-width:3px + linkStyle 17 stroke:#d32f2f,stroke-width:3px + + %% Node styling + classDef coreInfra fill:#e3f2fd,stroke:#1976d2,stroke-width:2px + classDef strategy fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px + classDef modelInterface fill:#e8f5e8,stroke:#388e3c,stroke-width:2px + classDef orchestration fill:#fff3e0,stroke:#f57c00,stroke-width:3px + + class AQ,DP,ALP coreInfra + class QS,LS,MS strategy + class TP,VP,IP,TL,LP modelInterface + class Driver orchestration +``` + +**Relationship Legend:** +- **Blue thick arrows (==>)**: Inheritance relationships (subclass extends parent) +- **Green dashed arrows (-.->)**: Composition relationships (object contains/manages other objects) +- **Purple dotted arrows (-.->)**: Protocol usage relationships (can use or implements interface) +- **Red solid arrows (-->)**: Data flow relationships (data moves between components) + +Active Learning Workflow +------------------------ + +The typical active learning workflow orchestrated by `DriverProtocol` follows this sequence: + +1. **Training Phase**: Use `LearnerProtocol` or `TrainingLoop` to train the model on `training_pool` +2. **Metrology Phase** (optional): Apply `MetrologyStrategy` instances to assess model performance +3. **Query Phase**: Apply `QueryStrategy` instances to select samples from `unlabeled_pool` → `query_queue` +4. **Labeling Phase** (optional): Apply `LabelStrategy` instances to label queued samples → `label_queue` +5. **Data Integration**: Move labeled data from `label_queue` to `training_pool` + +Type Parameters +--------------- +- `T`: Data structure containing both inputs and ground truth labels +- `S`: Data structure containing only inputs (no ground truth labels) +""" + +from __future__ import annotations + +import inspect +import logging +from enum import StrEnum +from logging import Logger +from pathlib import Path +from typing import Any, Iterator, Protocol, TypeVar + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader + +from physicsnemo import Module + +# T is used to denote a data structure that contains inputs for a model and ground truths +T = TypeVar("T") +# S is used to denote a data structure that has inputs for a model, but no ground truth labels +S = TypeVar("S") + + +class ActiveLearningPhase(StrEnum): + """ + An enumeration of the different phases of the active learning workflow. + + This is primarily used in the metadata for restarting an ongoing active + learning experiment. + """ + + TRAINING = "training" + METROLOGY = "metrology" + QUERY = "query" + LABELING = "labeling" + DATA_INTEGRATION = "data_integration" + + +class AbstractQueue(Protocol[T]): + """ + Defines a generic queue protocol for data that is passed between active + learning components. + + This can be a simple local `queue.Queue`, or a more sophisticated + distributed queue system. + + The primary use case for this is to allow a query strategy to + enqueue some data structure for the labeling strategy to consume, + and once the labeling is done, enqueue to a data serialization + workflow. While there is no explcit restriction on the **type** + of queue that is implemented, a reasonable assumption to make + would be a FIFO queue, unless otherwise specified by the concrete + implementation. + + Optional Serialization Methods + ------------------------------- + Implementations may optionally provide `to_list()` and `from_list()` + methods for checkpoint serialization. If not provided, the queue + will be serialized using `torch.save()` as a fallback. + + Type Parameters + --------------- + T + The type of items that will be stored in the queue. + """ + + def put(self, item: T) -> None: + """ + Method to put a data structure into the queue. + + Parameters + ---------- + item: T + The data structure to put into the queue. + """ + ... + + def get(self) -> T: + """ + Method to get a data structure from the queue. + + This method should remove the data structure from the queue, + and return it to a consumer. + + Returns + ------- + T + The data structure that was removed from the queue. + """ + ... + + def empty(self) -> bool: + """ + Method to check if the queue is empty/has been depleted. + + Returns + ------- + bool + True if the queue is empty, False otherwise. + """ + ... + + +class DataPool(Protocol[T]): + """ + An abstract protocol for some reservoir of data that is + used for some part of active learning, parametrized such + that it will return data structures of an arbitrary type ``T``. + + **All** methods are left abstract, and need to be defined + by concrete implementations. For the most part, a `torch.utils.data.Dataset` + would match this protocol, provided that it implements the ``append`` method + which will allow data to be persisted to a filesystem. + + Methods + ------- + __getitem__(self, index: int) -> T: + Method to get a single data structure from the data pool. + __len__(self) -> int: + Method to get the length of the data pool. + __iter__(self) -> Iterator[T]: + Method to iterate over the data pool. + append(self, item: T) -> None: + Method to append a data structure to the data pool. + """ + + def __getitem__(self, index: int) -> T: + """ + Method to get a data structure from the data pool. + + This method should retrieve an item from the pool by a + flat index. + + Parameters + ---------- + index: int + The index of the data structure to get. + + Returns + ------- + T + The data structure at the given index. + """ + ... + + def __len__(self) -> int: + """ + Method to get the length of the data pool. + + Returns + ------- + int + The length of the data pool. + """ + ... + + def __iter__(self) -> Iterator[T]: + """ + Method to iterate over the data pool. + + This method should return an iterator over the data pool. + + Returns + ------- + Iterator[T] + An iterator over the data pool. + """ + ... + + def append(self, item: T) -> None: + """ + Method to append a data structure to the data pool. + + For persistent storage pools, this will actually mean that the + ``item`` is serialized to a filesystem. + + Parameters + ---------- + item: T + The data structure to append to the data pool. + """ + ... + + +class ActiveLearningProtocol(Protocol): + """ + This protocol acts as a basis for all active learning protocols. + + This ensures that all protocols have some common interface, for + example the ability to `attach` to another object for scope + management. + + Attributes + ---------- + __protocol_name__: str + The name of the protocol. This is primarily used for `repr` + and `str` f-strings. This should be defined by concrete + implementations. + _args: dict[str, Any] + A dictionary of arguments that were used to instantiate the protocol. + This is used for serialization and deserialization of the protocol, + and follows the same pattern as the ``_args`` attribute of + ``physicsnemo.Module``. + + Methods + ------- + attach(self, other: object) -> None: + This method is used to attach the current object to another, + allowing the protocol to access the attached object's scope. + The use case for this is to allow a protocol access to the + driver's scope to access dataset, model, etc. as needed. + This needs to be implemented by concrete implementations. + is_attached: bool + Whether the current object is attached to another object. + This is left abstract, as it depends on how ``attach`` is implemented. + logger: Logger + The logger for this protocol. This is used to log information + about the protocol's progress. + _setup_logger(self) -> None: + This method is used to setup the logger for the protocol. + The default implementation is to configure the logger similarly + to how ``physicsnemo`` loggers are configured. + """ + + __protocol_name__: str + __protocol_type__: ActiveLearningPhase + _args: dict[str, Any] + + def __new__(cls, *args: Any, **kwargs: Any) -> ActiveLearningProtocol: + """ + Wrapper for instantiating any subclass of `ActiveLearningProtocol`. + + This method will use `inspect` to capture arguments and keyword + arguments that were used to instantiate the protocol, and stash + them into the `_args` attribute of the instance, following + what is done with `physicsnemo.Module`. + + This approach is useful for reconstructing strategies from checkpoints. + + Parameters + ---------- + args: Any + Arguments to pass to the protocol's constructor. + kwargs: Any + Keyword arguments to pass to the protocol's constructor. + + Returns + ------- + ActiveLearningProtocol + A new instance of the protocol class. The instance will have an + `_args` attribute that contains the keys `__name__`, `__module__`, + and `__args__` as metadata for the protocol. + """ + out = super().__new__(cls) + + # Get signature of __init__ function + sig = inspect.signature(cls.__init__) + + # Bind args and kwargs to signature + bound_args = sig.bind_partial( + *([None] + list(args)), **kwargs + ) # Add None to account for self + bound_args.apply_defaults() + + # Get args and kwargs (excluding self and unroll kwargs) + instantiate_args = {} + for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()): + # Skip self + if k == "self": + continue + + # Add args and kwargs to instantiate_args + if param.kind == param.VAR_KEYWORD: + instantiate_args.update(v) + else: + instantiate_args[k] = v + + # Store args needed for instantiation + out._args = { + "__name__": cls.__name__, + "__module__": cls.__module__, + "__args__": instantiate_args, + } + return out + + def attach(self, other: object) -> None: + """ + This method is used to attach another object to the current protocol, + allowing the attached object to access the scope of this protocol. + The primary reason for this is to allow the protocol to access + things like the dataset, the learner model, etc. as needed. + + Example use cases would be for a query strategy to access the ``unlabeled_pool``; + for a metrology strategy to access the ``validation_pool``, and for any + strategy to be able to access the surrogate/learner model. + + This method can be as simple as setting ``self.driver = other``, but + is left abstract in case there are other potential use cases + where multiple protocols could share information. + + Parameters + ---------- + other: object + The object to attach to. + """ + ... + + @property + def is_attached(self) -> bool: + """ + Property to check if the current object is already attached. + + This is left abstract, as it depends on how ``attach`` is implemented. + + Returns + ------- + bool + True if the current object is attached, False otherwise. + """ + ... + + @property + def logger(self) -> Logger: + """ + Property to access the logger for this protocol. + + If the logger has not been configured yet, the property + will call the `_setup_logger` method to configure it. + + Returns + ------- + Logger + The logger for this protocol. + """ + if not hasattr(self, "_logger"): + self._setup_logger() + return self._logger + + @logger.setter + def logger(self, logger: Logger) -> None: + """ + Setter for the logger for this protocol. + + Parameters + ---------- + logger: Logger + The logger to set for this protocol. + """ + self._logger = logger + + def _setup_logger(self) -> None: + """ + Method to setup the logger for all active learning protocols. + + Each protocol should have their own logger + """ + self.logger = logging.getLogger( + f"core.active_learning.{self.__protocol_name__}" + ) + # Don't add handlers here - let the parent logger handle formatting + # This prevents duplicate console output + self.logger.setLevel(logging.WARNING) + + @property + def strategy_dir(self) -> Path: + """ + Returns the directory where the underlying strategy can use + to persist data. + + Depending on the strategy abstraction, further nesting may be + required (e.g active learning step index, phase, etc.). + + Returns + ------- + Path + The directory where the metrology strategy will persist + its records. + + Raises + ------ + RuntimeError + If the metrology strategy is not attached to a driver yet. + """ + if not self.is_attached: + raise RuntimeError( + f"{self.__class__.__name__} is not attached to a driver yet." + ) + path = ( + self.driver.log_dir / str(self.__protocol_type__) / self.__class__.__name__ + ) + path.mkdir(parents=True, exist_ok=True) + return path + + @property + def checkpoint_dir(self) -> Path: + """ + Utility property for strategies to conveniently access the checkpoint directory. + + This is useful for (de)serializing data tied to checkpointing. + + Returns + ------- + Path + The checkpoint directory, which includes the active learning step index. + + Raises + ------ + RuntimeError + If the strategy is not attached to a driver yet. + """ + if not self.is_attached: + raise RuntimeError( + f"{self.__class__.__name__} is not attached to a driver yet." + ) + path = ( + self.driver.log_dir + / "checkpoints" + / f"step_{self.driver.active_learning_step_idx}" + ) + path.mkdir(parents=True, exist_ok=True) + return path + + +class QueryStrategy(ActiveLearningProtocol): + """ + This protocol defines a query strategy for active learning. + + A query strategy is responsible for selecting data points for labeling. + In the most general sense, concrete instances of this protocol + will specify how many samples to query, and the heuristics for + selecting samples. + + Attributes + ---------- + max_samples: int + The maximum number of samples to query. This can be interpreted + as the exact number of samples to query, or as an upper limit + for querying methods that are threshold based. + """ + + max_samples: int + __protocol_type__ = ActiveLearningPhase.QUERY + + def sample(self, query_queue: AbstractQueue[T], *args: Any, **kwargs: Any) -> None: + """ + Method that implements the logic behind querying data to be labeled. + + This method should be implemented by concrete implementations, + and assume that an active learning driver will pass a queue + for this method to enqueue data to be labeled. + + Additional ``args`` and ``kwargs`` are passed to the method, + and can be used to pass additional information to the query strategy. + + This method will enqueue in place, and should not return anything. + + Parameters + ---------- + query_queue: AbstractQueue[T] + The queue to enqueue data to be labeled. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def __call__( + self, query_queue: AbstractQueue[T], *args: Any, **kwargs: Any + ) -> None: + """ + Syntactic sugar for the ``sample`` method. + + This allows the object to be called as a function, and will pass + the arguments to the strategy's ``sample`` method. + + Parameters + ---------- + query_queue: AbstractQueue[T] + The queue to enqueue data to be labeled. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + self.sample(query_queue, *args, **kwargs) + + +class LabelStrategy(ActiveLearningProtocol): + """ + This protocol defines a label strategy for active learning. + + A label strategy is responsible for labeling data points; this may + be an simple Python function for demonstrating a concept, or an external, + potentially time consuming and complex, process. + + Attributes + ---------- + __is_external_process__: bool + Whether the label strategy is running in an external process. + __provides_fields__: set[str] + The fields that the label strategy provides. This should be + set by concrete implementations, and should be used to write + and map labeled data to fields within the data structure ``T``. + """ + + __is_external_process__: bool + __provides_fields__: set[str] | None = None + __protocol_type__ = ActiveLearningPhase.LABELING + + def label( + self, + queue_to_label: AbstractQueue[T], + serialize_queue: AbstractQueue[T], + *args: Any, + **kwargs: Any, + ) -> None: + """ + Method that implements the logic behind labeling data. + + This method should be implemented by concrete implementations, + and assume that an active learning driver will pass a queue + for this method to dequeue data to be labeled. + + Parameters + ---------- + queue_to_label: AbstractQueue[T] + Queue containing data structures to be labeled. Generally speaking, + this should be passed over after running query strateg(ies). + serialize_queue: AbstractQueue[T] + Queue for enqueing labeled data to be serialized. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def __call__( + self, + queue_to_label: AbstractQueue[T], + serialize_queue: AbstractQueue[T], + *args: Any, + **kwargs: Any, + ) -> None: + """ + Syntactic sugar for the ``label`` method. + + This allows the object to be called as a function, and will pass + the arguments to the strategy's ``label`` method. + + Parameters + ---------- + queue_to_label: AbstractQueue[T] + Queue containing data structures to be labeled. + serialize_queue: AbstractQueue[T] + Queue for enqueing labeled data to be serialized. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + self.label(queue_to_label, serialize_queue, *args, **kwargs) + + +class MetrologyStrategy(ActiveLearningProtocol): + """ + This protocol defines a metrology strategy for active learning. + + A metrology strategy is responsible for assessing the improvements to the underlying + model, beyond simple validation metrics. This should reflect the application + requirements of the model, which may include running a simulation. + + Attributes + ---------- + records: list[S] + A sequence of record data structures that records the + history of the active learning process, as viewed by + this particular metrology view. + """ + + records: list[S] + __protocol_type__ = ActiveLearningPhase.METROLOGY + + def append(self, record: S) -> None: + """ + Method to append a record to the metrology strategy. + + Parameters + ---------- + record: S + The record to append to the metrology strategy. + """ + self.records.append(record) + + def __len__(self) -> int: + """ + Method to get the length of the metrology strategy. + + Returns + ------- + int + The length of the metrology strategy. + """ + return len(self.records) + + def serialize_records( + self, path: Path | None = None, *args: Any, **kwargs: Any + ) -> None: + """ + Method to serialize the records of the metrology strategy. + + This should be defined by a concrete implementation, which dictates + how the records are persisted, e.g. to a JSON file, database, etc. + + The `strategy_dir` property can be used to determine the directory where + the records should be persisted. + + Parameters + ---------- + path: Path | None + The path to serialize the records to. If not provided, the strategy + should provide a reasonable default, such as with the checkpointing + or within the corresponding metrology directory via `strategy_dir`. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def load_records(self, path: Path | None = None, *args: Any, **kwargs: Any) -> None: + """ + Method to load the records of the metrology strategy, i.e. + the reverse of `serialize_records`. + + This should be defined by a concrete implementation, which dictates + how the records are loaded, e.g. from a JSON file, database, etc. + + If no path is provided, the strategy should load the latest records + as sensible defaults. The `records` attribute should then be overwritten + in-place. + + Parameters + ---------- + path: Path | None + The path to load the records from. If not provided, the strategy + should load the latest records as sensible defaults. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def compute(self, *args: Any, **kwargs: Any) -> None: + """ + Method to compute the metrology strategy. No data is passed to + this method, as it is expected that the data be drawn as needed + from various ``DataPool`` connected to the driver. + + This method defines the core logic for computing a particular view + of performance by the underlying model on the data. Once computed, + the data needs to be formatted into a record data structure ``S``, + that is then appended to the ``records`` attribute. + + Parameters + ---------- + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def __call__(self, *args: Any, **kwargs: Any) -> None: + """ + Syntactic sugar for the ``compute`` method. + + This allows the object to be called as a function, and will pass + the arguments to the strategy's ``compute`` method. + + Parameters + ---------- + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + self.compute(*args, **kwargs) + + def reset(self) -> None: + """ + Method to reset any stateful attributes of the metrology strategy. + + By default, the ``records`` attribute is reset to an empty list. + """ + self.records = [] + + +class TrainingProtocol(Protocol): + """ + This protocol defines the interface for training steps: given + a model and some input data, compute the reduced, differentiable + loss tensor and return it. + + A concrete implementation can simply be a function with a signature that + matches what is defined in ``__call__``. + """ + + def __call__( + self, model: Module, data: T, *args: Any, **kwargs: Any + ) -> torch.Tensor: + """ + Implements the training logic for a single training sample or batch. + + For a PhysicsNeMo ``Module`` with trainable parameters, the output + of this function should correspond to a PyTorch tensor that is + ``backward``-ready. If there are any logging operations associated + with training, they should be performed within this function. + + For ideal performance, this function should also be wrappable with + ``StaticCaptureTraining`` for optimization. + + Parameters + ---------- + model: Module + The model to train. + data: T + The data to train on. This data structure should comprise + both input and ground truths to compute the loss. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + + Returns + ------- + torch.Tensor + The reduced, differentiable loss tensor. + + Example + ------- + Minimum viable implementation: + >>> import torch + >>> def training_step(model, data): + ... output = model(data) + ... loss = torch.sum(torch.pow(output - data, 2)) + ... return loss + """ + ... + + +class ValidationProtocol(Protocol): + """ + This protocol defines the interface for validation steps: given + a model and some input data, compute metrics of interest and if + relevant to do so, log the results. + + A concrete implementation can simply be a function with a signature that + matches what is defined in ``__call__``. + """ + + def __call__(self, model: Module, data: T, *args: Any, **kwargs: Any) -> None: + """ + Implements the validation logic for a single sample or batch. + + This method will be called in validation steps **only**, and not used + for training, query, or metrology steps. In those cases, implement the + ``inference_step`` method instead. + + This function should not return anything, but should contain the logic + for computing metrics of interest over a validation/test set. If there + are any logging operations that need to be performed, they should also + be performed here. + + Depending on the type of model architecture, consider wrapping this method + with ``StaticCaptureEvaluateNoGrad`` for performance optimizations. This + should be used if the model does not require autograd as part of its + forward pass. + + Parameters + ---------- + model: Module + The model to validate. + data: T + The data to validate on. This data structure should comprise + both input and ground truths to compute the loss. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + + Example + ------- + Minimum viable implementation: + >>> import torch + >>> def validation_step(model, data): + ... output = model(data) + ... loss = torch.sum(torch.pow(output - data, 2)) + ... return loss + """ + ... + + +class InferenceProtocol(Protocol): + """ + This protocol defines the interface for inference steps: given + a model and some input data, return the output of the model's forward pass. + + A concrete implementation can simply be a function with a signature that + matches what is defined in ``__call__``. + """ + + def __call__(self, model: Module, data: S, *args: Any, **kwargs: Any) -> Any: + """ + Implements the inference logic for a single sample or batch. + + This method will be called in query and metrology steps, and should + return the output of the model's forward pass, likely minimally processed + so that any transformations can be performed by strategies that utilize + this protocol. + + The key difference between this protocol and the other two training and + validation protocols is that the data structure ``S`` does not need + to contain ground truth values to compute a loss. + + Similar to ``ValidationProtocol``, if relevant to the underlying architecture, + consider wrapping a concrete implementation of this protocol with + ``StaticCaptureInference`` for performance optimizations. + + Parameters + ---------- + model: Module + The model to infer on. + data: S + The data to infer on. This data structure should comprise + only input values to compute the forward pass. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + + Returns + ------- + Any + The output of the model's forward pass. + + Example + ------- + Minimum viable implementation: + >>> def inference_step(model, data): + ... output = model(data) + ... return output + """ + ... + + +class TrainingLoop(Protocol): + """ + Defines a protocol that implements a training loop. + + This protocol is intended to be called within the active learning loop + during the training phase, where the model is trained on a specified + number of epochs or training steps, and optionally validated on a dataset. + + If a ``LearnerProtocol`` is provided, then ``train_fn`` and ``validate_fn`` + become optional as they will be defined within the ``LearnerProtocol``. If + they are provided, however, then they should override the ``LearnerProtocol`` + variants. + + If graph capture/compilation is intended, then ``train_fn`` and ``validate_fn`` + should be wrapped with ``StaticCaptureTraining`` and ``StaticCaptureEvaluateNoGrad``, + respectively. + """ + + def __call__( + self, + model: Module | LearnerProtocol, + optimizer: Optimizer, + train_dataloader: DataLoader, + validation_dataloader: DataLoader | None = None, + train_step_fn: TrainingProtocol | None = None, + validate_step_fn: ValidationProtocol | None = None, + max_epochs: int | None = None, + max_train_steps: int | None = None, + max_val_steps: int | None = None, + lr_scheduler: _LRScheduler | None = None, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Defines the signature for a minimal viable training loop. + + The protocol defines a ``model`` with trainable parameters + tracked by ``optimizer`` will go through multiple epochs or + training steps. In the latter, the ``train_dataloader`` will be + exhausted ``max_epochs`` times, while the mutually exclusive + ``max_train_steps`` will limit the number of training batches, + which can be greater or less than the length of the ``train_dataloader``. + + (Optional) Validation is intended to be performed either at the end of a training + epoch, or when the maximum number of training steps is reached. The + ``max_val_steps`` parameter can be used to limit the number of batches to validate with + on a per-epoch basis. Validation is only performed if a ``validate_step_fn`` is provided, + alongside ``validation_dataloader``. + + The pseudocode for training to ``max_epochs`` would look like this: + + .. code-block:: python + + max_epochs = 10 + for epoch in range(max_epochs): + for train_idx, batch in enumerate(train_dataloader): + optimizer.zero_grad() + loss = train_step_fn(model, batch) + loss.backward() + optimizer.step() + if train_idx + 1 == max_train_steps: + break + if validate_step_fn and validation_dataloader: + for val_idx, batch in enumerate(validation_dataloader): + validate_step_fn(model, batch) + if val_idx + 1 == max_val_steps: + break + + The pseudocode for training with a ``LearnerProtocol`` would look like this: + + .. code-block:: python + + for epoch in range(max_epochs): + for train_idx, batch in enumerate(train_dataloader): + loss = model.training_step(batch) + if train_idx + 1 == max_train_steps: + break + if validation_dataloader: + for val_idx, batch in enumerate(validation_dataloader): + model.validation_step(batch) + if val_idx + 1 == max_val_steps: + break + + The key difference between specifying ``train_step_fn`` and ``LearnerProtocol`` + is that the former excludes the backward pass and optimizer step logic, + whereas the latter encapsulates them. + + The ``device`` and ``dtype`` parameters are used to specify the device and + dtype to use for the training loop. If not provided, a reasonable default + should be used (e.g. from ``torch.get_default_device()`` and ``torch.get_default_dtype()``). + + Parameters + ---------- + model: Module | LearnerProtocol + The model to train. + optimizer: Optimizer + The optimizer to use for training. + train_dataloader: DataLoader + The dataloader to use for training. + validation_dataloader: DataLoader | None + The dataloader to use for validation. + train_step_fn: TrainingProtocol | None + The training function to use for training. This is optional only + if ``model`` implements the ``LearnerProtocol``. If this is + provided and ``model`` implements the ``LearnerProtocol``, + then this function will take precedence over the + ``LearnerProtocol.training_step`` method. + validate_step_fn: ValidationProtocol | None + The validation function to use for validation, only if it is + provided alongside ``validation_dataloader``. If ``model`` implements + the ``LearnerProtocol``, then this function will take precedence over + the ``LearnerProtocol.validation_step`` method. + max_epochs: int | None + The maximum number of epochs to train for. Mututally exclusive + with ``max_train_steps``. + max_train_steps: int | None + The maximum number of training steps to perform. Mututally exclusive + with ``max_epochs``. If this value is greater than the length + of ``train_dataloader``, then the training loop will recycle the data + (i.e. more than one epoch) until the maximum number of training steps + is reached. + max_val_steps: int | None + The maximum number of validation steps to perform per training + epoch. If ``None``, then the full validation set will be used. + lr_scheduler: _LRScheduler | None = None, + The learning rate scheduler to use for training. If provided, + this will be used to update the learning rate of the optimizer + during training. If not provided, then the learning rate will + not be adjusted within this function. + device: str | torch.device | None = None + The device to use for the training loop. + dtype: torch.dtype | None = None + The dtype to use for the training loop. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + +class LearnerProtocol: + """ + This protocol represents the learner part of an active learning + algorithm. + + This corresponds to a set of trainable parameters that are optimized, + and subsequently used for inference and evaluation. + + The required methods make this classes that implement this protocol + provide all the required functionality across all active learning steps. + Keep in mind that, similar to all other protocols in this module, this + is merely the required interface and not the actual implementation. + """ + + def training_step(self, data: T, *args: Any, **kwargs: Any) -> None: + """ + Implements the training logic for a single batch. + + This method will be called in training steps **only**, and not used + for validation, query, or metrology steps. Specifically this means + that gradients will be computed and used to update parameters. + + In cases where gradients are not needed, consider implementing the + ``validation_step`` method instead. + + This should mirror the ``TrainingProtocol`` definition, except that + the model corresponds to this object. + + Parameters + ---------- + data: T + The data to train on. Typically assumed to be a batch + of data. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def validation_step(self, data: T, *args: Any, **kwargs: Any) -> None: + """ + Implements the validation logic for a single batch. + + This can match the forward pass, without the need for weight updates. + This method will be called in validation steps **only**, and not used + for query or metrology steps. In those cases, implement the ``inference_step`` + method instead. + + This should mirror the ``ValidationProtocol`` definition, except that + the model corresponds to this object. + + Parameters + ---------- + data: T + The data to validate on. Typically assumed to be a batch + of data. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def inference_step(self, data: T | S, *args: Any, **kwargs: Any) -> None: + """ + Implements the inference logic for a single batch. + + This can match the forward pass exactly, but provides an opportunity + to differentiate (or lack thereof, with no pun intended). Specifically, + this method will be called during query and metrology steps. + + This should mirror the ``InferenceProtocol`` definition, except that + the model corresponds to this object. + + Parameters + ---------- + data: T + The data to infer on. Typically assumed to be a batch + of data. + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + @property + def parameters(self) -> Iterator[torch.Tensor]: + """ + Returns an iterator over the parameters of the learner. + + If subclassing from `torch.nn.Module`, this will automatically return + the parameters of the module. + + Returns + ------- + Iterator[torch.Tensor] + An iterator over the parameters of the learner. + """ + ... + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """ + Implements the forward pass for a single batch. + + This method is called between all active learning steps, and should + contain the logic for how a model ingests data and produces predictions. + + Parameters + ---------- + args: Any + Additional arguments to pass to the model. + kwargs: Any + Additional keyword arguments to pass to the model. + + Returns + ------- + Any + The output of the model's forward pass. + """ + ... + + +class DriverProtocol: + """ + This protocol specifies the expected interface for an active learning + driver: for a concrete implementation, refer to the `driver` module + instead. The specification is provided mostly as a reference, and for + ease of type hinting to prevent circular imports. + + Attributes + ---------- + learner: LearnerProtocol + The learner module that will be used as the surrogate within + the active learning loop. + query_strategies: list[QueryStrategy] + The query strategies that will be used for selecting data points to label. + A list of strategies can be included, and will sequentially be used to + populate the ``query_queue`` that passes samples over to labeling. + query_queue: AbstractQueue[T] + The queue containing data samples to be labeled. ``QueryStrategy`` instances + should enqueue samples to this queue. + label_strategy: LabelStrategy | None + The label strategy that will be used for labeling data points. In contrast + to the other strategies, only a single label strategy is supported. + This strategy will consume the ``query_queue`` and enqueue labeled data to + the ``label_queue``. + label_queue: AbstractQueue[T] | None + The queue containing freshly labeled data. ``LabelStrategy`` instances + should enqueue labeled data to this queue, and the driver will subsequently + serialize data contained within this queue to a persistent format. + metrology_strategies: list[MetrologyStrategy] | None + The metrology strategies that will be used for assessing the performance + of the surrogate. A list of strategies can be included, and will sequentially + be used to populate the ``metrology_queue`` that passes data over to the + learner. + training_pool: DataPool[T] + The pool of data to be used for training. This data will be used to train + the underlying model, and is assumed to be mutable in that additional data + can be added to the pool over the course of active learning. + validation_pool: DataPool[T] | None + The pool of data to be used for validation. This data will be used for both + conventional validation, as well as for metrology. This dataset is considered + to be immutable, and should not be modified over the course of active learning. + This dataset is considered optional, as both validation and metrology are. + unlabeled_pool: DataPool[T] | None + An optional pool of data to be used for querying and labeling. If supplied, + this dataset can be depleted by a query strategy to select data points for labeling. + In principle, this could also represent a generative model, i.e. not just a static + dataset, but at a high level represents a distribution of data. + """ + + learner: LearnerProtocol + query_strategies: list[QueryStrategy] + query_queue: AbstractQueue[T] + label_strategy: LabelStrategy | None + label_queue: AbstractQueue[T] | None + metrology_strategies: list[MetrologyStrategy] | None + training_pool: DataPool[T] + validation_pool: DataPool[T] | None + unlabeled_pool: DataPool[T] | None + + def active_learning_step(self, *args: Any, **kwargs: Any) -> None: + """ + Implements the active learning step. + + This step performs a single pass of the active learning loop, with the + intended order being: training, metrology, query, labeling, with + the metrology and labeling steps being optional. + + Parameters + ---------- + args: Any + Additional arguments to pass to the method. + kwargs: Any + Additional keyword arguments to pass to the method. + """ + ... + + def _setup_logger(self) -> None: + """ + Sets up the logger for the driver. + + The intended concrete method should account for the ability to + scope logging, such that things like active learning iteration + counts, etc. can be logged. + """ + ... + + def attach_strategies(self) -> None: + """ + Attaches all provided strategies. + + This method relies on the ``attach`` method of the strategies, which + will subsequently give the strategy access to the driver's scope. + + Example use cases would be for any strategy (apart from label strategy) + to access the underlying model (``LearnerProtocol``); for a query + strategy to access the ``unlabeled_pool``; for a metrology strategy + to access the ``validation_pool``. + """ + for strategy in self.query_strategies: + strategy.attach(self) + if self.label_strategy: + self.label_strategy.attach(self) + if self.metrology_strategies: + for strategy in self.metrology_strategies: + strategy.attach(self) diff --git a/test/active_learning/__init__.py b/test/active_learning/__init__.py new file mode 100644 index 0000000000..b2340c62ce --- /dev/null +++ b/test/active_learning/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/active_learning/conftest.py b/test/active_learning/conftest.py new file mode 100644 index 0000000000..fe85622e69 --- /dev/null +++ b/test/active_learning/conftest.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterator +from unittest.mock import MagicMock + +import pytest +import torch + +from physicsnemo import Module +from physicsnemo.active_learning import protocols as p +from physicsnemo.active_learning._registry import registry + + +# Mock classes for testing serialization +class MockQueryStrategy: + """Mock query strategy for testing.""" + + def __init__(self): + pass + + def __call__(self, *args, **kwargs): + pass + + def attach(self, driver): + """Attach strategy to driver (no-op for mock).""" + pass + + +class MockLabelStrategy: + """Mock label strategy for testing.""" + + def __init__(self): + pass + + def __call__(self, *args, **kwargs): + pass + + def attach(self, driver): + """Attach strategy to driver (no-op for mock).""" + pass + + +class MockMetrologyStrategy: + """Mock metrology strategy for testing.""" + + def __init__(self): + pass + + def __call__(self, *args, **kwargs): + pass + + def attach(self, driver): + """Attach strategy to driver (no-op for mock).""" + pass + + +class MockTrainingLoop: + """Mock training loop for testing.""" + + def __init__(self): + pass + + def __call__(self, *args, **kwargs): + pass + + +# Register mock classes +registry.register("MockQueryStrategy")(MockQueryStrategy) +registry.register("MockLabelStrategy")(MockLabelStrategy) +registry.register("MockMetrologyStrategy")(MockMetrologyStrategy) +registry.register("MockTrainingLoop")(MockTrainingLoop) + + +@dataclass +class MockDataStructure: + """ + Essentially a stand-in that holds inputs, the target device, + and for testing labeling workflows, an optional target. + """ + + inputs: torch.Tensor + device: torch.device = torch.device("cpu") + targets: torch.Tensor | None = None + + +class MockModule(Module): + """A mock module that implements a linear layer and stands-in for a non-learner module.""" + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(64, 3) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Forward pass of the mock module""" + return self.linear(input_tensor) + + +class MockLearnerModule(p.LearnerProtocol): + """A mock learner module that implements a linear layer and a loss function""" + + def __init__(self): + super().__init__() + self.module = MockModule() + self.loss_fn = torch.nn.MSELoss() + + def training_step(self, data: MockDataStructure, *args: Any, **kwargs: Any) -> None: + """As this is a mock module, this is a no-op""" + return None + + def validation_step( + self, data: MockDataStructure, *args: Any, **kwargs: Any + ) -> None: + """As this is a mock module, this is a no-op""" + return None + + def inference_step( + self, data: MockDataStructure, *args: Any, **kwargs: Any + ) -> None: + """As this is a mock module, this is a no-op""" + return None + + @property + def parameters(self) -> Iterator[torch.Tensor]: + """Returns an iterator over the parameters of the learner.""" + return self.module.parameters() + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Forward pass through the module.""" + return self.module.forward(*args, **kwargs) + + +@pytest.fixture(scope="function", autouse=True) +def learner_module() -> MockLearnerModule: + """Mocks a learner module""" + return MockLearnerModule() + + +@pytest.fixture(scope="function", autouse=True) +def mock_module() -> MockModule: + """Mocks a module""" + return MockModule() + + +@pytest.fixture(scope="function", autouse=True) +def mock_queue() -> p.AbstractQueue[MockDataStructure]: + """Mocks a query queue with a single data entry""" + mock = MagicMock(spec=p.AbstractQueue) + mock.empty.return_value = False + mock.get.return_value = MockDataStructure( + inputs=torch.randn(16, 64), + device=torch.device("cpu"), + ) + mock.put.return_value = None + return mock + + +@pytest.fixture(scope="function", autouse=True) +def mock_query_strategy() -> p.QueryStrategy: + mock = MagicMock(spec=p.QueryStrategy) + mock.sample.return_value = None + mock._args = { + "__name__": "MockQueryStrategy", + "__module__": "test.active_learning.conftest", + "__args__": {}, + } + return mock + + +@pytest.fixture(scope="function", autouse=True) +def mock_label_strategy() -> p.LabelStrategy: + mock = MagicMock(spec=p.LabelStrategy) + mock.label.return_value = None + mock._args = { + "__name__": "MockLabelStrategy", + "__module__": "test.active_learning.conftest", + "__args__": {}, + } + return mock + + +@pytest.fixture(scope="function", autouse=True) +def mock_metrology_strategy() -> p.MetrologyStrategy: + mock = MagicMock(spec=p.MetrologyStrategy) + mock.compute.return_value = None + mock.__call__ = mock.compute + mock.serialize_records.return_value = None + mock.records = [ + None, + ] + mock._args = { + "__name__": "MockMetrologyStrategy", + "__module__": "test.active_learning.conftest", + "__args__": {}, + } + return mock + + +@pytest.fixture(scope="function", autouse=True) +def mock_training_loop() -> p.TrainingLoop: + mock = MagicMock(spec=p.TrainingLoop) + mock._args = { + "__name__": "MockTrainingLoop", + "__module__": "test.active_learning.conftest", + "__args__": {}, + } + return mock + + +@pytest.fixture(scope="function", autouse=True) +def mock_data_pool() -> p.DataPool[MockDataStructure]: + """Mocks a data pool""" + mock = MagicMock(spec=p.DataPool) + mock.append.return_value = None + mock.__getitem__.return_value = MockDataStructure( + inputs=torch.randn(16, 64), + device=torch.device("cpu"), + ) + mock.__len__.return_value = 1 + mock.__iter__.return_value = iter( + [ + MockDataStructure( + inputs=torch.randn(16, 64), + device=torch.device("cpu"), + ) + ] + ) + return mock diff --git a/test/active_learning/test_checkpointing.py b/test/active_learning/test_checkpointing.py new file mode 100644 index 0000000000..5df0fe9db7 --- /dev/null +++ b/test/active_learning/test_checkpointing.py @@ -0,0 +1,883 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for active learning checkpoint functionality.""" + +from __future__ import annotations + +import queue +import shutil +from typing import Any + +import pytest +import torch + +from physicsnemo.active_learning import protocols as p +from physicsnemo.active_learning.config import ( + DriverConfig, + OptimizerConfig, + StrategiesConfig, + TrainingConfig, +) +from physicsnemo.active_learning.driver import ActiveLearningCheckpoint, Driver + +from .conftest import MockDataStructure, MockModule + + +class SimpleQueue: + """Simple queue implementation with serialization support for testing.""" + + def __init__(self): + """Initialize empty queue.""" + self._items = [] + + def put(self, item: Any) -> None: + """Add item to queue.""" + self._items.append(item) + + def get(self) -> Any: + """Remove and return item from queue.""" + return self._items.pop(0) if self._items else None + + def empty(self) -> bool: + """Check if queue is empty.""" + return len(self._items) == 0 + + def to_list(self) -> list[Any]: + """Serialize queue to list.""" + return self._items.copy() + + def from_list(self, items: list[Any]) -> None: + """Restore queue from list.""" + self._items = items.copy() + + +@pytest.fixture +def simple_queue(): + """Fixture for a simple queue with serialization support.""" + return SimpleQueue + + +@pytest.fixture +def temp_checkpoint_dir(tmp_path): + """Fixture for temporary checkpoint directory.""" + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + yield checkpoint_dir + # Cleanup + if checkpoint_dir.exists(): + shutil.rmtree(checkpoint_dir) + + +@pytest.fixture +def driver_config(temp_checkpoint_dir): + """Fixture for basic driver config with checkpointing enabled.""" + return DriverConfig( + batch_size=16, + max_active_learning_steps=2, + checkpoint_interval=1, + checkpoint_on_query=True, + skip_training=True, # Skip training for basic checkpoint tests + skip_labeling=True, # Skip labeling for basic checkpoint tests + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + +@pytest.fixture +def strategies_config(mock_query_strategy, simple_queue): + """Fixture for strategies config with simple queue.""" + return StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=simple_queue, + ) + + +@pytest.fixture +def training_config(mock_data_pool, mock_training_loop): + """Fixture for training config.""" + return TrainingConfig( + train_datapool=mock_data_pool, + max_training_epochs=2, + optimizer_config=OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": 0.01}, + ), + train_loop_fn=mock_training_loop, + ) + + +@pytest.mark.dependency() +def test_checkpoint_basic_save( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test basic checkpoint saving functionality.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Manually set a phase + driver.current_phase = p.ActiveLearningPhase.QUERY + + # Save checkpoint + checkpoint_path = temp_checkpoint_dir / "test_checkpoint" + driver.save_checkpoint(path=checkpoint_path) + + # Verify checkpoint files exist + assert checkpoint_path.exists() + assert (checkpoint_path / "checkpoint.pt").exists() + assert (checkpoint_path / "MockModule.mdlus").exists() + + # Verify last_checkpoint property + assert driver.last_checkpoint == checkpoint_path + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_contains_correct_metadata( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test that checkpoint contains all required metadata.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + driver.active_learning_step_idx = 5 + driver.current_phase = p.ActiveLearningPhase.QUERY + + checkpoint_path = temp_checkpoint_dir / "metadata_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load and verify checkpoint + checkpoint_dict = torch.load(checkpoint_path / "checkpoint.pt", weights_only=False) + checkpoint: ActiveLearningCheckpoint = checkpoint_dict["checkpoint"] + + assert checkpoint.active_learning_step_idx == 5 + assert checkpoint.active_learning_phase == p.ActiveLearningPhase.QUERY + assert checkpoint.driver_config is not None + assert checkpoint.strategies_config is not None + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_with_training_state( + mock_module, strategies_config, training_config, temp_checkpoint_dir +): + """Test that training state is saved separately via training loop.""" + from physicsnemo.active_learning.loop import DefaultTrainingLoop + + # Create config without skip_training for this test + config = DriverConfig( + batch_size=16, + max_active_learning_steps=2, + checkpoint_interval=1, + checkpoint_on_query=True, + skip_labeling=True, # Skip labeling since we don't have a label strategy + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + driver = Driver( + config=config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + + # Configure optimizer + driver.configure_optimizer() + assert driver.optimizer is not None + + checkpoint_path = temp_checkpoint_dir / "training_state_test" + + # Use training loop to save training state + training_loop = DefaultTrainingLoop() + training_loop.save_training_checkpoint( + checkpoint_dir=checkpoint_path, + model=driver.learner, + optimizer=driver.optimizer, + lr_scheduler=driver.lr_scheduler if hasattr(driver, "lr_scheduler") else None, + training_epoch=5, + ) + + # Verify training state file exists and contains optimizer state + assert (checkpoint_path / "training_state.pt").exists() + training_state = torch.load(checkpoint_path / "training_state.pt") + assert "optimizer_state" in training_state + assert "param_groups" in training_state["optimizer_state"] + assert training_state["training_epoch"] == 5 + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_queue_serialization_with_to_list( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test queue serialization using to_list/from_list methods.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Add items to queues + test_data = MockDataStructure(inputs=torch.randn(16, 64)) + driver.query_queue.put(test_data) + driver.label_queue.put(test_data) + + checkpoint_path = temp_checkpoint_dir / "queue_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load and verify queue files were created + checkpoint_dict = torch.load(checkpoint_path / "checkpoint.pt", weights_only=False) + checkpoint: ActiveLearningCheckpoint = checkpoint_dict["checkpoint"] + + assert checkpoint.has_query_queue is True + assert checkpoint.has_label_queue is True + assert (checkpoint_path / "query_queue.pt").exists() + assert (checkpoint_path / "label_queue.pt").exists() + + # Verify queue contents + query_queue_data = torch.load( + checkpoint_path / "query_queue.pt", weights_only=False + ) + assert query_queue_data["type"] == "list" + assert len(query_queue_data["data"]) == 1 + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_queue_serialization_fallback( + driver_config, mock_module, mock_query_strategy, temp_checkpoint_dir +): + """Test queue serialization handles unpicklable queues gracefully.""" + # Use standard library queue without to_list method + # This queue cannot be pickled due to thread locks + strategies_config_stdlib = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=queue.Queue, + ) + + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config_stdlib, + ) + + # Add item to queue + test_data = MockDataStructure(inputs=torch.randn(16, 64)) + driver.query_queue.put(test_data) + + checkpoint_path = temp_checkpoint_dir / "queue_fallback_test" + driver.save_checkpoint(path=checkpoint_path) + + # Verify checkpoint was created + assert (checkpoint_path / "checkpoint.pt").exists() + + # Load and verify queue serialization failed (unpicklable) + checkpoint_dict = torch.load(checkpoint_path / "checkpoint.pt", weights_only=False) + checkpoint: ActiveLearningCheckpoint = checkpoint_dict["checkpoint"] + + # stdlib queue.Queue cannot be pickled, so has_query_queue should be False + assert checkpoint.has_query_queue is False + # Queue file should not exist + assert not (checkpoint_path / "query_queue.pt").exists() + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_load_basic( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test basic checkpoint loading functionality.""" + # Create and save a checkpoint + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + driver.active_learning_step_idx = 3 + driver.current_phase = p.ActiveLearningPhase.QUERY + + checkpoint_path = temp_checkpoint_dir / "load_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load checkpoint + loaded_driver = Driver.load_checkpoint( + checkpoint_path=checkpoint_path, + learner=MockModule(), + ) + + # Verify state was restored + assert loaded_driver.active_learning_step_idx == 3 + assert loaded_driver.current_phase == p.ActiveLearningPhase.QUERY + assert loaded_driver.last_checkpoint == checkpoint_path + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_load_with_datapools( + mock_module, + strategies_config, + training_config, + mock_data_pool, + temp_checkpoint_dir, +): + """Test checkpoint loading with datapool restoration.""" + # Create config without skip_training for this test + config = DriverConfig( + batch_size=16, + max_active_learning_steps=2, + checkpoint_interval=1, + checkpoint_on_query=True, + skip_labeling=True, + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + driver = Driver( + config=config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + + checkpoint_path = temp_checkpoint_dir / "datapool_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load with datapools provided + loaded_driver = Driver.load_checkpoint( + checkpoint_path=checkpoint_path, + learner=MockModule(), + train_datapool=mock_data_pool, + val_datapool=mock_data_pool, + ) + + # Verify datapools were set + assert loaded_driver.train_datapool is not None + assert loaded_driver.val_datapool is not None + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_training_loop_restores_optimizer( + mock_module, + strategies_config, + training_config, + mock_data_pool, + temp_checkpoint_dir, +): + """Test that training loop can restore optimizer state from training_state.pt.""" + from physicsnemo.active_learning.loop import DefaultTrainingLoop + + # Create config without skip_training for this test + config = DriverConfig( + batch_size=16, + max_active_learning_steps=2, + checkpoint_interval=1, + checkpoint_on_query=True, + skip_labeling=True, + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + driver = Driver( + config=config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + + driver.configure_optimizer() + # Modify optimizer state + for param_group in driver.optimizer.param_groups: + param_group["lr"] = 0.1234 + + checkpoint_path = temp_checkpoint_dir / "training_state_restore_test" + + # Use training loop to save training state + training_loop = DefaultTrainingLoop() + training_loop.save_training_checkpoint( + checkpoint_dir=checkpoint_path, + model=driver.learner, + optimizer=driver.optimizer, + lr_scheduler=driver.lr_scheduler if hasattr(driver, "lr_scheduler") else None, + training_epoch=3, + ) + + # Create new driver and optimizer + new_module = MockModule() + new_driver = Driver( + config=config, + learner=new_module, + strategies_config=strategies_config, + training_config=training_config, + ) + new_driver.configure_optimizer() + + # Load training state using training loop + epoch = DefaultTrainingLoop.load_training_checkpoint( + checkpoint_dir=checkpoint_path, + model=new_driver.learner, + optimizer=new_driver.optimizer, + lr_scheduler=new_driver.lr_scheduler + if hasattr(new_driver, "lr_scheduler") + else None, + ) + + # Verify optimizer state was restored + assert new_driver.optimizer.param_groups[0]["lr"] == pytest.approx(0.1234) + assert epoch == 3 + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_load_restores_queues( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test that checkpoint loading restores queue contents.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Add items to queue + test_data = MockDataStructure(inputs=torch.randn(16, 64)) + driver.query_queue.put(test_data) + + checkpoint_path = temp_checkpoint_dir / "queue_restore_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load checkpoint + loaded_driver = Driver.load_checkpoint( + checkpoint_path=checkpoint_path, + learner=MockModule(), + ) + + # Verify queue was restored + assert not loaded_driver.query_queue.empty() + restored_item = loaded_driver.query_queue.get() + assert isinstance(restored_item, MockDataStructure) + assert restored_item.inputs.shape == test_data.inputs.shape + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_interval_controls_saving( + temp_checkpoint_dir, mock_module, mock_query_strategy, simple_queue +): + """Test that checkpoint_interval controls when checkpoints are saved.""" + # Set checkpoint_interval to 2 + config = DriverConfig( + batch_size=16, + max_active_learning_steps=5, + checkpoint_interval=2, + checkpoint_on_query=True, + skip_training=True, + skip_labeling=True, + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=simple_queue, + ) + + driver = Driver( + config=config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Test at step 0 (should checkpoint) + driver.active_learning_step_idx = 0 + assert driver._should_checkpoint_at_step() + + # Test at step 1 (should NOT checkpoint) + driver.active_learning_step_idx = 1 + assert not driver._should_checkpoint_at_step() + + # Test at step 2 (should checkpoint) + driver.active_learning_step_idx = 2 + assert driver._should_checkpoint_at_step() + + # Test at step 4 (should checkpoint) + driver.active_learning_step_idx = 4 + assert driver._should_checkpoint_at_step() + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_interval_zero_disables_checkpointing( + temp_checkpoint_dir, mock_module, mock_query_strategy, simple_queue +): + """Test that checkpoint_interval=0 disables checkpointing.""" + config = DriverConfig( + batch_size=16, + max_active_learning_steps=5, + checkpoint_interval=0, # Disabled + checkpoint_on_query=True, + skip_training=True, + skip_labeling=True, + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=simple_queue, + ) + + driver = Driver( + config=config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Should never checkpoint when interval is 0 + for step in range(5): + driver.active_learning_step_idx = step + assert not driver._should_checkpoint_at_step() + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_with_training_epoch( + mock_module, + strategies_config, + training_config, + mock_data_pool, + temp_checkpoint_dir, +): + """Test checkpoint saving with training epoch information.""" + # Create config without skip_training for this test + config = DriverConfig( + batch_size=16, + max_active_learning_steps=2, + checkpoint_interval=1, + checkpoint_on_query=True, + skip_labeling=True, + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + driver = Driver( + config=config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + + driver.current_phase = p.ActiveLearningPhase.TRAINING + + # Save checkpoint with epoch number + checkpoint_path = temp_checkpoint_dir / "epoch_test" + driver.save_checkpoint(path=checkpoint_path, training_epoch=5) + + # Verify epoch is saved + checkpoint_dict = torch.load(checkpoint_path / "checkpoint.pt", weights_only=False) + assert "training_epoch" in checkpoint_dict + assert checkpoint_dict["training_epoch"] == 5 + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_auto_path_generation( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test that checkpoints are saved with auto-generated paths.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + driver.active_learning_step_idx = 2 + driver.current_phase = p.ActiveLearningPhase.QUERY + + # Save without specifying path + driver.save_checkpoint() + + # Verify path was auto-generated + expected_path = driver.log_dir / "checkpoints" / "step_2" / "query" + assert expected_path.exists() + assert (expected_path / "checkpoint.pt").exists() + + +@pytest.mark.dependency(depends=["test_checkpoint_basic_save"]) +def test_checkpoint_preserves_model_weights( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test that model weights are correctly saved and loaded.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Get initial weights + initial_weights = { + name: param.clone() for name, param in driver.learner.named_parameters() + } + + checkpoint_path = temp_checkpoint_dir / "weights_test" + driver.save_checkpoint(path=checkpoint_path) + + # Create new module and load weights + new_module = MockModule() + loaded_driver = Driver.load_checkpoint( + checkpoint_path=checkpoint_path, + learner=new_module, + ) + + # Verify weights match + for name, param in loaded_driver.learner.named_parameters(): + assert torch.allclose(param, initial_weights[name]) + + +def test_checkpoint_phase_specific_flags( + temp_checkpoint_dir, mock_module, mock_query_strategy, simple_queue +): + """Test that phase-specific checkpoint flags are respected.""" + config = DriverConfig( + batch_size=16, + max_active_learning_steps=2, + checkpoint_interval=1, + checkpoint_on_training=True, + checkpoint_on_metrology=False, + checkpoint_on_query=True, + checkpoint_on_labeling=False, + skip_training=True, + skip_labeling=True, + root_log_dir=temp_checkpoint_dir.parent, + device=torch.device("cpu"), + ) + + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=simple_queue, + ) + + driver = Driver( + config=config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Verify flags are set correctly + assert driver.config.checkpoint_on_training is True + assert driver.config.checkpoint_on_metrology is False + assert driver.config.checkpoint_on_query is True + assert driver.config.checkpoint_on_labeling is False + + +# ============================================================================ +# Phase Resumption Tests +# ============================================================================ + + +def test_get_phase_index(driver_config, mock_module, strategies_config): + """Test _get_phase_index helper method.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Test each phase + assert driver._get_phase_index(None) == 0 + assert driver._get_phase_index(p.ActiveLearningPhase.TRAINING) == 0 + assert driver._get_phase_index(p.ActiveLearningPhase.METROLOGY) == 1 + assert driver._get_phase_index(p.ActiveLearningPhase.QUERY) == 2 + assert driver._get_phase_index(p.ActiveLearningPhase.LABELING) == 3 + + +def test_build_phase_queue_from_fresh_start( + driver_config, mock_module, strategies_config +): + """Test phase queue includes only non-skipped phases when current_phase is None.""" + driver = Driver( + config=driver_config, # skip_training=True, skip_labeling=True, skip_metrology=False + learner=mock_module, + strategies_config=strategies_config, + ) + + # current_phase is None (fresh start) + assert driver.current_phase is None + + # Build phase queue + phase_queue = driver._build_phase_queue(None, None, (), {}) + + # Verify: metrology and query phases (training and labeling skipped) + assert len(phase_queue) == 2 # metrology + query + + +def test_build_phase_queue_from_query_phase( + driver_config, mock_module, strategies_config +): + """Test phase queue starts from query when current_phase=QUERY.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Set current_phase to QUERY (as if loaded from checkpoint) + driver.current_phase = p.ActiveLearningPhase.QUERY + + # Build phase queue + phase_queue = driver._build_phase_queue(None, None, (), {}) + + # With skip_training=True, skip_labeling=True + # Queue should include: [query] (labeling skipped by config) + assert len(phase_queue) == 1 + + +def test_resume_from_query_phase_skips_earlier_phases( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test that resuming from query phase skips training and metrology.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Create checkpoint at query phase + driver.active_learning_step_idx = 1 + driver.current_phase = p.ActiveLearningPhase.QUERY + driver.query_queue.put(MockDataStructure(inputs=torch.randn(16, 64))) + + checkpoint_path = temp_checkpoint_dir / "resume_query_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load checkpoint + loaded_driver = Driver.load_checkpoint( + checkpoint_path=checkpoint_path, + learner=MockModule(), + ) + + # Track which phases execute + executed_phases = [] + + loaded_driver._training_phase = lambda *a, **k: executed_phases.append("training") + loaded_driver._metrology_phase = lambda *a, **k: executed_phases.append("metrology") + loaded_driver._query_phase = lambda *a, **k: executed_phases.append("query") + loaded_driver._labeling_phase = lambda *a, **k: executed_phases.append("labeling") + + # Execute one AL step + loaded_driver.active_learning_step() + + # Verify: only query executed (training/metrology skipped, labeling skipped by config) + assert "training" not in executed_phases + assert "metrology" not in executed_phases + assert "query" in executed_phases + assert "labeling" not in executed_phases + + # Verify: current_phase reset after step completion + assert loaded_driver.current_phase is None + assert loaded_driver.active_learning_step_idx == 2 + + +def test_current_phase_resets_after_step_completion( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test that current_phase is reset to None after completing an AL step.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Save checkpoint at query phase + driver.active_learning_step_idx = 0 + driver.current_phase = p.ActiveLearningPhase.QUERY + checkpoint_path = temp_checkpoint_dir / "phase_reset_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load checkpoint + loaded_driver = Driver.load_checkpoint( + checkpoint_path=checkpoint_path, + learner=MockModule(), + ) + + # Verify loaded state + assert loaded_driver.current_phase == p.ActiveLearningPhase.QUERY + + # Mock all phase methods as no-ops + loaded_driver._training_phase = lambda *a, **k: None + loaded_driver._metrology_phase = lambda *a, **k: None + loaded_driver._query_phase = lambda *a, **k: None + loaded_driver._labeling_phase = lambda *a, **k: None + + # Execute one AL step + loaded_driver.active_learning_step() + + # Verify: current_phase reset to None after step completion + assert loaded_driver.current_phase is None + assert loaded_driver.active_learning_step_idx == 1 + + +def test_resume_continues_to_next_al_step( + driver_config, mock_module, strategies_config, temp_checkpoint_dir +): + """Test that after resuming and completing one step, next step starts fresh.""" + driver = Driver( + config=driver_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Checkpoint at step 1, query phase + driver.active_learning_step_idx = 1 + driver.current_phase = p.ActiveLearningPhase.QUERY + + checkpoint_path = temp_checkpoint_dir / "multi_step_test" + driver.save_checkpoint(path=checkpoint_path) + + # Load checkpoint + loaded_driver = Driver.load_checkpoint( + checkpoint_path=checkpoint_path, + learner=MockModule(), + ) + + all_executions = [] + + def track_query(): + all_executions.append(f"step_{loaded_driver.active_learning_step_idx}_query") + + loaded_driver._training_phase = lambda *a, **k: all_executions.append( + f"step_{loaded_driver.active_learning_step_idx}_training" + ) + loaded_driver._metrology_phase = lambda *a, **k: all_executions.append( + f"step_{loaded_driver.active_learning_step_idx}_metrology" + ) + loaded_driver._query_phase = lambda *a, **k: track_query() + loaded_driver._labeling_phase = lambda *a, **k: all_executions.append( + f"step_{loaded_driver.active_learning_step_idx}_labeling" + ) + + # Execute step 1 (resume from query) + loaded_driver.active_learning_step() + + # Verify step 1 completed, current_phase reset + assert loaded_driver.current_phase is None + assert loaded_driver.active_learning_step_idx == 2 + + # Execute step 2 (fresh start, should build full queue) + loaded_driver.active_learning_step() + + # Verify execution pattern: + # Step 1: query only (resumed from QUERY, training/labeling skipped by config) + # Step 2: metrology + query (fresh start, current_phase=None, training/labeling skipped by config) + assert all_executions == ["step_1_query", "step_2_metrology", "step_2_query"] + assert loaded_driver.current_phase is None + assert loaded_driver.active_learning_step_idx == 3 diff --git a/test/active_learning/test_config.py b/test/active_learning/test_config.py new file mode 100644 index 0000000000..4e69db4260 --- /dev/null +++ b/test/active_learning/test_config.py @@ -0,0 +1,504 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for active learning configuration classes.""" + +import json +from math import nan +from queue import Queue + +import pytest +from torch.optim import SGD, AdamW +from torch.optim.lr_scheduler import StepLR + +from physicsnemo.active_learning.config import ( + DriverConfig, + OptimizerConfig, + StrategiesConfig, + TrainingConfig, +) + + +class TestOptimizerConfig: + """Tests for OptimizerConfig.""" + + def test_default_config(self): + """Test default optimizer configuration.""" + config = OptimizerConfig() + assert config.optimizer_cls == AdamW + assert config.optimizer_kwargs == {"lr": 1e-4} + assert config.scheduler_cls is None + assert config.scheduler_kwargs == {} + + def test_custom_optimizer(self): + """Test custom optimizer configuration.""" + config = OptimizerConfig( + optimizer_cls=SGD, + optimizer_kwargs={"lr": 0.01, "momentum": 0.9}, + ) + assert config.optimizer_cls == SGD + assert config.optimizer_kwargs["lr"] == 0.01 + assert config.optimizer_kwargs["momentum"] == 0.9 + + def test_with_scheduler(self): + """Test optimizer config with scheduler.""" + config = OptimizerConfig( + scheduler_cls=StepLR, + scheduler_kwargs={"step_size": 10, "gamma": 0.1}, + ) + assert config.scheduler_cls == StepLR + assert config.scheduler_kwargs["step_size"] == 10 + + def test_invalid_learning_rate(self): + """Test that invalid learning rates are rejected.""" + with pytest.raises(ValueError, match="Learning rate must be positive"): + OptimizerConfig(optimizer_kwargs={"lr": -0.01}) + + with pytest.raises(ValueError, match="Learning rate must be positive"): + OptimizerConfig(optimizer_kwargs={"lr": 0}) + + def test_scheduler_kwargs_without_scheduler(self): + """Test that scheduler_kwargs without scheduler_cls raises error.""" + with pytest.raises( + ValueError, match="scheduler_kwargs provided but scheduler_cls is None" + ): + OptimizerConfig(scheduler_kwargs={"step_size": 10}) + + def test_to_dict_from_dict_round_trip(self): + """Test that OptimizerConfig can be serialized and deserialized.""" + config = OptimizerConfig( + optimizer_cls=SGD, + optimizer_kwargs={"lr": 0.01, "momentum": 0.9}, + scheduler_cls=StepLR, + scheduler_kwargs={"step_size": 10, "gamma": 0.1}, + ) + + # Serialize + config_dict = config.to_dict() + + # Deserialize + restored_config = OptimizerConfig.from_dict(config_dict) + + # Verify + assert restored_config.optimizer_cls == SGD + assert restored_config.optimizer_kwargs == {"lr": 0.01, "momentum": 0.9} + assert restored_config.scheduler_cls == StepLR + assert restored_config.scheduler_kwargs == {"step_size": 10, "gamma": 0.1} + + def test_to_dict_from_dict_no_scheduler(self): + """Test serialization round-trip without scheduler.""" + config = OptimizerConfig( + optimizer_cls=AdamW, + optimizer_kwargs={"lr": 1e-3, "weight_decay": 1e-4}, + ) + + config_dict = config.to_dict() + restored_config = OptimizerConfig.from_dict(config_dict) + + assert restored_config.optimizer_cls == AdamW + assert restored_config.optimizer_kwargs == {"lr": 1e-3, "weight_decay": 1e-4} + assert restored_config.scheduler_cls is None + assert restored_config.scheduler_kwargs == {} + + +class TestStrategiesConfig: + """Tests for StrategiesConfig.""" + + def test_minimal_config(self, mock_query_strategy): + """Test minimal strategies configuration.""" + config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + ) + assert len(config.query_strategies) == 1 + assert config.queue_cls == Queue + assert config.label_strategy is None + assert config.metrology_strategies is None + assert config.unlabeled_datapool is None + + def test_full_config( + self, mock_query_strategy, mock_label_strategy, mock_metrology_strategy + ): + """Test fully configured strategies.""" + config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + label_strategy=mock_label_strategy, + metrology_strategies=[mock_metrology_strategy], + ) + assert len(config.query_strategies) == 1 + assert config.label_strategy is not None + assert len(config.metrology_strategies) == 1 + + def test_empty_query_strategies(self): + """Test that empty query strategies list raises error.""" + with pytest.raises(ValueError, match="At least one query strategy"): + StrategiesConfig(query_strategies=[], queue_cls=Queue) + + def test_empty_metrology_strategies(self, mock_query_strategy): + """Test that empty metrology strategies list raises error.""" + with pytest.raises(ValueError, match="metrology_strategies is an empty list"): + StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + metrology_strategies=[], + ) + + def test_non_callable_query_strategy(self): + """Test that non-callable query strategies are rejected.""" + with pytest.raises(ValueError, match="must be callable"): + StrategiesConfig( + query_strategies=["not_callable"], + queue_cls=Queue, + ) + + def test_non_callable_label_strategy(self, mock_query_strategy): + """Test that non-callable label strategy is rejected.""" + with pytest.raises(ValueError, match="label_strategy must be callable"): + StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + label_strategy="not_callable", + ) + + def test_non_callable_metrology_strategy(self, mock_query_strategy): + """Test that non-callable metrology strategies are rejected.""" + with pytest.raises(ValueError, match="must be callable"): + StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + metrology_strategies=["not_callable"], + ) + + def test_to_dict_from_dict_round_trip( + self, mock_query_strategy, mock_label_strategy, mock_metrology_strategy + ): + """Test that StrategiesConfig can be serialized and deserialized.""" + config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + label_strategy=mock_label_strategy, + metrology_strategies=[mock_metrology_strategy], + ) + + # Serialize (no warning when unlabeled_datapool is None) + config_dict = config.to_dict() + + # Deserialize (without unlabeled_datapool as it's not serialized) + restored_config = StrategiesConfig.from_dict(config_dict) + + # Verify + assert len(restored_config.query_strategies) == 1 + assert restored_config.queue_cls == Queue + assert restored_config.label_strategy is not None + assert len(restored_config.metrology_strategies) == 1 + assert restored_config.unlabeled_datapool is None + + def test_to_dict_from_dict_minimal(self, mock_query_strategy): + """Test serialization round-trip with minimal config.""" + config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + ) + + # Serialize (no warning when unlabeled_datapool is None) + config_dict = config.to_dict() + + restored_config = StrategiesConfig.from_dict(config_dict) + + assert len(restored_config.query_strategies) == 1 + assert restored_config.queue_cls == Queue + assert restored_config.label_strategy is None + assert restored_config.metrology_strategies is None + + +class TestTrainingConfig: + """Tests for TrainingConfig.""" + + def test_minimal_config(self, mock_data_pool, mock_training_loop): + """Test minimal training configuration.""" + config = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + ) + assert config.train_datapool is not None + assert config.train_loop_fn is not None + assert config.val_datapool is None + assert isinstance(config.optimizer_config, OptimizerConfig) + assert config.max_training_epochs == 10 + assert ( + config.max_fine_tuning_epochs == 10 + ) # should default to max_training_epochs + + def test_with_validation(self, mock_data_pool, mock_training_loop): + """Test training config with validation data.""" + config = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + val_datapool=mock_data_pool, + ) + assert config.val_datapool is not None + + def test_custom_optimizer_config(self, mock_data_pool, mock_training_loop): + """Test training config with custom optimizer.""" + opt_config = OptimizerConfig( + optimizer_cls=SGD, + optimizer_kwargs={"lr": 0.01}, + ) + config = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + optimizer_config=opt_config, + ) + assert config.optimizer_config.optimizer_cls == SGD + + def test_non_callable_training_loop(self, mock_data_pool): + """Test that non-callable training loop is rejected.""" + with pytest.raises(ValueError, match="train_loop_fn must be callable"): + TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn="not_callable", + max_training_epochs=10, + ) + + def test_to_dict_from_dict_round_trip(self, mock_data_pool, mock_training_loop): + """Test that TrainingConfig can be serialized and deserialized.""" + opt_config = OptimizerConfig( + optimizer_cls=SGD, + optimizer_kwargs={"lr": 0.05}, + ) + config = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=20, + max_fine_tuning_epochs=5, + optimizer_config=opt_config, + ) + + # Serialize + with pytest.warns(UserWarning, match="train_datapool.*not supported"): + config_dict = config.to_dict() + + # Deserialize (must provide datapools) + restored_config = TrainingConfig.from_dict( + config_dict, train_datapool=mock_data_pool + ) + + # Verify + assert restored_config.max_training_epochs == 20 + assert restored_config.max_fine_tuning_epochs == 5 + assert restored_config.optimizer_config.optimizer_cls == SGD + assert restored_config.optimizer_config.optimizer_kwargs == {"lr": 0.05} + assert restored_config.train_datapool is mock_data_pool + assert restored_config.val_datapool is None + + def test_from_dict_requires_train_datapool(self, mock_training_loop): + """Test that from_dict requires train_datapool in kwargs.""" + config_dict = { + "max_training_epochs": 10, + "max_fine_tuning_epochs": 10, + "optimizer_config": OptimizerConfig().to_dict(), + "train_loop_fn": mock_training_loop._args, + } + + with pytest.raises(ValueError, match="train_datapool.*must be provided"): + TrainingConfig.from_dict(config_dict) + + def test_to_dict_from_dict_with_validation( + self, mock_data_pool, mock_training_loop + ): + """Test serialization round-trip with validation datapool.""" + config = TrainingConfig( + train_datapool=mock_data_pool, + val_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=15, + ) + + with pytest.warns(UserWarning, match="train_datapool.*not supported"): + config_dict = config.to_dict() + + # Provide both datapools during deserialization + restored_config = TrainingConfig.from_dict( + config_dict, + train_datapool=mock_data_pool, + val_datapool=mock_data_pool, + ) + + assert restored_config.train_datapool is mock_data_pool + assert restored_config.val_datapool is mock_data_pool + assert restored_config.max_training_epochs == 15 + + +class TestDriverConfig: + """Tests for DriverConfig.""" + + def test_minimal_config(self): + """Test minimal driver configuration.""" + config = DriverConfig(batch_size=4) + assert config.batch_size == 4 + assert config.max_active_learning_steps == float("inf") + assert config.skip_training is False + assert config.skip_metrology is False + assert config.skip_labeling is False + + def test_to_json(self): + """Test JSON serialization.""" + config = DriverConfig(batch_size=8) + json_str = config.to_json() + json_dict = json.loads(json_str) + assert json_dict["batch_size"] == 8 + assert "run_id" in json_dict + assert "world_size" in json_dict + + @pytest.mark.parametrize("bad_number", [-1, float("inf"), nan]) + def test_invalid_batch_size(self, bad_number): + """Test that invalid batch sizes are rejected.""" + with pytest.raises(ValueError, match="`batch_size` must be a positive integer"): + DriverConfig(batch_size=bad_number) + + @pytest.mark.parametrize("bad_number", [-1, float("inf"), nan]) + def test_invalid_checkpoint_interval(self, bad_number): + """Test that invalid checkpoint intervals are rejected.""" + with pytest.raises( + ValueError, match="`checkpoint_interval` must be a non-negative integer" + ): + DriverConfig(batch_size=1, checkpoint_interval=bad_number) + + def test_zero_checkpoint_interval(self): + """Test that checkpoint_interval=0 is valid (disables checkpointing).""" + config = DriverConfig(batch_size=1, checkpoint_interval=0) + assert config.checkpoint_interval == 0 + + def test_invalid_fine_tuning_lr(self): + """Test that invalid fine-tuning learning rates are rejected.""" + with pytest.raises(ValueError, match="`fine_tuning_lr` must be positive"): + DriverConfig(batch_size=1, fine_tuning_lr=-0.01) + + with pytest.raises(ValueError, match="`fine_tuning_lr` must be positive"): + DriverConfig(batch_size=1, fine_tuning_lr=0) + + def test_invalid_num_workers(self): + """Test that negative num_workers is rejected.""" + with pytest.raises( + ValueError, match="`num_dataloader_workers` must be non-negative" + ): + DriverConfig(batch_size=1, num_dataloader_workers=-1) + + def test_invalid_collate_fn(self): + """Test that non-callable collate_fn is rejected.""" + with pytest.raises(ValueError, match="`collate_fn` must be callable"): + DriverConfig(batch_size=1, collate_fn="not_callable") + + def test_max_steps_invalid(self): + """Test that invalid max_active_learning_steps is rejected.""" + with pytest.raises( + ValueError, match="`max_active_learning_steps` must be a positive integer" + ): + DriverConfig(batch_size=1, max_active_learning_steps=0) + + with pytest.raises( + ValueError, match="`max_active_learning_steps` must be a positive integer" + ): + DriverConfig(batch_size=1, max_active_learning_steps=-5) + + def test_to_json_from_json_round_trip(self): + """Test that DriverConfig can be serialized and deserialized.""" + import torch + + config = DriverConfig( + batch_size=16, + max_active_learning_steps=100, + fine_tuning_lr=1e-5, + reset_optim_states=False, + skip_training=True, + checkpoint_interval=5, + num_dataloader_workers=4, + device="cpu", + dtype=torch.float32, + ) + + # Serialize + json_str = config.to_json() + + # Verify it's valid JSON + json_dict = json.loads(json_str) + assert json_dict["batch_size"] == 16 + assert json_dict["max_active_learning_steps"] == 100 + + # Deserialize + restored_config = DriverConfig.from_json(json_str) + + # Verify + assert restored_config.batch_size == 16 + assert restored_config.max_active_learning_steps == 100 + assert restored_config.fine_tuning_lr == 1e-5 + assert restored_config.reset_optim_states is False + assert restored_config.skip_training is True + assert restored_config.checkpoint_interval == 5 + assert restored_config.num_dataloader_workers == 4 + assert restored_config.device == torch.device("cpu") + assert restored_config.dtype == torch.float32 + + def test_to_json_from_json_minimal(self): + """Test serialization round-trip with minimal config.""" + config = DriverConfig(batch_size=8) + + json_str = config.to_json() + restored_config = DriverConfig.from_json(json_str) + + assert restored_config.batch_size == 8 + assert restored_config.max_active_learning_steps == float("inf") + assert restored_config.skip_training is False + + def test_from_json_with_kwargs_override(self): + """Test that kwargs can override deserialized values.""" + config = DriverConfig(batch_size=4, checkpoint_interval=10) + json_str = config.to_json() + + # Override batch_size during deserialization + restored_config = DriverConfig.from_json(json_str, batch_size=32) + + assert restored_config.batch_size == 32 + assert restored_config.checkpoint_interval == 10 + + def test_from_json_with_collate_fn(self): + """Test providing non-serializable collate_fn via kwargs.""" + + def custom_collate(batch): + return batch + + config = DriverConfig(batch_size=8) + json_str = config.to_json() + + # Provide collate_fn during deserialization + restored_config = DriverConfig.from_json(json_str, collate_fn=custom_collate) + + assert restored_config.collate_fn is custom_collate + assert restored_config.batch_size == 8 + + def test_to_json_from_json_different_dtypes(self): + """Test serialization with different dtype values.""" + import torch + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + config = DriverConfig(batch_size=4, dtype=dtype) + json_str = config.to_json() + restored_config = DriverConfig.from_json(json_str) + assert restored_config.dtype == dtype diff --git a/test/active_learning/test_driver.py b/test/active_learning/test_driver.py new file mode 100644 index 0000000000..50fb0a5b8a --- /dev/null +++ b/test/active_learning/test_driver.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for active learning Driver class.""" + +from queue import Queue +from uuid import UUID + +import pytest +from torch.optim import SGD +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DataLoader + +from physicsnemo.active_learning.config import ( + DriverConfig, + OptimizerConfig, + StrategiesConfig, + TrainingConfig, +) +from physicsnemo.active_learning.driver import Driver + + +@pytest.fixture +def minimal_config(tmp_path) -> DriverConfig: + """A minimal functioning configuration""" + return DriverConfig(batch_size=4, root_log_dir=tmp_path) + + +def test_minimal_driver_init( + minimal_config: DriverConfig, mock_module, mock_query_strategy +): + """Test a minimal driver initialization""" + # Skip training and metrology for minimal initialization + minimal_config.skip_training = True + minimal_config.skip_metrology = True + minimal_config.skip_labeling = True + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + ) + assert driver.config == minimal_config + assert driver.strategies_config == strategies_config + assert hasattr(driver, "logger") + assert driver.active_learning_step_idx == 0 + # assert strategies were attached + mock_query_strategy.attach.assert_called_once_with(driver) + # check that queue was initialized + assert isinstance(driver.query_queue, Queue) + # make sure run ID is assigned something valid + assert isinstance(driver.run_id, str) + assert UUID(driver.run_id, version=4) + assert isinstance(driver.short_run_id, str) + # make sure the log file exists + assert (driver.log_dir / f"{driver.run_id}.log").exists() + + +def test_driver_configure_optimizer( + minimal_config: DriverConfig, + mock_module, + mock_query_strategy, + mock_data_pool, + mock_training_loop, +): + """Test the driver's optimizer configuration""" + # Skip metrology and labeling for this test + minimal_config.skip_metrology = True + minimal_config.skip_labeling = True + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + ) + training_config = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + driver.configure_optimizer() + assert driver.optimizer is not None + assert driver.is_optimizer_configured + + # now try with a non-default optimizer and scheduler + optimizer_config = OptimizerConfig( + optimizer_cls=SGD, + optimizer_kwargs={"lr": 1e-3}, + scheduler_cls=StepLR, + scheduler_kwargs={"step_size": 10}, + ) + training_config_custom = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + optimizer_config=optimizer_config, + ) + new_driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config_custom, + ) + new_driver.configure_optimizer() + assert new_driver.is_optimizer_configured + assert new_driver.is_lr_scheduler_configured + # this should work without fail + new_driver.lr_scheduler.step() + + +def test_construct_loader( + minimal_config: DriverConfig, mock_data_pool, mock_module, mock_query_strategy +): + """Test the driver's loader construction""" + # Skip all phases for this construction test + minimal_config.skip_training = True + minimal_config.skip_metrology = True + minimal_config.skip_labeling = True + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + ) + loader = driver._construct_dataloader(mock_data_pool, shuffle=False) + assert isinstance(loader, DataLoader) + + +def test_minimal_learning_step( + minimal_config: DriverConfig, + mock_module, + mock_query_strategy, + mock_training_loop, + mock_data_pool, +): + """Test the driver's learning step""" + # disable the extra steps + minimal_config.skip_labeling = True + minimal_config.skip_metrology = True + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + ) + + # Test that without skip_training=True and no training_config, initialization fails + with pytest.raises(ValueError, match="`training_config` must be provided"): + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + ) + + # Create training config + training_config = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + # should fail without passing a train_step_fn when not using a learner + with pytest.raises(ValueError, match="`train_step_fn` must be provided"): + driver.active_learning_step() + # this should work, despite that the train_step_fn is a dummy one + driver.active_learning_step(train_step_fn=lambda x: None) + assert driver.active_learning_step_idx == 1 + + +def test_query_label_pipeline( + minimal_config: DriverConfig, + mock_module, + mock_query_strategy, + mock_data_pool, + mock_label_strategy, + mock_queue, + mock_training_loop, +): + """Test the driver's query and label pipeline without training""" + minimal_config.skip_training = True + minimal_config.skip_metrology = True + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + label_strategy=mock_label_strategy, + ) + # Create a minimal training config with just train_datapool for labeling + training_config = TrainingConfig( + train_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + # this checks to make sure the query strategy was called + driver.active_learning_step() + assert driver.active_learning_step_idx == 1 + driver.query_strategies[0].assert_called_with(driver.query_queue) + # do it again, but this time we subsitute the queue with a populated one + driver.query_queue = mock_queue + driver.label_queue.put("literally anything") + driver.active_learning_step() + assert driver.active_learning_step_idx == 2 + driver.label_strategy.assert_called_with(driver.query_queue, driver.label_queue) + # check that the data pool was theoretically updated + assert len(driver.train_datapool) == 1 + mock_data_pool.append.assert_called_once() + + +def test_train_metrology_pipeline( + minimal_config: DriverConfig, + mock_module, + mock_query_strategy, + mock_data_pool, + mock_metrology_strategy, + mock_training_loop, +): + """Test the driver's train, metrology, and query pipeline""" + minimal_config.skip_labeling = True + # do not do any batching + minimal_config.collate_fn = lambda x: x + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + metrology_strategies=[mock_metrology_strategy], + ) + training_config = TrainingConfig( + train_datapool=mock_data_pool, + val_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + # patch the dataloader to return a dummy batch + driver.active_learning_step( + train_step_fn=lambda x: None, validate_step_fn=lambda x: None + ) + assert driver.active_learning_step_idx == 1 + driver.query_strategies[0].assert_called_with(driver.query_queue) + assert driver.train_loop_fn.call_count == 1 + driver.metrology_strategies[0].assert_called_once() + + +def test_full_step_pipeline( + minimal_config: DriverConfig, + mock_module, + mock_query_strategy, + mock_data_pool, + mock_label_strategy, + mock_metrology_strategy, + mock_training_loop, +): + """Test the fully specified active learning pipeline""" + minimal_config.collate_fn = lambda x: x + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + label_strategy=mock_label_strategy, + metrology_strategies=[mock_metrology_strategy], + ) + training_config = TrainingConfig( + train_datapool=mock_data_pool, + val_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + driver.query_queue.put("something from querying") + driver.label_queue.put("something from labeling") + driver.active_learning_step( + train_step_fn=lambda x: None, validate_step_fn=lambda x: None + ) + # every component should be called at least once + assert driver.active_learning_step_idx == 1 + driver.query_strategies[0].assert_called_with(driver.query_queue) + driver.metrology_strategies[0].assert_called_once() + # make sure the label strategy was called and new data added + driver.label_strategy.assert_called_with(driver.query_queue, driver.label_queue) + assert driver.train_datapool.append.call_count == 1 + + +def test_run_loop( + minimal_config: DriverConfig, + mock_module, + mock_query_strategy, + mock_data_pool, + mock_label_strategy, + mock_metrology_strategy, + mock_training_loop, +): + """Test a minimal configuration running the loop a few times""" + minimal_config.max_active_learning_steps = 5 + minimal_config.collate_fn = lambda x: x + minimal_config.fine_tuning_lr = 50.0 + strategies_config = StrategiesConfig( + query_strategies=[mock_query_strategy], + queue_cls=Queue, + label_strategy=mock_label_strategy, + metrology_strategies=[mock_metrology_strategy], + ) + training_config = TrainingConfig( + train_datapool=mock_data_pool, + val_datapool=mock_data_pool, + train_loop_fn=mock_training_loop, + max_training_epochs=10, + ) + driver = Driver( + config=minimal_config, + learner=mock_module, + strategies_config=strategies_config, + training_config=training_config, + ) + # artificially populate the queues + driver.query_queue.put("something from querying") + driver.label_queue.put("something from labeling") + driver( + train_step_fn=lambda x: None, + validate_step_fn=lambda x: None, + ) + assert driver.active_learning_step_idx == 5 + assert driver.label_strategy.call_count == 5 + assert driver.metrology_strategies[0].call_count == 5 + assert driver.query_strategies[0].call_count == 5 + assert driver.train_loop_fn.call_count == 5 + assert driver.optimizer.param_groups[0]["lr"] == 50.0 diff --git a/test/active_learning/test_loop.py b/test/active_learning/test_loop.py new file mode 100644 index 0000000000..4d17163d03 --- /dev/null +++ b/test/active_learning/test_loop.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for active learning training loop.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +import torch +from torch.utils.data import DataLoader, TensorDataset + +from physicsnemo.active_learning.loop import ( + DefaultTrainingLoop, + _recursive_data_device_cast, +) + +# Define device parametrization for reuse across tests +AVAILABLE_DEVICES = [torch.device("cpu")] + ( + [torch.device("cuda:0")] if torch.cuda.is_available() else [] +) +DEVICE_IDS = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +class TestRecursiveDataDeviceCast: + """Tests for _recursive_data_device_cast function.""" + + @pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float16, torch.bfloat16], + ids=["float32", "float16", "bfloat16"], + ) + def test_tensor_cast_dtype(self, dtype: torch.dtype): + """Test casting tensor to different dtypes.""" + tensor = torch.randn(4, 8, dtype=torch.float32) + result = _recursive_data_device_cast(tensor, dtype=dtype) + assert result.dtype == dtype + assert result.shape == tensor.shape + + @pytest.mark.parametrize("device", AVAILABLE_DEVICES, ids=DEVICE_IDS) + def test_tensor_cast_device(self, device: torch.device): + """Test moving tensor to different devices.""" + tensor = torch.randn(4, 8) + result = _recursive_data_device_cast(tensor, device=device) + assert result.device == device + assert result.shape == tensor.shape + + @pytest.mark.parametrize("device", AVAILABLE_DEVICES, ids=DEVICE_IDS) + def test_dict_cast(self, device: torch.device): + """Test casting dictionary of tensors.""" + data = { + "input": torch.randn(4, 8), + "target": torch.randn(4, 1), + "metadata": torch.tensor([1, 2, 3, 4]), + } + result = _recursive_data_device_cast(data, device=device, dtype=torch.float16) + + assert isinstance(result, dict) + assert set(result.keys()) == set(data.keys()) + for key in result: + assert result[key].device == device + if result[key].dtype.is_floating_point: + assert result[key].dtype == torch.float16 + assert result[key].shape == data[key].shape + + @pytest.mark.parametrize("device", AVAILABLE_DEVICES, ids=DEVICE_IDS) + def test_list_cast(self, device: torch.device): + """Test casting list of tensors.""" + data = [ + torch.randn(4, 8), + torch.randn(4, 1), + torch.tensor([1, 2, 3, 4]), + ] + result = _recursive_data_device_cast(data, device=device, dtype=torch.float16) + + assert isinstance(result, list) + assert len(result) == len(data) + for i, tensor in enumerate(result): + assert tensor.device == device + if tensor.dtype.is_floating_point: + assert tensor.dtype == torch.float16 + assert tensor.shape == data[i].shape + + @pytest.mark.parametrize("device", AVAILABLE_DEVICES, ids=DEVICE_IDS) + def test_tuple_cast(self, device: torch.device): + """Test casting tuple of tensors.""" + data = ( + torch.randn(4, 8), + torch.randn(4, 1), + torch.tensor([1, 2, 3, 4]), + ) + result = _recursive_data_device_cast(data, device=device, dtype=torch.float16) + + assert isinstance(result, tuple) + assert len(result) == len(data) + for i, tensor in enumerate(result): + assert tensor.device == device + if tensor.dtype.is_floating_point: + assert tensor.dtype == torch.float16 + assert tensor.shape == data[i].shape + + def test_nested_structures(self): + """Test casting nested data structures.""" + data = { + "batch": { + "input": torch.randn(4, 8), + "target": torch.randn(4, 1), + }, + "metadata": [ + torch.tensor([1, 2, 3, 4]), + torch.tensor([5, 6, 7, 8]), + ], + } + + result = _recursive_data_device_cast( + data, device=torch.device("cpu"), dtype=torch.float32 + ) + + assert isinstance(result, dict) + assert isinstance(result["batch"], dict) + assert isinstance(result["metadata"], list) + assert result["batch"]["input"].dtype == torch.float32 + assert result["batch"]["target"].dtype == torch.float32 + + def test_non_tensor_passthrough(self): + """Test that non-tensor data passes through unchanged.""" + data = { + "tensor": torch.randn(4, 8), + "string": "some_string", + "int": 42, + "float": 3.14, + } + + result = _recursive_data_device_cast(data, device=torch.device("cpu")) + + assert result["string"] == "some_string" + assert result["int"] == 42 + assert result["float"] == 3.14 + assert isinstance(result["tensor"], torch.Tensor) + + +class TestDefaultTrainingLoop: + """Tests for DefaultTrainingLoop class.""" + + def test_instantiation_default(self): + """Test instantiation with default parameters.""" + loop = DefaultTrainingLoop() + + assert loop.train_step_fn is None + assert loop.validate_step_fn is None + assert loop.enable_static_capture is True + assert loop.use_progress_bars is True + assert loop.dtype == torch.get_default_dtype() + + def test_instantiation_with_train_step(self): + """Test instantiation with custom train step function.""" + + def mock_train_step(model, batch): + return torch.tensor(0.5) + + loop = DefaultTrainingLoop(train_step_fn=mock_train_step) + assert loop.train_step_fn == mock_train_step + + def test_call_without_train_step_raises_error(self, mock_module): + """Test that calling loop without train_step_fn raises error.""" + loop = DefaultTrainingLoop() + optimizer = torch.optim.SGD(mock_module.parameters(), lr=0.01) + + # Create mock dataloader + dataset = TensorDataset(torch.randn(8, 64), torch.randn(8, 3)) + dataloader = DataLoader(dataset, batch_size=4) + + with pytest.raises(RuntimeError, match="No training step function provided"): + loop(mock_module, optimizer, dataloader, max_epochs=1) + + def test_basic_training_loop_execution(self, mock_module): + """Test basic execution of training loop with mocked components.""" + # Create a mock train step that returns a loss with backward method + mock_loss = MagicMock() + mock_loss.detach.return_value.item.return_value = 0.5 + mock_loss.backward = MagicMock() + + def mock_train_step(model, batch, *args, **kwargs): + return mock_loss + + loop = DefaultTrainingLoop( + enable_static_capture=False, + use_progress_bars=False, + ) + + optimizer = torch.optim.SGD(mock_module.parameters(), lr=0.01) + + # Wrap optimizer.step to track calls + original_optimizer_step = optimizer.step + optimizer.step = MagicMock(side_effect=original_optimizer_step) + + # Create mock dataloader with 2 batches + dataset = TensorDataset(torch.randn(8, 64), torch.randn(8, 3)) + dataloader = DataLoader(dataset, batch_size=4) + + # Run the loop, passing train_step_fn to __call__ + loop( + mock_module, + optimizer, + dataloader, + max_epochs=2, + train_step_fn=mock_train_step, + ) + + # Verify train step was called (2 epochs * 2 batches = 4 times) + assert mock_loss.backward.call_count == 4 + assert mock_loss.detach.call_count == 4 + # Verify optimizer.step was called once per batch (4 times total) + assert optimizer.step.call_count == 4 + + def test_training_with_validation(self, mock_module): + """Test training loop with validation step.""" + # Create mock train step + mock_loss = MagicMock() + mock_loss.detach.return_value.item.return_value = 0.5 + mock_loss.backward = MagicMock() + + def mock_train_step(model, batch, *args, **kwargs): + return mock_loss + + # Create mock validation step + mock_validate_step = MagicMock() + + loop = DefaultTrainingLoop( + enable_static_capture=False, + use_progress_bars=False, + ) + + optimizer = torch.optim.SGD(mock_module.parameters(), lr=0.01) + + # Wrap optimizer.step to track calls + original_optimizer_step = optimizer.step + optimizer.step = MagicMock(side_effect=original_optimizer_step) + + # Create mock dataloaders + train_dataset = TensorDataset(torch.randn(8, 64), torch.randn(8, 3)) + train_dataloader = DataLoader(train_dataset, batch_size=4) + + val_dataset = TensorDataset(torch.randn(4, 64), torch.randn(4, 3)) + val_dataloader = DataLoader(val_dataset, batch_size=4) + + # Run the loop, passing both step functions to __call__ + loop( + mock_module, + optimizer, + train_dataloader, + max_epochs=1, + validation_dataloader=val_dataloader, + train_step_fn=mock_train_step, + validate_step_fn=mock_validate_step, + ) + + # Verify training step was called (1 epoch * 2 batches) + assert mock_loss.backward.call_count == 2 + assert mock_loss.detach.call_count == 2 + # Verify optimizer.step was called once per training batch (2 times) + assert optimizer.step.call_count == 2 + # Verify validation step was called (1 epoch * 1 validation batch) + assert mock_validate_step.call_count == 1 + + def test_training_with_lr_scheduler(self, mock_module): + """Test training loop with learning rate scheduler.""" + mock_loss = MagicMock() + mock_loss.detach.return_value.item.return_value = 0.5 + mock_loss.backward = MagicMock() + + def mock_train_step(model, batch, *args, **kwargs): + return mock_loss + + loop = DefaultTrainingLoop( + enable_static_capture=False, + use_progress_bars=False, + ) + + optimizer = torch.optim.SGD(mock_module.parameters(), lr=0.01) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + + # Wrap optimizer.step and scheduler.step to track calls + original_optimizer_step = optimizer.step + optimizer.step = MagicMock(side_effect=original_optimizer_step) + original_scheduler_step = scheduler.step + scheduler.step = MagicMock(side_effect=original_scheduler_step) + + # Create mock dataloader with 2 batches + dataset = TensorDataset(torch.randn(8, 64), torch.randn(8, 3)) + dataloader = DataLoader(dataset, batch_size=4) + + # Run the loop, passing train_step_fn to __call__ + loop( + mock_module, + optimizer, + dataloader, + max_epochs=2, + lr_scheduler=scheduler, + train_step_fn=mock_train_step, + ) + + # Verify training step was called (2 epochs * 2 batches = 4 times) + assert mock_loss.backward.call_count == 4 + assert mock_loss.detach.call_count == 4 + # Verify optimizer.step was called once per batch (4 times total) + assert optimizer.step.call_count == 4 + # Verify scheduler.step was called once per batch (4 times total) + assert scheduler.step.call_count == 4 + + def test_device_override_in_call(self, mock_module): + """Test that device specified in call overrides constructor device.""" + mock_loss = MagicMock() + mock_loss.detach.return_value.item.return_value = 0.5 + mock_loss.backward = MagicMock() + + def mock_train_step(model, batch, *args, **kwargs): + return mock_loss + + loop = DefaultTrainingLoop( + device="cpu", + enable_static_capture=False, + use_progress_bars=False, + ) + + optimizer = torch.optim.SGD(mock_module.parameters(), lr=0.01) + dataset = TensorDataset(torch.randn(8, 64), torch.randn(8, 3)) + dataloader = DataLoader(dataset, batch_size=4) + + # Call with device override and train_step_fn + loop( + mock_module, + optimizer, + dataloader, + max_epochs=1, + device="cpu", + train_step_fn=mock_train_step, + ) + + # Verify training completed + assert mock_loss.backward.call_count == 2 + + def test_dtype_override_in_call(self, mock_module): + """Test that dtype specified in call overrides constructor dtype.""" + mock_loss = MagicMock() + mock_loss.detach.return_value.item.return_value = 0.5 + mock_loss.backward = MagicMock() + + def mock_train_step(model, batch, *args, **kwargs): + return mock_loss + + loop = DefaultTrainingLoop( + dtype=torch.float32, + enable_static_capture=False, + use_progress_bars=False, + ) + + optimizer = torch.optim.SGD(mock_module.parameters(), lr=0.01) + dataset = TensorDataset(torch.randn(8, 64), torch.randn(8, 3)) + dataloader = DataLoader(dataset, batch_size=4) + + # Call with dtype override and train_step_fn + loop( + mock_module, + optimizer, + dataloader, + max_epochs=1, + dtype=torch.float16, + train_step_fn=mock_train_step, + ) + + # Verify training completed + assert mock_loss.backward.call_count == 2 + + def test_train_step_fn_override_in_call(self, mock_module): + """Test that train_step_fn in call overrides constructor train_step_fn.""" + # Constructor train step + constructor_loss = MagicMock() + constructor_loss.detach.return_value.item.return_value = 0.3 + constructor_loss.backward = MagicMock() + + def constructor_train_step(model, batch, *args, **kwargs): + return constructor_loss + + # Call train step + call_loss = MagicMock() + call_loss.detach.return_value.item.return_value = 0.5 + call_loss.backward = MagicMock() + + def call_train_step(model, batch, *args, **kwargs): + return call_loss + + loop = DefaultTrainingLoop( + train_step_fn=constructor_train_step, + enable_static_capture=False, + use_progress_bars=False, + ) + + optimizer = torch.optim.SGD(mock_module.parameters(), lr=0.01) + dataset = TensorDataset(torch.randn(8, 64), torch.randn(8, 3)) + dataloader = DataLoader(dataset, batch_size=4) + + # This should use the call_train_step due to static capture being disabled + # and the override logic + loop( + mock_module, + optimizer, + dataloader, + max_epochs=1, + train_step_fn=call_train_step, + ) + + # Since static capture is disabled and we pass train_step_fn to call, + # it won't be used directly in this implementation + # This test verifies the loop runs without error + assert True diff --git a/test/active_learning/test_registry.py b/test/active_learning/test_registry.py new file mode 100644 index 0000000000..d7b43c38d1 --- /dev/null +++ b/test/active_learning/test_registry.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from torch import nn + +from physicsnemo.active_learning._registry import ActiveLearningRegistry + + +@pytest.fixture(scope="function") +def simple_class(): + class SimpleClass: + def __init__(self, param1: int, param2: str): + self.param1 = param1 + self.param2 = param2 + + return SimpleClass + + +def test_initialization(): + """Test that the registry can be initialized.""" + registry = ActiveLearningRegistry() + assert registry._registry == {} + + +def test_registration(simple_class): + """Test that the register method registers a class""" + registry = ActiveLearningRegistry() + + registry.register("my_strategy")(simple_class) + + assert registry.is_registered("my_strategy") + assert "my_strategy" in registry.registered_names + assert registry._registry["my_strategy"] == simple_class + + +def test_missing_registration(): + """Test accessing a missing class raises a KeyError""" + registry = ActiveLearningRegistry() + with pytest.raises(KeyError): + registry["missing_strategy"] + + with pytest.raises(NameError): + registry.construct("missing_strategy") + + +def test_construction(simple_class): + """Test that the construct method returns an instance of the registered class.""" + registry = ActiveLearningRegistry() + + registry.register("my_strategy")(simple_class) + + strategy = registry.construct("my_strategy", param1=42, param2="test") + assert strategy.param1 == 42 + assert strategy.param2 == "test" + + +def test_torch_module(): + """Test that registry can construct different types of objects""" + registry = ActiveLearningRegistry() + + @registry.register("simple_model") + class SimpleModel(nn.Module): + def __init__(self, input_size: int, output_size: int): + super().__init__() + self.linear = nn.Linear(input_size, output_size) + + input_size = 10 + output_size = 10 + model = registry.construct( + "simple_model", input_size=input_size, output_size=output_size + ) + assert isinstance(model, nn.Module) + assert model.linear.weight.shape == (output_size, input_size) + + +def test_bad_construction(): + """Test that the construct method raises an error with bad arguments""" + registry = ActiveLearningRegistry() + + @registry.register("simple_model") + class SimpleModel(nn.Module): + def __init__(self, input_size: int, output_size: int): + super().__init__() + self.linear = nn.Linear(input_size, output_size) + + with pytest.raises(TypeError): + registry.construct("simple_model", input_size=10, bad_arg=215) + + +def test_get_class_no_module_path(): + """Test that the get_class method returns a class from the registry or module path.""" + from time import monotonic + + registry = ActiveLearningRegistry() + cls = registry.get_class("monotonic", "time") + assert cls == monotonic + + +def test_get_class_with_module_path(): + """Test that the get_class method returns a class from a module path.""" + registry = ActiveLearningRegistry() + cls = registry.get_class("Linear", "torch.nn") + model = cls(8, 16) + + # add the import now and make sure they are equivalent + from torch import nn + + assert isinstance(model, nn.Linear) + + +def test_get_class_missing(): + """Test that the get_class method raises an error in three scenarios.""" + registry = ActiveLearningRegistry() + # when the module is completely missing + with pytest.raises(ModuleNotFoundError): + registry.get_class("missing_class", "missing_module") + # when we are missing the class in a module + with pytest.raises(NameError): + registry.get_class("missing_class", "torch.nn") + # when we are missing the class in the registry and no module path is provided + with pytest.raises(NameError): + registry.get_class("missing_class")