diff --git a/.gitignore b/.gitignore index eca74dbda..a55e04dfe 100644 --- a/.gitignore +++ b/.gitignore @@ -102,7 +102,3 @@ target/ # uv uv.lock .python-version - -# Serena cache -.serena/ -.claude/ diff --git a/docs/how_to_guide.rst b/docs/how_to_guide.rst index ed052ff19..1b87427ed 100644 --- a/docs/how_to_guide.rst +++ b/docs/how_to_guide.rst @@ -47,6 +47,7 @@ Training how_to_guide/07_save_and_load.ipynb how_to_guide/07_resume_training.ipynb how_to_guide/21_hyperparameter_tuning.ipynb + how_to_guide/22_experiment_tracking.ipynb Sampling diff --git a/docs/how_to_guide/22_experiment_tracking.ipynb b/docs/how_to_guide/22_experiment_tracking.ipynb new file mode 100644 index 000000000..a312c7483 --- /dev/null +++ b/docs/how_to_guide/22_experiment_tracking.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7fb27b941602401d91542211134fc71a", + "metadata": {}, + "source": [ + "# How to track experiments" + ] + }, + { + "cell_type": "markdown", + "id": "9dcf97f2", + "metadata": {}, + "source": [ + "Experiment tracking helps compare model variants and keep a record of hyperparameters and training metrics. By default, `sbi` logs to TensorBoard. You can also bring your own tracker by implementing the lightweight `Tracker` protocol and passing it as `tracker=...`.\n", + "\n", + "If using your own tracker (e.g., `wandb`, `mlflow` or `trackio`), note that the run lifecycle (e.g., `wandb.init`, `mlflow.start_run`) is handled on the user side." + ] + }, + { + "cell_type": "markdown", + "id": "6492510e", + "metadata": {}, + "source": [ + "## Define a minimal training setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "898c316a", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from sbi.inference import NPE\n", + "from sbi.neural_nets import posterior_nn\n", + "from sbi.neural_nets.embedding_nets import FCEmbedding\n", + "from sbi.utils import BoxUniform\n", + "\n", + "torch.manual_seed(0)\n", + "\n", + "def simulator(theta):\n", + " return theta + 0.1 * torch.randn_like(theta)\n", + "\n", + "prior = BoxUniform(low=-2 * torch.ones(2), high=2 * torch.ones(2))\n", + "\n", + "theta = prior.sample((5000,))\n", + "x = simulator(theta)\n", + "\n", + "embedding_net = FCEmbedding(input_dim=x.shape[1], output_dim=32)\n", + "density_estimator = posterior_nn(\n", + " model=\"nsf\",\n", + " embedding_net=embedding_net,\n", + " num_transforms=4,\n", + ")\n", + "\n", + "train_kwargs = dict(\n", + " max_num_epochs=50,\n", + " training_batch_size=128,\n", + " validation_fraction=0.1,\n", + " show_train_summary=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9142d4b6", + "metadata": {}, + "source": [ + "## Train with a tracker\n", + "\n", + "By default, `sbi` uses a TensorBoard tracker to log training loss, validation loss,\n", + "number of epochs and more. \n", + "\n", + "When you want to track additional quantities, you instantiate the tracker yourself and\n", + "pass it to the inference class:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a62bc10", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.tensorboard.writer import SummaryWriter\n", + "\n", + "from sbi.utils.tracking import TensorBoardTracker\n", + "\n", + "tracker = TensorBoardTracker(SummaryWriter(\"sbi-logs\"))\n", + "tracker.log_params({\"embedding_dim\": 32, \"num_transforms\": 4})\n", + "\n", + "inference = NPE(prior=prior, tracker=tracker)\n", + "inference.append_simulations(theta, x)\n", + "estimator = inference.train(**train_kwargs)\n", + "posterior = inference.build_posterior(estimator)" + ] + }, + { + "cell_type": "markdown", + "id": "c4da9894", + "metadata": {}, + "source": [ + "## View TensorBoard results\n", + "\n", + "You can then view your tracked run(s) on a TensorBoard shown on your localhost in the\n", + "browser. By default, `sbi` will create a log directory `sbi-logs` at the location the\n", + "training script was called.\n", + "\n", + "```bash\n", + "tensorboard --logdir=sbi-logs\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "6f2777dd", + "metadata": {}, + "source": [ + "## Using other trackers\n", + "\n", + "To enable usage of other trackers, we provide a lightweight `Protocol` that trackers\n", + "need to follow. You can implement a small adapter that satisfies the `Tracker` protocol\n", + "and pass it to `tracker=`. Below are minimal examples for common tools." + ] + }, + { + "cell_type": "markdown", + "id": "43644d68", + "metadata": {}, + "source": [ + "```python\n", + "# W&B adapter (requires `wandb.init()` before training)\n", + "class WandBAdapter:\n", + " log_dir = None\n", + "\n", + " def __init__(self, run):\n", + " self._run = run\n", + "\n", + " def log_metric(self, name, value, step=None):\n", + " self._run.log({name: value}, step=step)\n", + "\n", + " def log_metrics(self, metrics, step=None):\n", + " self._run.log(metrics, step=step)\n", + "\n", + " def log_params(self, params):\n", + " self._run.config.update(params)\n", + "\n", + " def add_figure(self, name, figure, step=None):\n", + " import wandb\n", + " self._run.log({name: wandb.Image(figure)}, step=step)\n", + "\n", + " def flush(self):\n", + " pass\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "031651b3", + "metadata": {}, + "source": [ + "```python\n", + "# MLflow adapter (configure tracking URI separately)\n", + "class MLflowAdapter:\n", + " log_dir = None\n", + "\n", + " def __init__(self, mlflow):\n", + " self._mlflow = mlflow\n", + "\n", + " def log_metric(self, name, value, step=None):\n", + " self._mlflow.log_metric(name, value, step=step)\n", + "\n", + " def log_metrics(self, metrics, step=None):\n", + " for name, value in metrics.items():\n", + " self.log_metric(name, value, step=step)\n", + "\n", + " def log_params(self, params):\n", + " self._mlflow.log_params(params)\n", + "\n", + " def add_figure(self, name, figure, step=None):\n", + " self._mlflow.log_figure(figure, f\"{name}.png\")\n", + "\n", + " def flush(self):\n", + " pass\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "6d891d85", + "metadata": {}, + "source": [ + "```python\n", + "# Trackio adapter (requires `trackio.init()` before training)\n", + "class TrackioAdapter:\n", + " log_dir = None\n", + "\n", + " def __init__(self, trackio):\n", + " self._trackio = trackio\n", + "\n", + " def log_metric(self, name, value, step=None):\n", + " self._trackio.log({name: value}, step=step)\n", + "\n", + " def log_metrics(self, metrics, step=None):\n", + " self._trackio.log(metrics, step=step)\n", + "\n", + " def log_params(self, params):\n", + " self._trackio.log(params)\n", + "\n", + " def add_figure(self, name, figure, step=None):\n", + " self._trackio.log_image(figure, name=name, step=step)\n", + "\n", + " def flush(self):\n", + " pass\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "43b453f7", + "metadata": {}, + "source": [ + "When using external trackers, create an adapter instance and pass it to `tracker=`:\n", + "\n", + "```python\n", + "# wandb.init(...)\n", + "tracker = WandBAdapter(wandb.run)\n", + "inference = NPE(prior=prior, density_estimator=density_estimator, tracker=tracker)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "5c748c81", + "metadata": {}, + "source": [ + "## Log figures\n", + "\n", + "Trackers can also store matplotlib figures. For example, after training you can log a pairplot:\n", + "\n", + "```python\n", + "from sbi.analysis import pairplot\n", + "\n", + "x_o = x[:1]\n", + "samples = posterior.sample((1000,), x=x_o)\n", + "fig, _ = pairplot(samples)\n", + "tracker.add_figure(\"posterior_pairplot\", fig, step=0)\n", + "```\n", + "\n", + "Figure logging depends on the tracker implementation (e.g., `wandb.Image`, `mlflow.log_figure`)." + ] + }, + { + "cell_type": "markdown", + "id": "ab99ec92", + "metadata": {}, + "source": [ + "## Custom training loop (optional)\n", + "\n", + "If you want to log custom diagnostics per epoch, use the training interface tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/18_training_interface.html." + ] + }, + { + "cell_type": "markdown", + "id": "b206a6a7", + "metadata": {}, + "source": [ + "## Notes\n", + "\n", + "- Each tool supports richer logging (artifacts, checkpoints, plots), but the patterns above are enough to track hyperparameters, epoch-wise losses, and validation metrics.\n", + "- If you already use Optuna or other sweep tools, you can call the logger inside the objective function to log each trial." + ] + }, + { + "cell_type": "markdown", + "id": "46465002", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sbi (3.12.9)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sbi/analysis/tensorboard_output.py b/sbi/analysis/tensorboard_output.py index 447158de2..ee1f203a2 100644 --- a/sbi/analysis/tensorboard_output.py +++ b/sbi/analysis/tensorboard_output.py @@ -39,10 +39,10 @@ def plot_summary( ylabel: Optional[List[str]] = None, plot_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[Figure, Axes]: - """Plots data logged by the tensorboard summary writer of an inference object. + """Plots data logged by the TensorBoard tracker of an inference object. Args: - inference: inference object that holds a ._summary_writer.log_dir attribute. + inference: inference object that holds a tracker with a log_dir attribute. Optionally the log_dir itself. tags: list of summery writer tags to visualize. disable_tensorboard_prompt: flag to disable the logging of how to run @@ -67,7 +67,12 @@ def plot_summary( size_guidance.update(scalars=tensorboard_scalar_limit) if isinstance(inference, NeuralInference): - log_dir = inference._summary_writer.log_dir + log_dir = getattr(inference._tracker, "log_dir", None) + if log_dir is None: + raise ValueError( + "Inference tracker does not expose a log_dir. " + "Use a TensorBoard tracker or pass a log directory directly." + ) elif isinstance(inference, Path): log_dir = inference else: diff --git a/sbi/inference/trainers/base.py b/sbi/inference/trainers/base.py index 839296e81..c7f71a6ec 100644 --- a/sbi/inference/trainers/base.py +++ b/sbi/inference/trainers/base.py @@ -57,7 +57,7 @@ ConditionalEstimatorType, ConditionalVectorFieldEstimator, ) -from sbi.sbi_types import TorchTransform +from sbi.sbi_types import TorchTransform, Tracker from sbi.utils import ( check_prior, get_log_root, @@ -70,6 +70,7 @@ from sbi.utils.sbiutils import get_simulations_since_round from sbi.utils.simulation_utils import simulate_for_sbi from sbi.utils.torchutils import check_if_prior_on_device, process_device +from sbi.utils.tracking import TensorBoardTracker from sbi.utils.user_input_checks import ( check_sbi_inputs, process_prior, @@ -175,6 +176,7 @@ def __init__( device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Base class for inference methods. @@ -187,8 +189,10 @@ def __init__( perform all posterior operations, e.g. gpu or cpu. logging_level: Minimum severity of messages to log. One of the strings "INFO", "WARNING", "DEBUG", "ERROR" and "CRITICAL". - summary_writer: A `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ @@ -219,11 +223,19 @@ def __init__( self._best_val_loss = float("Inf") self._epochs_since_last_improvement = 0 - self._summary_writer = ( - self._default_summary_writer() if summary_writer is None else summary_writer - ) + if summary_writer is not None: + warn( + "summary_writer is deprecated. Use tracker instead.", + FutureWarning, + stacklevel=2, + ) + if tracker is not None: + raise ValueError("Pass only one of summary_writer or tracker.") + tracker = TensorBoardTracker(summary_writer) + + self._tracker = self._default_tracker() if tracker is None else tracker - # Logging during training (by SummaryWriter). + # Logging during training. self._summary = dict( epochs_trained=[], best_validation_loss=[], @@ -1134,14 +1146,14 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool: return converged - def _default_summary_writer(self) -> SummaryWriter: - """Return summary writer logging to method- and simulator-specific directory.""" + def _default_tracker(self) -> Tracker: + """Return default tracker logging to a TensorBoard directory.""" method = self.__class__.__name__ logdir = Path( get_log_root(), method, datetime.now().isoformat().replace(":", "_") ) - return SummaryWriter(logdir) + return TensorBoardTracker(SummaryWriter(logdir)) def _report_convergence_at_end( self, epoch: int, stop_after_epochs: int, max_num_epochs: int @@ -1163,11 +1175,11 @@ def _summarize( self, round_: int, ) -> None: - """Update the summary_writer with statistics for a given round. + """Update the tracker with statistics for a given round. During training several performance statistics are added to the summary, e.g., using `self._summary['key'].append(value)`. This function writes these values - into summary writer object. + into the tracker. Args: round: index of round @@ -1186,17 +1198,17 @@ def _summarize( """ - # Add most recent training stats to summary writer. - self._summary_writer.add_scalar( - tag="epochs_trained", - scalar_value=self._summary["epochs_trained"][-1], - global_step=round_ + 1, + # Add most recent training stats to tracker. + self._tracker.log_metric( + name="epochs_trained", + value=self._summary["epochs_trained"][-1], + step=round_ + 1, ) - self._summary_writer.add_scalar( - tag="best_validation_loss", - scalar_value=self._summary["best_validation_loss"][-1], - global_step=round_ + 1, + self._tracker.log_metric( + name="best_validation_loss", + value=self._summary["best_validation_loss"][-1], + step=round_ + 1, ) # Add validation loss for every epoch. @@ -1207,27 +1219,27 @@ def _summarize( .item() ) for i, vlp in enumerate(self._summary["validation_loss"][offset:]): - self._summary_writer.add_scalar( - tag="validation_loss", - scalar_value=vlp, - global_step=offset + i, + self._tracker.log_metric( + name="validation_loss", + value=vlp, + step=int(offset + i), ) for i, tlp in enumerate(self._summary["training_loss"][offset:]): - self._summary_writer.add_scalar( - tag="training_loss", - scalar_value=tlp, - global_step=offset + i, + self._tracker.log_metric( + name="training_loss", + value=tlp, + step=int(offset + i), ) for i, eds in enumerate(self._summary["epoch_durations_sec"][offset:]): - self._summary_writer.add_scalar( - tag="epoch_durations_sec", - scalar_value=eds, - global_step=offset + i, + self._tracker.log_metric( + name="epoch_durations_sec", + value=eds, + step=int(offset + i), ) - self._summary_writer.flush() + self._tracker.flush() @staticmethod def _describe_round(round_: int, summary: Dict[str, list]) -> str: @@ -1266,11 +1278,11 @@ def __getstate__(self) -> Dict: "changes in the following two ways: " "1) `.train(..., retrain_from_scratch=True)` is not supported. " "2) When the loaded object calls the `.train()` method, it generates a new " - "tensorboard summary writer (instead of appending to the current one).", + "tracker instance (instead of appending to the current one).", stacklevel=2, ) dict_to_save = {} - unpicklable_attributes = ["_summary_writer", "_build_neural_net"] + unpicklable_attributes = ["_tracker", "_build_neural_net"] for key in self.__dict__: if key in unpicklable_attributes: dict_to_save[key] = None @@ -1281,13 +1293,13 @@ def __getstate__(self) -> Dict: def __setstate__(self, state_dict: Dict): """Sets the state when being loaded from pickle. - Also creates a new summary writer (because the previous one was set to `None` + Also creates a new tracker (because the previous one was set to `None` during serializing, see `__get_state__()`). Args: state_dict: State to be restored. """ - state_dict["_summary_writer"] = self._default_summary_writer() + state_dict["_tracker"] = self._default_tracker() vars(self).update(state_dict) diff --git a/sbi/inference/trainers/marginal/marginal_base.py b/sbi/inference/trainers/marginal/marginal_base.py index fa9f7105f..dd5787725 100644 --- a/sbi/inference/trainers/marginal/marginal_base.py +++ b/sbi/inference/trainers/marginal/marginal_base.py @@ -6,6 +6,7 @@ from datetime import datetime from pathlib import Path from typing import Any, Callable, Optional, Tuple, Union +from warnings import warn import torch from torch import Tensor @@ -20,8 +21,10 @@ reshape_to_batch_event, ) from sbi.neural_nets.factory import ZukoFlowType, marginal_nn +from sbi.sbi_types import Tracker from sbi.utils import check_estimator_arg, get_log_root from sbi.utils.torchutils import assert_all_finite, process_device +from sbi.utils.tracking import TensorBoardTracker DensityEstimatorType = Union[ZukoFlowType, str, Callable[[Tensor], Any]] @@ -38,6 +41,7 @@ def __init__( density_estimator: DensityEstimatorType = ZukoFlowType.NSF, device: str = "cpu", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): """Initialize the marginal trainer. @@ -55,8 +59,10 @@ def __init__( If a callable, it must be a function that returns a neural network that inherits from `UnconditionalDensityEstimator`. device: Device to use for training. Can be "cpu" or "cuda". - summary_writer: Summary writer for logging training progress. If None, - a new writer is created. + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training progress. If None, + a TensorBoard tracker is created. show_progress_bars: Whether to show progress bars during training. """ @@ -66,11 +72,19 @@ def __init__( self._show_progress_bars = show_progress_bars self._val_loss = float("Inf") - self._summary_writer = ( - self._default_summary_writer() if summary_writer is None else summary_writer - ) + if summary_writer is not None: + warn( + "summary_writer is deprecated. Use tracker instead.", + FutureWarning, + stacklevel=2, + ) + if tracker is not None: + raise ValueError("Pass only one of summary_writer or tracker.") + tracker = TensorBoardTracker(summary_writer) + + self._tracker = self._default_tracker() if tracker is None else tracker - # Logging during training (by SummaryWriter). + # Logging during training. self._summary = dict( epochs_trained=[], best_validation_loss=[], @@ -283,14 +297,14 @@ def train( return deepcopy(self._neural_net) - def _default_summary_writer(self) -> SummaryWriter: - """Return summary writer logging to method- and simulator-specific directory.""" + def _default_tracker(self) -> Tracker: + """Return default tracker logging to a TensorBoard directory.""" method = self.__class__.__name__ logdir = Path( get_log_root(), method, datetime.now().isoformat().replace(":", "_") ) - return SummaryWriter(logdir) + return TensorBoardTracker(SummaryWriter(logdir)) def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """Return whether the training converged yet and save best model state so far. @@ -328,11 +342,11 @@ def _summarize( self, round_: int, ) -> None: - """Update the summary_writer with statistics for a given round. + """Update the tracker with statistics for a given round. During training several performance statistics are added to the summary, e.g., using `self._summary['key'].append(value)`. This function writes these values - into summary writer object. + into the tracker. Args: round: index of round @@ -351,17 +365,17 @@ def _summarize( """ - # Add most recent training stats to summary writer. - self._summary_writer.add_scalar( - tag="epochs_trained", - scalar_value=self._summary["epochs_trained"][-1], - global_step=round_ + 1, + # Add most recent training stats to tracker. + self._tracker.log_metric( + name="epochs_trained", + value=self._summary["epochs_trained"][-1], + step=round_ + 1, ) - self._summary_writer.add_scalar( - tag="best_validation_loss", - scalar_value=self._summary["best_validation_loss"][-1], - global_step=round_ + 1, + self._tracker.log_metric( + name="best_validation_loss", + value=self._summary["best_validation_loss"][-1], + step=round_ + 1, ) # Add validation loss for every epoch. @@ -372,27 +386,27 @@ def _summarize( .item() ) for i, vlp in enumerate(self._summary["validation_loss"][offset:]): - self._summary_writer.add_scalar( - tag="validation_loss", - scalar_value=vlp, - global_step=offset + i, + self._tracker.log_metric( + name="validation_loss", + value=vlp, + step=int(offset + i), ) for i, tlp in enumerate(self._summary["training_loss"][offset:]): - self._summary_writer.add_scalar( - tag="training_loss", - scalar_value=tlp, - global_step=offset + i, + self._tracker.log_metric( + name="training_loss", + value=tlp, + step=int(offset + i), ) for i, eds in enumerate(self._summary["epoch_durations_sec"][offset:]): - self._summary_writer.add_scalar( - tag="epoch_durations_sec", - scalar_value=eds, - global_step=offset + i, + self._tracker.log_metric( + name="epoch_durations_sec", + value=eds, + step=int(offset + i), ) - self._summary_writer.flush() + self._tracker.flush() @staticmethod def _maybe_show_progress(show: bool, epoch: int) -> None: diff --git a/sbi/inference/trainers/nle/mnle.py b/sbi/inference/trainers/nle/mnle.py index 3c22536b9..7a52874bc 100644 --- a/sbi/inference/trainers/nle/mnle.py +++ b/sbi/inference/trainers/nle/mnle.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Literal, Optional, Union from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.posteriors.posterior_parameters import ( @@ -15,7 +16,7 @@ from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer from sbi.neural_nets.estimators import MixedDensityEstimator from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries @@ -40,7 +41,8 @@ def __init__( ] = "mnle", device: str = "cpu", logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize MNLE. @@ -60,8 +62,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ diff --git a/sbi/inference/trainers/nle/nle_a.py b/sbi/inference/trainers/nle/nle_a.py index 99e29213f..2a50da925 100644 --- a/sbi/inference/trainers/nle/nle_a.py +++ b/sbi/inference/trainers/nle/nle_a.py @@ -4,13 +4,14 @@ from typing import Literal, Optional, Union from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer from sbi.neural_nets.estimators.base import ( ConditionalDensityEstimator, ConditionalEstimatorBuilder, ) -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries @@ -31,7 +32,8 @@ def __init__( ] = "maf", device: str = "cpu", logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize Neural Likelihood Estimation. @@ -51,8 +53,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ diff --git a/sbi/inference/trainers/nle/nle_base.py b/sbi/inference/trainers/nle/nle_base.py index 6ebe5ddb2..74dc99535 100644 --- a/sbi/inference/trainers/nle/nle_base.py +++ b/sbi/inference/trainers/nle/nle_base.py @@ -27,7 +27,7 @@ from sbi.neural_nets.estimators.shape_handling import ( reshape_to_batch_event, ) -from sbi.sbi_types import TorchTransform +from sbi.sbi_types import TorchTransform, Tracker from sbi.utils import check_estimator_arg, x_shape_from_simulation from sbi.utils.torchutils import assert_all_finite @@ -43,6 +43,7 @@ def __init__( device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Base class for `Neural Likelihood Estimation` methods. @@ -69,6 +70,7 @@ def __init__( device=device, logging_level=logging_level, summary_writer=summary_writer, + tracker=tracker, show_progress_bars=show_progress_bars, ) diff --git a/sbi/inference/trainers/npe/mnpe.py b/sbi/inference/trainers/npe/mnpe.py index fd2f5c87a..668814c10 100644 --- a/sbi/inference/trainers/npe/mnpe.py +++ b/sbi/inference/trainers/npe/mnpe.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Literal, Optional, Union from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.posteriors.posterior_parameters import ( @@ -16,7 +17,7 @@ from sbi.inference.trainers.npe.npe_c import NPE_C from sbi.neural_nets.estimators import MixedDensityEstimator from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries @@ -71,7 +72,8 @@ def __init__( ] = "mnpe", device: str = "cpu", logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize Mixed Neural Posterior Estimation (MNPE). @@ -91,8 +93,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ diff --git a/sbi/inference/trainers/npe/npe_a.py b/sbi/inference/trainers/npe/npe_a.py index fe5a0e09d..d680f21f5 100644 --- a/sbi/inference/trainers/npe/npe_a.py +++ b/sbi/inference/trainers/npe/npe_a.py @@ -11,6 +11,7 @@ from pyknos.nflows.transforms import CompositeTransform from torch import Tensor from torch.distributions import Distribution, MultivariateNormal +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.posteriors.direct_posterior import DirectPosterior from sbi.inference.trainers.npe.npe_base import ( @@ -20,7 +21,7 @@ ConditionalDensityEstimator, ConditionalEstimatorBuilder, ) -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils import torchutils from sbi.utils.sbiutils import ( batched_mixture_mv, @@ -59,7 +60,8 @@ def __init__( num_components: int = 10, device: str = "cpu", logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NPE-A [1]. @@ -89,8 +91,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during training. """ diff --git a/sbi/inference/trainers/npe/npe_b.py b/sbi/inference/trainers/npe/npe_b.py index 9e3406281..c9915c793 100644 --- a/sbi/inference/trainers/npe/npe_b.py +++ b/sbi/inference/trainers/npe/npe_b.py @@ -6,6 +6,7 @@ import torch from torch import Tensor from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter import sbi.utils as utils from sbi.inference.trainers.npe.npe_base import ( @@ -16,7 +17,7 @@ ConditionalEstimatorBuilder, ) from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries @@ -45,7 +46,8 @@ def __init__( ] = "maf", device: str = "cpu", logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NPE-B. @@ -64,8 +66,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during training. """ diff --git a/sbi/inference/trainers/npe/npe_base.py b/sbi/inference/trainers/npe/npe_base.py index 4c36f2919..397fd13ea 100644 --- a/sbi/inference/trainers/npe/npe_base.py +++ b/sbi/inference/trainers/npe/npe_base.py @@ -39,7 +39,7 @@ reshape_to_batch_event, reshape_to_sample_batch_event, ) -from sbi.sbi_types import TorchTransform +from sbi.sbi_types import TorchTransform, Tracker from sbi.utils import ( RestrictedPrior, check_estimator_arg, @@ -68,6 +68,7 @@ def __init__( device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): """Base class for Sequential Neural Posterior Estimation methods. @@ -94,6 +95,7 @@ def __init__( device=device, logging_level=logging_level, summary_writer=summary_writer, + tracker=tracker, show_progress_bars=show_progress_bars, ) diff --git a/sbi/inference/trainers/npe/npe_c.py b/sbi/inference/trainers/npe/npe_c.py index ac6e0acc1..d2fc018e7 100644 --- a/sbi/inference/trainers/npe/npe_c.py +++ b/sbi/inference/trainers/npe/npe_c.py @@ -8,6 +8,7 @@ from pyknos.nflows.transforms import CompositeTransform from torch import Tensor, eye, ones from torch.distributions import Distribution, MultivariateNormal, Uniform +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.posteriors.direct_posterior import DirectPosterior from sbi.inference.trainers.npe.npe_base import ( @@ -21,7 +22,7 @@ reshape_to_batch_event, reshape_to_sample_batch_event, ) -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils import ( batched_mixture_mv, batched_mixture_vmv, @@ -78,7 +79,8 @@ def __init__( ] = "maf", device: str = "cpu", logging_level: Union[int, str] = "WARNING", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NPE-C. @@ -97,8 +99,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard ``SummaryWriter`` to control, among others, - log file location (default is ``/logs``.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during training. """ diff --git a/sbi/inference/trainers/nre/bnre.py b/sbi/inference/trainers/nre/bnre.py index 668e96c40..31805f29a 100644 --- a/sbi/inference/trainers/nre/bnre.py +++ b/sbi/inference/trainers/nre/bnre.py @@ -6,12 +6,13 @@ import torch from torch import Tensor, nn, ones from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.trainers._contracts import LossArgs, LossArgsBNRE from sbi.inference.trainers.nre.nre_a import NRE_A from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries from sbi.utils.torchutils import assert_all_finite @@ -34,7 +35,8 @@ def __init__( classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet", device: str = "cpu", logging_level: Union[int, str] = "warning", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Balanced neural ratio estimation (BNRE). @@ -53,8 +55,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ diff --git a/sbi/inference/trainers/nre/nre_a.py b/sbi/inference/trainers/nre/nre_a.py index 464d77b90..84dfb8074 100644 --- a/sbi/inference/trainers/nre/nre_a.py +++ b/sbi/inference/trainers/nre/nre_a.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn, ones from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.trainers._contracts import LossArgsNRE_A from sbi.inference.trainers.nre.nre_base import ( @@ -13,7 +14,7 @@ ) from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries from sbi.utils.torchutils import assert_all_finite @@ -31,7 +32,8 @@ def __init__( classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet", device: str = "cpu", logging_level: Union[int, str] = "warning", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NRE_A. @@ -50,8 +52,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ diff --git a/sbi/inference/trainers/nre/nre_b.py b/sbi/inference/trainers/nre/nre_b.py index 698edda35..135c002cb 100644 --- a/sbi/inference/trainers/nre/nre_b.py +++ b/sbi/inference/trainers/nre/nre_b.py @@ -6,6 +6,7 @@ import torch from torch import Tensor from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.trainers._contracts import LossArgsNRE from sbi.inference.trainers.nre.nre_base import ( @@ -13,7 +14,7 @@ ) from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries from sbi.utils.torchutils import assert_all_finite @@ -31,7 +32,8 @@ def __init__( classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet", device: str = "cpu", logging_level: Union[int, str] = "warning", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NRE_B. @@ -50,8 +52,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ diff --git a/sbi/inference/trainers/nre/nre_base.py b/sbi/inference/trainers/nre/nre_base.py index 1c96a85fd..07b2cda0b 100644 --- a/sbi/inference/trainers/nre/nre_base.py +++ b/sbi/inference/trainers/nre/nre_base.py @@ -33,7 +33,7 @@ from sbi.neural_nets import classifier_nn from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator -from sbi.sbi_types import TorchTransform +from sbi.sbi_types import TorchTransform, Tracker from sbi.utils import ( check_estimator_arg, clamp_and_warn, @@ -49,6 +49,7 @@ def __init__( device: str = "cpu", logging_level: Union[int, str] = "warning", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Neural Ratio Estimation. @@ -85,6 +86,7 @@ def __init__( device=device, logging_level=logging_level, summary_writer=summary_writer, + tracker=tracker, show_progress_bars=show_progress_bars, ) diff --git a/sbi/inference/trainers/nre/nre_c.py b/sbi/inference/trainers/nre/nre_c.py index cec865fcc..0bb36852b 100644 --- a/sbi/inference/trainers/nre/nre_c.py +++ b/sbi/inference/trainers/nre/nre_c.py @@ -6,6 +6,7 @@ import torch from torch import Tensor from torch.distributions import Distribution +from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.trainers._contracts import LossArgs, LossArgsNRE_C from sbi.inference.trainers.nre.nre_base import ( @@ -13,7 +14,7 @@ ) from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator -from sbi.sbi_types import TensorBoardSummaryWriter +from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries from sbi.utils.torchutils import assert_all_finite @@ -45,7 +46,8 @@ def __init__( classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet", device: str = "cpu", logging_level: Union[int, str] = "warning", - summary_writer: Optional[TensorBoardSummaryWriter] = None, + summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NRE-C. @@ -64,8 +66,10 @@ def __init__( device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. - summary_writer: A tensorboard `SummaryWriter` to control, among others, log - file location (default is `/logs`.) + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ diff --git a/sbi/inference/trainers/vfpe/base_vf_inference.py b/sbi/inference/trainers/vfpe/base_vf_inference.py index d7aad91cd..131adf1ed 100644 --- a/sbi/inference/trainers/vfpe/base_vf_inference.py +++ b/sbi/inference/trainers/vfpe/base_vf_inference.py @@ -24,7 +24,7 @@ from sbi.inference.trainers.base import LossArgs from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder -from sbi.sbi_types import TorchTransform +from sbi.sbi_types import TorchTransform, Tracker from sbi.utils import ( check_estimator_arg, handle_invalid_x, @@ -48,6 +48,7 @@ def __init__( device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, **kwargs, ): @@ -68,7 +69,8 @@ def __init__( device: Device to run the training on. logging_level: Logging level for the training. Can be an integer or a string. - summary_writer: Tensorboard summary writer. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show progress bars during training. kwargs: Additional keyword arguments passed to the default builder if `vector_field_estimator_builder` is a string. @@ -79,6 +81,7 @@ def __init__( device=device, logging_level=logging_level, summary_writer=summary_writer, + tracker=tracker, show_progress_bars=show_progress_bars, ) diff --git a/sbi/inference/trainers/vfpe/fmpe.py b/sbi/inference/trainers/vfpe/fmpe.py index 16b4b914f..2e0994734 100644 --- a/sbi/inference/trainers/vfpe/fmpe.py +++ b/sbi/inference/trainers/vfpe/fmpe.py @@ -19,6 +19,7 @@ ConditionalVectorFieldEstimator, ) from sbi.neural_nets.factory import posterior_flow_nn +from sbi.sbi_types import Tracker class FMPE(VectorFieldTrainer): @@ -37,6 +38,7 @@ def __init__( device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, **kwargs, ) -> None: @@ -53,7 +55,10 @@ def __init__( warning is raised and the `vf_estimator="mlp"` default is used. device: Device to use for training. logging_level: Logging level. - summary_writer: Summary writer for tensorboard. + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show progress bars. **kwargs: Additional keyword arguments passed to the default builder if `density_estimator` is a string. @@ -73,6 +78,7 @@ def __init__( device=device, logging_level=logging_level, summary_writer=summary_writer, + tracker=tracker, show_progress_bars=show_progress_bars, vector_field_estimator_builder=vf_estimator, **kwargs, diff --git a/sbi/inference/trainers/vfpe/npse.py b/sbi/inference/trainers/vfpe/npse.py index 1b90d0115..846c06adf 100644 --- a/sbi/inference/trainers/vfpe/npse.py +++ b/sbi/inference/trainers/vfpe/npse.py @@ -15,6 +15,7 @@ from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.factory import posterior_score_nn +from sbi.sbi_types import Tracker class NPSE(VectorFieldTrainer): @@ -45,6 +46,7 @@ def __init__( device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, + tracker: Optional[Tracker] = None, show_progress_bars: bool = True, **kwargs, ): @@ -64,7 +66,10 @@ def __init__( device: Device to run the training on. logging_level: Logging level for the training. Can be an integer or a string. - summary_writer: Tensorboard summary writer. + summary_writer: Deprecated alias for the TensorBoard summary writer. + Use ``tracker`` instead. + tracker: Tracking adapter used to log training metrics. If None, a + TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show progress bars during training. kwargs: Additional keyword arguments passed to the default builder if `score_estimator` is a string. @@ -90,6 +95,7 @@ def __init__( device=device, logging_level=logging_level, summary_writer=summary_writer, + tracker=tracker, show_progress_bars=show_progress_bars, sde_type=sde_type, **kwargs, diff --git a/sbi/sbi_types.py b/sbi/sbi_types.py index 1d62e8eb4..7b01a05c4 100644 --- a/sbi/sbi_types.py +++ b/sbi/sbi_types.py @@ -1,7 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Optional, Protocol, Sequence, Tuple, TypeVar, Union +from typing import Any, Optional, Protocol, Sequence, Tuple, TypeVar, Union import numpy as np import torch @@ -56,6 +56,29 @@ class AcceptRejectFn(Protocol): def __call__(self, theta: Tensor) -> Tensor: ... +class Tracker(Protocol): + """Protocol for experiment tracking integrations.""" + + @property + def log_dir(self) -> Optional[str]: ... + + def log_metric( + self, name: str, value: float, step: Optional[int] = None + ) -> None: ... + + def log_metrics( + self, metrics: dict[str, float], step: Optional[int] = None + ) -> None: ... + + def log_params(self, params: dict[str, Any]) -> None: ... + + def add_figure( + self, name: str, figure: Any, step: Optional[int] = None + ) -> None: ... + + def flush(self) -> None: ... + + __all__ = [ "AcceptRejectFn", "Array", @@ -69,4 +92,5 @@ def __call__(self, theta: Tensor) -> Tensor: ... "TorchDistribution", "PyroTransformedDistribution", "TorchTensor", + "Tracker", ] diff --git a/sbi/utils/tracking.py b/sbi/utils/tracking.py new file mode 100644 index 000000000..095507d18 --- /dev/null +++ b/sbi/utils/tracking.py @@ -0,0 +1,46 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see + +from __future__ import annotations + +from typing import Any, Optional + +from torch.utils.tensorboard.writer import SummaryWriter + +from sbi.sbi_types import Tracker + + +class TensorBoardTracker: + """Adapter for TensorBoard SummaryWriter.""" + + def __init__(self, summary_writer: SummaryWriter) -> None: + self._writer = summary_writer + + @property + def log_dir(self) -> str: + return self._writer.log_dir + + def log_metric(self, name: str, value: float, step: Optional[int] = None) -> None: + self._writer.add_scalar(tag=name, scalar_value=value, global_step=step) + + def log_metrics( + self, metrics: dict[str, float], step: Optional[int] = None + ) -> None: + for name, value in metrics.items(): + self.log_metric(name=name, value=value, step=step) + + def log_params(self, params: dict[str, Any]) -> None: + for name, value in params.items(): + self._writer.add_text(tag=f"params/{name}", text_string=str(value)) + + def add_figure(self, name: str, figure: Any, step: Optional[int] = None) -> None: + self._writer.add_figure(tag=name, figure=figure, global_step=step) + + def flush(self) -> None: + self._writer.flush() + + +__all__ = [ + "TensorBoardTracker", + "Tracker", +] diff --git a/tests/plot_test.py b/tests/plot_test.py index 6f6d6769c..3a982ba62 100644 --- a/tests/plot_test.py +++ b/tests/plot_test.py @@ -20,6 +20,7 @@ ) from sbi.inference import NLE, NPE, NRE from sbi.utils import BoxUniform +from sbi.utils.tracking import TensorBoardTracker @pytest.mark.parametrize("samples", (torch.randn(100, 1),)) @@ -109,11 +110,12 @@ def test_plot_summary(method, tmp_path): num_simulations = 6 summary_writer = SummaryWriter(tmp_path) + tracker = TensorBoardTracker(summary_writer) def simulator(theta): return theta + 1.0 + torch.randn_like(theta) * 0.1 - inference = method(prior=prior, summary_writer=summary_writer) + inference = method(prior=prior, tracker=tracker) theta = prior.sample((num_simulations,)) x = simulator(theta)