Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,3 @@ target/
# uv
uv.lock
.python-version

# Serena cache
.serena/
.claude/
1 change: 1 addition & 0 deletions docs/how_to_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Training
how_to_guide/07_gpu_training.ipynb
how_to_guide/07_save_and_load.ipynb
how_to_guide/07_resume_training.ipynb
how_to_guide/22_experiment_tracking.ipynb


Sampling
Expand Down
306 changes: 306 additions & 0 deletions docs/how_to_guide/22_experiment_tracking.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7fb27b941602401d91542211134fc71a",
"metadata": {},
"source": [
"# How to track experiments"
]
},
{
"cell_type": "markdown",
"id": "9dcf97f2",
"metadata": {},
"source": [
"Experiment tracking helps compare model variants and keep a record of hyperparameters and training metrics. By default, `sbi` logs to TensorBoard. You can also bring your own tracker by implementing the lightweight `Tracker` protocol and passing it as `tracker=...`.\n",
"\n",
"If using your own tracker (e.g., `wandb`, `mlflow` or `trackio`), note that the run lifecycle (e.g., `wandb.init`, `mlflow.start_run`) is handled on the user side."
]
},
{
"cell_type": "markdown",
"id": "6492510e",
"metadata": {},
"source": [
"## Define a minimal training setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "898c316a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from sbi.inference import NPE\n",
"from sbi.neural_nets import posterior_nn\n",
"from sbi.neural_nets.embedding_nets import FCEmbedding\n",
"from sbi.utils import BoxUniform\n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"def simulator(theta):\n",
" return theta + 0.1 * torch.randn_like(theta)\n",
"\n",
"prior = BoxUniform(low=-2 * torch.ones(2), high=2 * torch.ones(2))\n",
"\n",
"theta = prior.sample((5000,))\n",
"x = simulator(theta)\n",
"\n",
"embedding_net = FCEmbedding(input_dim=x.shape[1], output_dim=32)\n",
"density_estimator = posterior_nn(\n",
" model=\"nsf\",\n",
" embedding_net=embedding_net,\n",
" num_transforms=4,\n",
")\n",
"\n",
"train_kwargs = dict(\n",
" max_num_epochs=50,\n",
" training_batch_size=128,\n",
" validation_fraction=0.1,\n",
" show_train_summary=False,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9142d4b6",
"metadata": {},
"source": [
"## Train with a tracker\n",
"\n",
"By default, `sbi` uses a TensorBoard tracker to log training loss, validation loss,\n",
"number of epochs and more. \n",
"\n",
"When you want to track additional quantities, you instantiate the tracker yourself and\n",
"pass it to the inference class:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a62bc10",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.tensorboard.writer import SummaryWriter\n",
"\n",
"from sbi.utils.tracking import TensorBoardTracker\n",
"\n",
"tracker = TensorBoardTracker(SummaryWriter(\"sbi-logs\"))\n",
"tracker.log_params({\"embedding_dim\": 32, \"num_transforms\": 4})\n",
"\n",
"inference = NPE(prior=prior, tracker=tracker)\n",
"inference.append_simulations(theta, x)\n",
"estimator = inference.train(**train_kwargs)\n",
"posterior = inference.build_posterior(estimator)"
]
},
{
"cell_type": "markdown",
"id": "c4da9894",
"metadata": {},
"source": [
"## View TensorBoard results\n",
"\n",
"You can then view your tracked run(s) on a TensorBoard shown on your localhost in the\n",
"browser. By default, `sbi` will create a log directory `sbi-logs` at the location the\n",
"training script was called.\n",
"\n",
"```bash\n",
"tensorboard --logdir=sbi-logs\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "6f2777dd",
"metadata": {},
"source": [
"## Using other trackers\n",
"\n",
"To enable usage of other trackers, we provide a lightweight `Protocol` that trackers\n",
"need to follow. You can implement a small adapter that satisfies the `Tracker` protocol\n",
"and pass it to `tracker=`. Below are minimal examples for common tools."
]
},
{
"cell_type": "markdown",
"id": "43644d68",
"metadata": {},
"source": [
"```python\n",
"# W&B adapter (requires `wandb.init()` before training)\n",
"class WandBAdapter:\n",
" log_dir = None\n",
"\n",
" def __init__(self, run):\n",
" self._run = run\n",
"\n",
" def log_metric(self, name, value, step=None):\n",
" self._run.log({name: value}, step=step)\n",
"\n",
" def log_metrics(self, metrics, step=None):\n",
" self._run.log(metrics, step=step)\n",
"\n",
" def log_params(self, params):\n",
" self._run.config.update(params)\n",
"\n",
" def add_figure(self, name, figure, step=None):\n",
" import wandb\n",
" self._run.log({name: wandb.Image(figure)}, step=step)\n",
"\n",
" def flush(self):\n",
" pass\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "031651b3",
"metadata": {},
"source": [
"```python\n",
"# MLflow adapter (configure tracking URI separately)\n",
"class MLflowAdapter:\n",
" log_dir = None\n",
"\n",
" def __init__(self, mlflow):\n",
" self._mlflow = mlflow\n",
"\n",
" def log_metric(self, name, value, step=None):\n",
" self._mlflow.log_metric(name, value, step=step)\n",
"\n",
" def log_metrics(self, metrics, step=None):\n",
" for name, value in metrics.items():\n",
" self.log_metric(name, value, step=step)\n",
"\n",
" def log_params(self, params):\n",
" self._mlflow.log_params(params)\n",
"\n",
" def add_figure(self, name, figure, step=None):\n",
" self._mlflow.log_figure(figure, f\"{name}.png\")\n",
"\n",
" def flush(self):\n",
" pass\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "6d891d85",
"metadata": {},
"source": [
"```python\n",
"# Trackio adapter (requires `trackio.init()` before training)\n",
"class TrackioAdapter:\n",
" log_dir = None\n",
"\n",
" def __init__(self, trackio):\n",
" self._trackio = trackio\n",
"\n",
" def log_metric(self, name, value, step=None):\n",
" self._trackio.log({name: value}, step=step)\n",
"\n",
" def log_metrics(self, metrics, step=None):\n",
" self._trackio.log(metrics, step=step)\n",
"\n",
" def log_params(self, params):\n",
" self._trackio.log(params)\n",
"\n",
" def add_figure(self, name, figure, step=None):\n",
" self._trackio.log_image(figure, name=name, step=step)\n",
"\n",
" def flush(self):\n",
" pass\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "43b453f7",
"metadata": {},
"source": [
"When using external trackers, create an adapter instance and pass it to `tracker=`:\n",
"\n",
"```python\n",
"# wandb.init(...)\n",
"tracker = WandBAdapter(wandb.run)\n",
"inference = NPE(prior=prior, density_estimator=density_estimator, tracker=tracker)\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "5c748c81",
"metadata": {},
"source": [
"## Log figures\n",
"\n",
"Trackers can also store matplotlib figures. For example, after training you can log a pairplot:\n",
"\n",
"```python\n",
"from sbi.analysis import pairplot\n",
"\n",
"x_o = x[:1]\n",
"samples = posterior.sample((1000,), x=x_o)\n",
"fig, _ = pairplot(samples)\n",
"tracker.add_figure(\"posterior_pairplot\", fig, step=0)\n",
"```\n",
"\n",
"Figure logging depends on the tracker implementation (e.g., `wandb.Image`, `mlflow.log_figure`)."
]
},
{
"cell_type": "markdown",
"id": "ab99ec92",
"metadata": {},
"source": [
"## Custom training loop (optional)\n",
"\n",
"If you want to log custom diagnostics per epoch, use the training interface tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/18_training_interface.html."
]
},
{
"cell_type": "markdown",
"id": "b206a6a7",
"metadata": {},
"source": [
"## Notes\n",
"\n",
"- Each tool supports richer logging (artifacts, checkpoints, plots), but the patterns above are enough to track hyperparameters, epoch-wise losses, and validation metrics.\n",
"- If you already use Optuna or other sweep tools, you can call the logger inside the objective function to log each trial."
]
},
{
"cell_type": "markdown",
"id": "46465002",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "sbi (3.12.9)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
11 changes: 8 additions & 3 deletions sbi/analysis/tensorboard_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def plot_summary(
ylabel: Optional[List[str]] = None,
plot_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Figure, Axes]:
"""Plots data logged by the tensorboard summary writer of an inference object.
"""Plots data logged by the TensorBoard tracker of an inference object.

Args:
inference: inference object that holds a ._summary_writer.log_dir attribute.
inference: inference object that holds a tracker with a log_dir attribute.
Optionally the log_dir itself.
tags: list of summery writer tags to visualize.
disable_tensorboard_prompt: flag to disable the logging of how to run
Expand All @@ -67,7 +67,12 @@ def plot_summary(
size_guidance.update(scalars=tensorboard_scalar_limit)

if isinstance(inference, NeuralInference):
log_dir = inference._summary_writer.log_dir
log_dir = getattr(inference._tracker, "log_dir", None)
if log_dir is None:
raise ValueError(
"Inference tracker does not expose a log_dir. "
"Use a TensorBoard tracker or pass a log directory directly."
)
elif isinstance(inference, Path):
log_dir = inference
else:
Expand Down
Loading
Loading