From 7a07a0729c731e4586470f81e8e75e10c3752db4 Mon Sep 17 00:00:00 2001 From: dgedon Date: Fri, 6 Feb 2026 16:41:56 +0100 Subject: [PATCH] new abstraction level guide --- docs/how_to_guide.rst | 1 + docs/how_to_guide/22_abstraction_levels.ipynb | 1390 +++++++++++++++++ 2 files changed, 1391 insertions(+) create mode 100644 docs/how_to_guide/22_abstraction_levels.ipynb diff --git a/docs/how_to_guide.rst b/docs/how_to_guide.rst index 9af9aa1ad..805a533df 100644 --- a/docs/how_to_guide.rst +++ b/docs/how_to_guide.rst @@ -40,6 +40,7 @@ Training .. toctree:: :maxdepth: 1 + how_to_guide/22_abstraction_levels.ipynb how_to_guide/06_choosing_inference_method.ipynb how_to_guide/02_multiround_inference.ipynb how_to_guide/07_gpu_training.ipynb diff --git a/docs/how_to_guide/22_abstraction_levels.ipynb b/docs/how_to_guide/22_abstraction_levels.ipynb new file mode 100644 index 000000000..86408c792 --- /dev/null +++ b/docs/how_to_guide/22_abstraction_levels.ipynb @@ -0,0 +1,1390 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# API abstraction levels in sbi" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`sbi` offers flexibility ranging from simple, high-level workflows to full control over neural networks and sampling. This guide shows:\n", + "\n", + "1. **Four abstraction levels** for controlling the density estimator (common to NPE and NLE)\n", + "2. **Additional sampling control** for NLE (4 more levels)\n", + "\n", + "We'll use the same simple example throughout to keep things clear." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, let's define a simple linear Gaussian simulator and generate data we'll use for all examples:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated 2000 simulations for training\n", + "Parameter shape: torch.Size([2000, 3]), Data shape: torch.Size([2000, 3])\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "from sbi.inference import NLE, NPE\n", + "from sbi.utils import BoxUniform\n", + "\n", + "\n", + "# Define a simple linear Gaussian simulator\n", + "def simulator(theta):\n", + " \"\"\"Linear Gaussian simulator with noise.\"\"\"\n", + " return theta + 1.0 + torch.randn_like(theta) * 0.1\n", + "\n", + "# Define prior over 3 parameters\n", + "num_dim = 3\n", + "prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))\n", + "\n", + "# Generate training data (used for all examples)\n", + "num_simulations = 2000\n", + "theta = prior.sample((num_simulations,))\n", + "x = simulator(theta)\n", + "\n", + "# Generate a single observation for inference\n", + "theta_o = prior.sample((1,))\n", + "x_o = simulator(theta_o)\n", + "\n", + "print(f\"Generated {num_simulations} simulations for training\")\n", + "print(f\"Parameter shape: {theta.shape}, Data shape: {x.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Part 1: Density Estimator Abstraction Levels\n", + "\n", + "The following **4 levels apply to both NPE and NLE**. They control how the neural density estimator is specified and constructed. We'll demonstrate with NPE first." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Level 1: Trainer Classes (Recommended)\n", + "\n", + "**Use case**: Standard workflows, most common approach\n", + "\n", + "The trainer classes provide the recommended interface with string-based customization." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Neural network successfully converged after 100 epochs." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1094it [00:00, 50601.21it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Level 1 complete - used NSF with default settings\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Level 1: Simple trainer class with string specification\n", + "inference = NPE(prior=prior, density_estimator=\"nsf\")\n", + "\n", + "# Train on the data\n", + "inference.append_simulations(theta, x)\n", + "posterior_net = inference.train()\n", + "\n", + "# Build posterior and sample\n", + "posterior = inference.build_posterior()\n", + "samples_lvl1 = posterior.sample((1000,), x=x_o)\n", + "\n", + "print(\"Level 1 complete - used NSF with default settings\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Key features**:\n", + "- Simple string specification: `\"nsf\"`, `\"maf\"`, `\"zuko_nsf\"`, `\"mdn\"`, etc.\n", + "- Multi-round inference support\n", + "- Automatic handling of training loops\n", + "- **Start here** for most use cases" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Level 2: Factory Functions\n", + "\n", + "**Use case**: Need specific architecture hyperparameters\n", + "\n", + "Use factory functions like `posterior_nn()` when you need to tune the network architecture." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Neural network successfully converged after 64 epochs." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1000/1000 [00:00<00:00, 62516.64it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Level 2 complete - used MAF with custom hyperparameters\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from sbi.neural_nets import posterior_nn\n", + "\n", + "# Level 2: Factory function with custom hyperparameters\n", + "density_estimator = posterior_nn(\n", + " model=\"maf\", # Masked Autoregressive Flow\n", + " hidden_features=50, # Customize hidden layer size\n", + " num_transforms=5, # Customize number of transform layers\n", + ")\n", + "\n", + "# Pass to NPE (rest of workflow is the same)\n", + "inference = NPE(prior=prior, density_estimator=density_estimator)\n", + "inference.append_simulations(theta, x)\n", + "posterior_net = inference.train()\n", + "\n", + "posterior = inference.build_posterior()\n", + "samples_lvl2 = posterior.sample((1000,), x=x_o)\n", + "\n", + "print(\"Level 2 complete - used MAF with custom hyperparameters\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Key features**:\n", + "- Fine-grained control over hyperparameters\n", + "- Can add embedding networks for high-dimensional data\n", + "- Still benefits from trainer conveniences\n", + "- For NPE: `posterior_nn()`, for NLE: `likelihood_nn()`, for NRE: `classifier_nn()`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Level 3: Direct Network Builders\n", + "\n", + "**Use case**: Custom neural network architecture with full parameter access\n", + "\n", + "Use direct builder functions like `build_nsf()` for maximum control over network construction." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Neural network successfully converged after 148 epochs." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1092it [00:00, 72132.23it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Level 3 complete - used custom NSF configuration\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "\n", + "from sbi.neural_nets.net_builders.flow import build_nsf\n", + "\n", + "# Level 3: Direct builder with full parameter control\n", + "custom_builder = partial(\n", + " build_nsf,\n", + " hidden_features=60,\n", + " num_transforms=3,\n", + " num_bins=8, # Number of spline bins\n", + " tail_bound=3.0, # Spline tail bound\n", + ")\n", + "\n", + "# Pass to NPE (rest of workflow is the same)\n", + "inference = NPE(prior=prior, density_estimator=custom_builder)\n", + "inference.append_simulations(theta, x)\n", + "posterior_net = inference.train()\n", + "\n", + "posterior = inference.build_posterior()\n", + "samples_lvl3 = posterior.sample((1000,), x=x_o)\n", + "\n", + "print(\"Level 3 complete - used custom NSF configuration\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Key features**:\n", + "- Direct access to all builder parameters\n", + "- Maximum flexibility for architecture design\n", + "- Can implement fully custom architectures by subclassing `DensityEstimator`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Level 4: Custom Training Loops\n", + "\n", + "**Use case**: Custom training logic, loss functions, research applications\n", + "\n", + "For complete control over the training process, implement custom training loops. This is covered in detail in [advanced tutorial 18](https://sbi.readthedocs.io/en/latest/advanced_tutorials/18_training_interface.html).\n", + "\n", + "At this level, you:\n", + "- Manually construct the density estimator\n", + "- Define custom loss functions and regularization\n", + "- Implement your own training loops with custom data loaders\n", + "- Have full control over optimization, early stopping, etc.\n", + "\n", + "**When to use**: Research on new methods, custom loss functions, specialized data augmentation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Part 2: NLE - Same Levels + Sampling Control\n", + "\n", + "## Understanding the Difference\n", + "\n", + "**NPE** directly approximates the posterior $p(\\theta|x)$:\n", + "- Sampling is straightforward: just sample from the neural network\n", + "- No additional configuration typically needed\n", + "\n", + "**NLE** approximates the likelihood $p(x|\\theta)$:\n", + "- Must combine with prior using MCMC, VI, or rejection sampling to get posterior samples\n", + "- This adds a **second dimension of control**: choosing and configuring the sampling method\n", + "\n", + "**Important**: The 4 density estimator levels above work exactly the same for NLE - just use `likelihood_nn()` instead of `posterior_nn()` at Level 2." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NLE Density Estimator (Same 4 Levels)\n", + "\n", + "Quick example showing NLE uses the same abstraction levels:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Neural network successfully converged after 78 epochs." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating 20 MCMC inits via resample strategy: 100%|██████████| 20/20 [00:01<00:00, 14.46it/s]\n", + "Running vectorized MCMC with 20 chains: 100%|██████████| 6000/6000 [00:19<00:00, 301.30it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NLE Level 1 complete\n", + "Default sampling method: MCMC with slice_np_vectorized\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Level 1 with NLE - same pattern as NPE\n", + "inference_nle = NLE(prior=prior, density_estimator=\"nsf\")\n", + "inference_nle.append_simulations(theta, x)\n", + "likelihood_net = inference_nle.train()\n", + "\n", + "# Build posterior (defaults to MCMC)\n", + "posterior_nle = inference_nle.build_posterior()\n", + "samples_nle = posterior_nle.sample((1000,), x=x_o)\n", + "\n", + "print(\"NLE Level 1 complete\")\n", + "print(\"Default sampling method: MCMC with slice_np_vectorized\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: Levels 2-4 for the density estimator work identically:\n", + "- Level 2: Use `likelihood_nn()` instead of `posterior_nn()`\n", + "- Level 3: Use `build_nsf()` (same as NPE)\n", + "- Level 4: Custom training (see tutorial 18)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Part 3: NLE Sampling Control\n", + "\n", + "NLE provides **additional control over how posterior samples are generated**. This is independent of the density estimator configuration above.\n", + "\n", + "Four levels of sampling control:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling Level 1: Default\n", + "\n", + "**Use case**: Starting point, works well for most problems\n", + "\n", + "Just call `build_posterior()` with no arguments - uses slice sampling by default." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating 20 MCMC inits via resample strategy: 100%|██████████| 20/20 [00:01<00:00, 15.70it/s]\n", + "Running vectorized MCMC with 20 chains: 100%|██████████| 6000/6000 [00:19<00:00, 306.88it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sampling Level 1: Default MCMC (slice_np_vectorized)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Sampling Level 1: Use defaults\n", + "posterior = inference_nle.build_posterior()\n", + "samples = posterior.sample((1000,), x=x_o)\n", + "\n", + "print(\"Sampling Level 1: Default MCMC (slice_np_vectorized)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Default behavior**: MCMC with `slice_np_vectorized` method, 200 warmup steps, 20 chains." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling Level 2: Choose Method\n", + "\n", + "**Use case**: Different problem characteristics favor different sampling methods\n", + "\n", + "Use the `sample_with` parameter to choose between MCMC, rejection sampling, VI, or importance sampling." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Drawing 1000 posterior samples: 0%| | 0/1000 [00:0010), may be less accurate\n", + "- `\"importance\"`: Importance sampling - useful for refining VI posteriors\n", + "\n", + "**Usage**: Simply change `sample_with=\"rejection\"` to `sample_with=\"vi\"` or any other method.\n", + "\n", + "**See also**: [how_to_guide/09_sampler_interface.ipynb](https://sbi.readthedocs.io/en/latest/how_to_guide/09_sampler_interface.html) for detailed guidance on choosing sampling algorithms." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling Level 3: Configure Method Specifics\n", + "\n", + "**Use case**: Choose specific algorithms within a sampling method\n", + "\n", + "Use `mcmc_method` or `vi_method` parameters to select specific algorithms." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danielgedon/Dropbox/05_Postdoc/projects/sbi/sbi/inference/trainers/base.py:577: FutureWarning: The following arguments are deprecated and will be removed in a future version: mcmc_method. Please use `posterior_parameters` instead. Refer to this guide for details:\n", + "https://sbi.readthedocs.io/en/latest/how_to_guide/19_posterior_parameters.html#\n", + " self._raise_deprecation_warning(deprecated_params, **kwargs)\n", + "Generating 20 MCMC inits via resample strategy: 100%|██████████| 20/20 [00:01<00:00, 15.86it/s]\n", + "/Users/danielgedon/Dropbox/05_Postdoc/projects/sbi/.venv/lib/python3.12/site-packages/pyro/infer/mcmc/api.py:499: UserWarning: num_chains=20 is more than available_cpu=7. Chains will be drawn sequentially.\n", + " warnings.warn(\n", + "Sample [0]: 100%|██████████| 250/250 [00:14, 17.72it/s, step size=6.63e-01, acc. prob=0.907]\n", + "Sample [1]: 100%|██████████| 250/250 [00:14, 17.27it/s, step size=7.33e-01, acc. prob=0.892]\n", + "Sample [2]: 100%|██████████| 250/250 [00:12, 19.88it/s, step size=7.17e-01, acc. prob=0.905]\n", + "Sample [3]: 100%|██████████| 250/250 [00:13, 17.96it/s, step size=6.11e-01, acc. prob=0.889]\n", + "Sample [4]: 100%|██████████| 250/250 [00:14, 17.42it/s, step size=7.32e-01, acc. prob=0.905]\n", + "Sample [5]: 100%|██████████| 250/250 [00:13, 18.08it/s, step size=7.03e-01, acc. prob=0.862]\n", + "Sample [6]: 100%|██████████| 250/250 [00:15, 15.72it/s, step size=6.64e-01, acc. prob=0.909]\n", + "Sample [7]: 100%|██████████| 250/250 [00:12, 19.67it/s, step size=8.07e-01, acc. prob=0.892]\n", + "Sample [8]: 100%|██████████| 250/250 [00:13, 18.29it/s, step size=8.93e-01, acc. prob=0.907]\n", + "Sample [9]: 100%|██████████| 250/250 [00:14, 17.19it/s, step size=5.45e-01, acc. prob=0.908]\n", + "Sample [10]: 100%|██████████| 250/250 [00:14, 17.07it/s, step size=7.45e-01, acc. prob=0.907]\n", + "Sample [11]: 100%|██████████| 250/250 [00:13, 18.42it/s, step size=7.69e-01, acc. prob=0.866]\n", + "Sample [12]: 100%|██████████| 250/250 [00:13, 18.10it/s, step size=8.39e-01, acc. prob=0.879]\n", + "Sample [13]: 100%|██████████| 250/250 [00:14, 16.97it/s, step size=7.64e-01, acc. prob=0.863]\n", + "Sample [14]: 100%|██████████| 250/250 [00:16, 15.62it/s, step size=5.90e-01, acc. prob=0.909]\n", + "Sample [15]: 100%|██████████| 250/250 [00:13, 17.93it/s, step size=7.72e-01, acc. prob=0.888]\n", + "Sample [16]: 100%|██████████| 250/250 [00:13, 17.88it/s, step size=8.95e-01, acc. prob=0.886]\n", + "Sample [17]: 100%|██████████| 250/250 [00:14, 17.84it/s, step size=7.79e-01, acc. prob=0.896]\n", + "Sample [18]: 100%|██████████| 250/250 [00:17, 14.09it/s, step size=4.37e-01, acc. prob=0.923]\n", + "Sample [19]: 100%|██████████| 250/250 [00:13, 18.77it/s, step size=6.80e-01, acc. prob=0.920]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sampling Level 3: Using NUTS from Pyro\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Sampling Level 3: Configure method specifics\n", + "# Use NUTS (No-U-Turn Sampler) instead of default slice sampling\n", + "posterior_nuts = inference_nle.build_posterior(\n", + " sample_with=\"mcmc\",\n", + " mcmc_method=\"nuts_pyro\"\n", + ")\n", + "samples_nuts = posterior_nuts.sample((1000,), x=x_o)\n", + "\n", + "print(\"Sampling Level 3: Using NUTS from Pyro\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Available MCMC methods** (use with `mcmc_method=`):\n", + "- `\"slice_np_vectorized\"`: Slice sampling (numpy, vectorized, **default**)\n", + "- `\"slice_np\"`: Slice sampling (numpy, sequential)\n", + "- `\"nuts_pyro\"`: No-U-Turn Sampler (Pyro)\n", + "- `\"hmc_pyro\"`: Hamiltonian Monte Carlo (Pyro)\n", + "- `\"slice_pymc\"`, `\"hmc_pymc\"`, `\"nuts_pymc\"`: PyMC samplers\n", + "\n", + "**Available VI methods** (use with `vi_method=`):\n", + "- `\"rKL\"`: Reverse KL divergence (mode-seeking, **default**)\n", + "- `\"fKL\"`: Forward KL divergence (mass-covering)\n", + "- `\"IW\"`: Importance weighted\n", + "- `\"alpha\"`: Alpha divergence\n", + "\n", + "**Usage**: Change `mcmc_method=\"nuts_pyro\"` to any other MCMC method, or use `vi_method=\"fKL\"` when `sample_with=\"vi\"`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampling Level 4: Fine-Tune Parameters\n", + "\n", + "**Use case**: Optimize sampling performance, troubleshoot convergence issues\n", + "\n", + "Fine-tune sampling parameters using dictionaries or `PosteriorParameters` dataclasses." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sampling Level 4a: Dictionary-based parameter tuning\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danielgedon/Dropbox/05_Postdoc/projects/sbi/sbi/inference/trainers/base.py:577: FutureWarning: The following arguments are deprecated and will be removed in a future version: mcmc_parameters. Please use `posterior_parameters` instead. Refer to this guide for details:\n", + "https://sbi.readthedocs.io/en/latest/how_to_guide/19_posterior_parameters.html#\n", + " self._raise_deprecation_warning(deprecated_params, **kwargs)\n" + ] + } + ], + "source": [ + "# Sampling Level 4a: Using parameter dictionaries\n", + "posterior_tuned = inference_nle.build_posterior(\n", + " sample_with=\"mcmc\",\n", + " mcmc_method=\"slice_np_vectorized\",\n", + " mcmc_parameters={\n", + " \"warmup_steps\": 100, # Burn-in samples to discard\n", + " \"num_chains\": 4, # Number of parallel chains\n", + " \"thin\": 2, # Thinning factor\n", + " \"num_workers\": 2, # CPU cores for parallelization\n", + " }\n", + ")\n", + "\n", + "print(\"Sampling Level 4a: Dictionary-based parameter tuning\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/danielgedon/Dropbox/05_Postdoc/projects/sbi/sbi/inference/posteriors/mcmc_posterior.py:626: UserWarning: As of sbi v0.19.0, the behavior of the SIR initialization for MCMC has changed. If you wish to restore the behavior of sbi v0.18.0, set `init_strategy='resample'.`\n", + " init_fn = self._build_mcmc_init_fn(\n", + "Generating 4 MCMC inits via sir strategy: 100%|██████████| 4/4 [00:04<00:00, 1.09s/it]\n", + "Warmup [1]: 0%| | 0/350 [00:00, ?it/s]\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 0%| | 1/350 [00:04, 4.76s/it, step size=1.87e+00, acc. prob=0.641]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 1%| | 4/350 [00:05, 1.01s/it, step size=3.75e-02, acc. prob=0.433]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 1%|▏ | 5/350 [00:05, 1.27it/s, step size=5.58e-02, acc. prob=0.545]\n", + "Warmup [1]: 2%|▏ | 7/350 [00:05, 2.10it/s, step size=7.14e-02, acc. prob=0.637]\n", + "Warmup [1]: 3%|▎ | 10/350 [00:05, 3.69it/s, step size=1.48e-01, acc. prob=0.711]\n", + "\u001b[A\n", + "Warmup [1]: 4%|▎ | 13/350 [00:06, 4.49it/s, step size=3.73e-02, acc. prob=0.697]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "Warmup [1]: 4%|▍ | 14/350 [00:06, 4.57it/s, step size=5.04e-02, acc. prob=0.711]\n", + "Warmup [1]: 5%|▍ | 16/350 [00:06, 5.85it/s, step size=1.56e-01, acc. prob=0.744]\n", + "\u001b[A\n", + "\u001b[A\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 5%|▌ | 18/350 [00:06, 5.51it/s, step size=2.18e-02, acc. prob=0.713]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 5%|▌ | 19/350 [00:07, 4.51it/s, step size=4.12e-02, acc. prob=0.728]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 6%|▌ | 20/350 [00:07, 4.62it/s, step size=7.23e-02, acc. prob=0.741]\n", + "Warmup [1]: 6%|▋ | 22/350 [00:07, 6.23it/s, step size=1.34e-01, acc. prob=0.755]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 7%|▋ | 25/350 [00:07, 7.77it/s, step size=4.46e-02, acc. prob=0.744]\n", + "Warmup [1]: 7%|▋ | 26/350 [00:08, 7.09it/s, step size=7.82e-02, acc. prob=0.753]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 8%|▊ | 27/350 [00:08, 7.48it/s, step size=9.55e-02, acc. prob=0.757]\n", + "Warmup [1]: 8%|▊ | 28/350 [00:08, 7.18it/s, step size=1.66e-01, acc. prob=0.765]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 9%|▉ | 31/350 [00:08, 8.68it/s, step size=5.53e-02, acc. prob=0.755]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 9%|▉ | 32/350 [00:08, 7.62it/s, step size=9.26e-02, acc. prob=0.762]\n", + "\n", + "Warmup [1]: 10%|▉ | 34/350 [00:08, 9.40it/s, step size=2.23e-01, acc. prob=0.773]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 11%|█ | 37/350 [00:09, 12.41it/s, step size=5.46e-02, acc. prob=0.760]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 12%|█▏ | 41/350 [00:09, 12.88it/s, step size=1.22e-01, acc. prob=0.771]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 12%|█▏ | 43/350 [00:09, 10.43it/s, step size=8.42e-02, acc. prob=0.768]\n", + "\n", + "Warmup [1]: 13%|█▎ | 45/350 [00:09, 10.66it/s, step size=1.52e-01, acc. prob=0.774]\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 14%|█▎ | 48/350 [00:10, 13.12it/s, step size=9.03e-02, acc. prob=0.771]\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 14%|█▍ | 50/350 [00:10, 13.51it/s, step size=6.72e-02, acc. prob=0.769]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 15%|█▍ | 52/350 [00:10, 12.72it/s, step size=1.24e-01, acc. prob=0.775]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "Warmup [1]: 15%|█▌ | 54/350 [00:10, 9.76it/s, step size=4.23e-02, acc. prob=0.767]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 16%|█▌ | 56/350 [00:10, 8.85it/s, step size=8.95e-02, acc. prob=0.774]\n", + "Warmup [1]: 17%|█▋ | 58/350 [00:11, 10.10it/s, step size=2.17e-01, acc. prob=0.781]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 17%|█▋ | 61/350 [00:11, 13.05it/s, step size=9.63e-02, acc. prob=0.776]\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 18%|█▊ | 63/350 [00:11, 12.40it/s, step size=1.41e-01, acc. prob=0.779]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 19%|█▉ | 66/350 [00:11, 9.58it/s, step size=6.30e-02, acc. prob=0.774]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 19%|█▉ | 68/350 [00:12, 8.77it/s, step size=6.50e-02, acc. prob=0.775]\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 20%|██ | 70/350 [00:12, 9.29it/s, step size=1.16e-01, acc. prob=0.779]\n", + "\u001b[A\n", + "\n", + "Warmup [1]: 21%|██ | 72/350 [00:12, 9.75it/s, step size=1.42e-01, acc. prob=0.781]\n", + "Warmup [1]: 21%|██ | 74/350 [00:12, 11.26it/s, step size=4.30e-02, acc. prob=0.773]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "Warmup [1]: 22%|██▏ | 76/350 [00:12, 11.25it/s, step size=7.31e-02, acc. prob=0.777]\n", + "Warmup [1]: 22%|██▏ | 78/350 [00:12, 12.16it/s, step size=1.64e-01, acc. prob=0.783]\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "Warmup [1]: 23%|██▎ | 81/350 [00:13, 13.91it/s, step size=1.61e-01, acc. prob=0.783]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "Warmup [1]: 25%|██▍ | 87/350 [00:13, 17.56it/s, step size=1.19e-01, acc. prob=0.782]\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "Warmup [1]: 26%|██▌ | 91/350 [00:13, 15.08it/s, step size=3.20e+00, acc. prob=0.778]\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 27%|██▋ | 93/350 [00:13, 15.08it/s, step size=3.62e-01, acc. prob=0.772]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Warmup [1]: 27%|██▋ | 95/350 [00:14, 11.43it/s, step size=8.92e-01, acc. prob=0.776]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 29%|██▊ | 100/350 [00:14, 15.23it/s, step size=1.02e+00, acc. prob=0.780]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 29%|██▉ | 102/350 [00:14, 10.90it/s, step size=1.02e+00, acc. prob=0.984]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 30%|███ | 105/350 [00:14, 12.71it/s, step size=1.02e+00, acc. prob=0.826]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 31%|███ | 107/350 [00:14, 13.13it/s, step size=1.02e+00, acc. prob=0.836]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 32%|███▏ | 111/350 [00:15, 15.56it/s, step size=1.02e+00, acc. prob=0.890]\n", + "Sample [1]: 33%|███▎ | 114/350 [00:15, 17.44it/s, step size=1.02e+00, acc. prob=0.903]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 33%|███▎ | 117/350 [00:15, 17.13it/s, step size=1.02e+00, acc. prob=0.904]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "Sample [1]: 34%|███▍ | 120/350 [00:15, 17.19it/s, step size=1.02e+00, acc. prob=0.904]\n", + "Sample [1]: 35%|███▍ | 122/350 [00:15, 16.62it/s, step size=1.02e+00, acc. prob=0.903]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 35%|███▌ | 124/350 [00:15, 13.38it/s, step size=1.02e+00, acc. prob=0.898]\n", + "\n", + "Sample [1]: 36%|███▌ | 126/350 [00:16, 13.80it/s, step size=1.02e+00, acc. prob=0.905]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 37%|███▋ | 128/350 [00:16, 13.85it/s, step size=1.02e+00, acc. prob=0.909]\n", + "Sample [1]: 37%|███▋ | 131/350 [00:16, 16.49it/s, step size=1.02e+00, acc. prob=0.902]\n", + "\n", + "Sample [1]: 38%|███▊ | 134/350 [00:16, 18.38it/s, step size=1.02e+00, acc. prob=0.905]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 39%|███▉ | 136/350 [00:16, 17.24it/s, step size=1.02e+00, acc. prob=0.907]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 39%|███▉ | 138/350 [00:16, 12.22it/s, step size=1.02e+00, acc. prob=0.908]\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 40%|████ | 140/350 [00:17, 11.81it/s, step size=1.02e+00, acc. prob=0.905]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 41%|████▏ | 145/350 [00:17, 13.75it/s, step size=1.02e+00, acc. prob=0.896]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 42%|████▏ | 147/350 [00:17, 14.11it/s, step size=1.02e+00, acc. prob=0.900]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 43%|████▎ | 149/350 [00:17, 14.15it/s, step size=1.02e+00, acc. prob=0.901]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 43%|████▎ | 151/350 [00:18, 11.01it/s, step size=1.02e+00, acc. prob=0.896]\n", + "Sample [1]: 44%|████▎ | 153/350 [00:18, 11.98it/s, step size=1.02e+00, acc. prob=0.896]\n", + "\n", + "Sample [1]: 45%|████▍ | 156/350 [00:18, 14.77it/s, step size=1.02e+00, acc. prob=0.896]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 45%|████▌ | 159/350 [00:18, 16.99it/s, step size=1.02e+00, acc. prob=0.890]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 46%|████▌ | 161/350 [00:18, 16.36it/s, step size=1.02e+00, acc. prob=0.890]\n", + "Sample [1]: 47%|████▋ | 164/350 [00:18, 18.30it/s, step size=1.02e+00, acc. prob=0.893]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 47%|████▋ | 166/350 [00:18, 17.22it/s, step size=1.02e+00, acc. prob=0.892]\n", + "\n", + "Sample [1]: 48%|████▊ | 168/350 [00:18, 16.46it/s, step size=1.02e+00, acc. prob=0.894]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 49%|████▉ | 171/350 [00:19, 18.48it/s, step size=1.02e+00, acc. prob=0.894]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 49%|████▉ | 173/350 [00:19, 15.66it/s, step size=1.02e+00, acc. prob=0.892]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 50%|█████ | 176/350 [00:19, 17.82it/s, step size=1.02e+00, acc. prob=0.896]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 51%|█████ | 178/350 [00:19, 15.32it/s, step size=1.02e+00, acc. prob=0.895]\n", + "\n", + "Sample [1]: 51%|█████▏ | 180/350 [00:19, 13.21it/s, step size=1.02e+00, acc. prob=0.894]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "Sample [1]: 52%|█████▏ | 182/350 [00:20, 10.53it/s, step size=1.02e+00, acc. prob=0.889]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 53%|█████▎ | 185/350 [00:20, 12.13it/s, step size=1.02e+00, acc. prob=0.889]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 54%|█████▎ | 188/350 [00:20, 13.52it/s, step size=1.02e+00, acc. prob=0.886]\n", + "Sample [1]: 54%|█████▍ | 190/350 [00:20, 13.79it/s, step size=1.02e+00, acc. prob=0.888]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 55%|█████▍ | 192/350 [00:20, 14.01it/s, step size=1.02e+00, acc. prob=0.885]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "Sample [1]: 55%|█████▌ | 194/350 [00:20, 11.89it/s, step size=1.02e+00, acc. prob=0.881]\n", + "Sample [1]: 56%|█████▋ | 197/350 [00:21, 14.64it/s, step size=1.02e+00, acc. prob=0.882]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 57%|█████▋ | 199/350 [00:21, 14.59it/s, step size=1.02e+00, acc. prob=0.883]\n", + "Sample [1]: 58%|█████▊ | 202/350 [00:21, 16.96it/s, step size=1.02e+00, acc. prob=0.882]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 58%|█████▊ | 204/350 [00:21, 16.25it/s, step size=1.02e+00, acc. prob=0.881]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 59%|█████▉ | 206/350 [00:21, 15.77it/s, step size=1.02e+00, acc. prob=0.881]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 60%|██████ | 211/350 [00:21, 17.95it/s, step size=1.02e+00, acc. prob=0.878]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 61%|██████ | 213/350 [00:21, 16.98it/s, step size=1.02e+00, acc. prob=0.877]\n", + "Sample [1]: 62%|██████▏ | 216/350 [00:22, 14.28it/s, step size=1.02e+00, acc. prob=0.877]\n", + "Sample [1]: 63%|██████▎ | 219/350 [00:22, 16.80it/s, step size=1.02e+00, acc. prob=0.878]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 63%|██████▎ | 222/350 [00:22, 16.93it/s, step size=1.02e+00, acc. prob=0.874]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 64%|██████▍ | 225/350 [00:22, 18.52it/s, step size=1.02e+00, acc. prob=0.874]\n", + "\n", + "Sample [1]: 65%|██████▍ | 227/350 [00:22, 17.58it/s, step size=1.02e+00, acc. prob=0.875]\n", + "Sample [1]: 65%|██████▌ | 229/350 [00:22, 16.92it/s, step size=1.02e+00, acc. prob=0.874]\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "Sample [1]: 66%|██████▌ | 231/350 [00:23, 16.23it/s, step size=1.02e+00, acc. prob=0.876]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 67%|██████▋ | 233/350 [00:23, 12.94it/s, step size=1.02e+00, acc. prob=0.877]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 67%|██████▋ | 235/350 [00:23, 12.14it/s, step size=1.02e+00, acc. prob=0.877]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 68%|██████▊ | 237/350 [00:23, 11.64it/s, step size=1.02e+00, acc. prob=0.877]\n", + "\n", + "Sample [1]: 69%|██████▊ | 240/350 [00:23, 14.50it/s, step size=1.02e+00, acc. prob=0.878]\n", + "Sample [1]: 69%|██████▉ | 243/350 [00:23, 15.48it/s, step size=1.02e+00, acc. prob=0.877]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 70%|███████ | 245/350 [00:24, 16.06it/s, step size=1.02e+00, acc. prob=0.878]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 71%|███████ | 247/350 [00:24, 15.68it/s, step size=1.02e+00, acc. prob=0.874]\n", + "\n", + "Sample [1]: 71%|███████ | 249/350 [00:24, 16.31it/s, step size=1.02e+00, acc. prob=0.875]\n", + "Sample [1]: 72%|███████▏ | 252/350 [00:24, 18.34it/s, step size=1.02e+00, acc. prob=0.875]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 73%|███████▎ | 254/350 [00:24, 17.20it/s, step size=1.02e+00, acc. prob=0.873]\n", + "\n", + "Sample [1]: 73%|███████▎ | 256/350 [00:24, 16.60it/s, step size=1.02e+00, acc. prob=0.873]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 74%|███████▎ | 258/350 [00:24, 16.11it/s, step size=1.02e+00, acc. prob=0.874]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 75%|███████▍ | 261/350 [00:24, 18.30it/s, step size=1.02e+00, acc. prob=0.874]\n", + "\n", + "Sample [1]: 75%|███████▌ | 264/350 [00:25, 20.08it/s, step size=1.02e+00, acc. prob=0.870]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 77%|███████▋ | 268/350 [00:25, 20.64it/s, step size=1.02e+00, acc. prob=0.869]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 77%|███████▋ | 271/350 [00:25, 19.46it/s, step size=1.02e+00, acc. prob=0.867]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 78%|███████▊ | 274/350 [00:25, 18.51it/s, step size=1.02e+00, acc. prob=0.868]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 79%|███████▉ | 276/350 [00:25, 17.58it/s, step size=1.02e+00, acc. prob=0.868]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 80%|████████ | 281/350 [00:26, 18.04it/s, step size=1.02e+00, acc. prob=0.866]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 81%|████████ | 283/350 [00:26, 17.27it/s, step size=1.02e+00, acc. prob=0.867]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 82%|████████▏ | 286/350 [00:26, 18.89it/s, step size=1.02e+00, acc. prob=0.869]\n", + "\n", + "Sample [1]: 83%|████████▎ | 289/350 [00:26, 20.58it/s, step size=1.02e+00, acc. prob=0.870]\n", + "Sample [1]: 83%|████████▎ | 292/350 [00:26, 21.53it/s, step size=1.02e+00, acc. prob=0.871]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 84%|████████▍ | 295/350 [00:26, 17.91it/s, step size=1.02e+00, acc. prob=0.868]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 85%|████████▌ | 298/350 [00:26, 19.52it/s, step size=1.02e+00, acc. prob=0.870]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 86%|████████▌ | 301/350 [00:27, 17.18it/s, step size=1.02e+00, acc. prob=0.869]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 87%|████████▋ | 303/350 [00:27, 17.40it/s, step size=1.02e+00, acc. prob=0.870]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 87%|████████▋ | 305/350 [00:27, 16.64it/s, step size=1.02e+00, acc. prob=0.870]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 88%|████████▊ | 308/350 [00:27, 16.98it/s, step size=1.02e+00, acc. prob=0.868]\n", + "Sample [1]: 89%|████████▉ | 311/350 [00:27, 18.83it/s, step size=1.02e+00, acc. prob=0.868]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 89%|████████▉ | 313/350 [00:27, 17.74it/s, step size=1.02e+00, acc. prob=0.867]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 91%|█████████ | 319/350 [00:28, 20.87it/s, step size=1.02e+00, acc. prob=0.867]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 92%|█████████▏| 322/350 [00:28, 19.73it/s, step size=1.02e+00, acc. prob=0.866]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 93%|█████████▎| 325/350 [00:28, 18.57it/s, step size=1.02e+00, acc. prob=0.866]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 94%|█████████▎| 328/350 [00:28, 19.92it/s, step size=1.02e+00, acc. prob=0.864]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 95%|█████████▍| 331/350 [00:28, 18.84it/s, step size=1.02e+00, acc. prob=0.864]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 96%|█████████▌| 335/350 [00:28, 21.45it/s, step size=1.02e+00, acc. prob=0.864]\n", + "\u001b[A\n", + "\n", + "Sample [1]: 97%|█████████▋| 338/350 [00:29, 19.90it/s, step size=1.02e+00, acc. prob=0.863]\n", + "Sample [1]: 97%|█████████▋| 341/350 [00:29, 20.98it/s, step size=1.02e+00, acc. prob=0.861]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 98%|█████████▊| 344/350 [00:29, 21.87it/s, step size=1.02e+00, acc. prob=0.861]\n", + "\n", + "\u001b[A\u001b[A\n", + "Sample [1]: 99%|█████████▉| 347/350 [00:29, 19.12it/s, step size=1.02e+00, acc. prob=0.862]\n", + "\n", + "Sample [1]: 100%|██████████| 350/350 [00:29, 20.74it/s, step size=1.02e+00, acc. prob=0.861]\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "\u001b[A\u001b[A\n", + "\n", + "Sample [1]: 100%|██████████| 350/350 [00:36, 9.55it/s, step size=1.02e+00, acc. prob=0.861]\n", + "Sample [2]: 100%|██████████| 350/350 [00:36, 9.55it/s, step size=5.77e-01, acc. prob=0.925]\n", + "Sample [3]: 100%|██████████| 350/350 [00:35, 9.73it/s, step size=5.48e-01, acc. prob=0.924]\n", + "Sample [4]: 100%|██████████| 350/350 [00:35, 9.73it/s, step size=5.24e-01, acc. prob=0.930]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sampling Level 4b: PosteriorParameters with validation\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Sampling Level 4b: Using PosteriorParameters (recommended)\n", + "from sbi.inference.posteriors import MCMCPosteriorParameters\n", + "\n", + "mcmc_params = MCMCPosteriorParameters(\n", + " method=\"nuts_pyro\",\n", + " warmup_steps=100,\n", + " num_chains=4,\n", + " init_strategy=\"sir\", # Sequential Importance Resampling for init\n", + " init_strategy_parameters={\"num_candidate_samples\": 1000},\n", + " num_workers=2,\n", + " mp_context=\"spawn\" # Multiprocessing context\n", + ")\n", + "\n", + "posterior_advanced = inference_nle.build_posterior(\n", + " posterior_parameters=mcmc_params\n", + ")\n", + "\n", + "samples_advanced = posterior_advanced.sample((1000,), x=x_o)\n", + "\n", + "print(\"Sampling Level 4b: PosteriorParameters with validation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Key tuning parameters**:\n", + "- `warmup_steps`: Number of initial samples to discard (default: 200)\n", + "- `num_chains`: Number of parallel chains (default: 20)\n", + "- `thin`: Thinning factor - keep every nth sample (default: -1, auto)\n", + "- `init_strategy`: How to initialize chains (`\"proposal\"`, `\"sir\"`, `\"resample\"`)\n", + "- `num_workers`: Number of CPU cores for parallelization\n", + "\n", + "**Advantages of PosteriorParameters**:\n", + "- Type checking and validation\n", + "- Better IDE autocomplete support\n", + "- Clear documentation of available parameters\n", + "\n", + "**See also**: [how_to_guide/19_posterior_parameters.ipynb](https://sbi.readthedocs.io/en/latest/how_to_guide/19_posterior_parameters.html) for complete details." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decision Guides\n", + "\n", + "## Guide 1: Which Density Estimator Level? (NPE and NLE)\n", + "\n", + "| You want to... | Use Level | Example |\n", + "|----------------|-----------|----------|\n", + "| Standard workflows with good defaults | **1** | `NPE(prior, density_estimator=\"nsf\")` |\n", + "| Try different density estimator types | **1** | Switch `\"nsf\"`, `\"maf\"`, `\"zuko_nsf\"` |\n", + "| Tune network depth or width | **2** | `posterior_nn(hidden_features=100)` |\n", + "| Add embedding networks for images/timeseries | **2** | `posterior_nn(embedding_net=my_cnn)` |\n", + "| Access specialized flow parameters | **3** | `build_nsf(num_bins=16, tail_bound=5.0)` |\n", + "| Implement custom network architecture | **3** | Subclass `DensityEstimator` |\n", + "| Define custom loss functions or training | **4** | See advanced tutorial 18 |\n", + "\n", + "**Rule of thumb**: Start with Level 1. Move to higher levels only when you need specific control." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Guide 2: Which Sampling Level? (NLE and NRE only)\n", + "\n", + "| Situation | Sampling Level | Example |\n", + "|-----------|----------------|----------|\n", + "| Starting out, need good defaults | **1** | `build_posterior()` |\n", + "| Very few parameters (<3) | **2** | `sample_with=\"rejection\"` |\n", + "| Many parameters (>10), speed is critical | **2** | `sample_with=\"vi\"` |\n", + "| Want to use NUTS or HMC | **3** | `mcmc_method=\"nuts_pyro\"` |\n", + "| MCMC not converging, need more warmup | **4** | `mcmc_parameters={\"warmup_steps\": 500}` |\n", + "| Want type checking and validation | **4** | `MCMCPosteriorParameters(...)` |\n", + "| Troubleshooting sampling issues | **4** | Tune `num_chains`, `init_strategy`, etc. |\n", + "\n", + "**Rule of thumb**: \n", + "- Start with default MCMC (Level 1)\n", + "- If too slow, try rejection (few params) or VI (many params) at Level 2\n", + "- Use Level 3-4 for optimization or troubleshooting" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Summary\n", + "\n", + "## All Methods (NPE, NLE, NRE)\n", + "\n", + "**4 Abstraction Levels for Density Estimator:**\n", + "\n", + "- **Level 1**: Trainer classes with strings → `NPE(prior, density_estimator=\"nsf\")`\n", + "- **Level 2**: Factory functions → `posterior_nn(model=\"maf\", hidden_features=50)`\n", + "- **Level 3**: Direct builders → `build_nsf(num_bins=8, tail_bound=3.0)`\n", + "- **Level 4**: Custom training → Full control (see tutorial 18)\n", + "\n", + "## NPE Sampling\n", + "\n", + "- Direct sampling from neural network\n", + "- No additional configuration needed\n", + "- Optionally can use MCMC/VI for more control\n", + "\n", + "## NLE and NRE Sampling (Additional Dimension)\n", + "\n", + "**4 Sampling Control Levels:**\n", + "\n", + "- **Level 1**: Default → `build_posterior()` (uses slice_np_vectorized)\n", + "- **Level 2**: Choose method → `sample_with=\"mcmc\"/\"vi\"/\"rejection\"/\"importance\"`\n", + "- **Level 3**: Configure algorithm → `mcmc_method=\"nuts_pyro\"`, `vi_method=\"fKL\"`\n", + "- **Level 4**: Fine-tune parameters → `mcmc_parameters={...}` or `MCMCPosteriorParameters(...)`\n", + "\n", + "## General Principle\n", + "\n", + "**Start simple, add complexity only when needed:**\n", + "1. Begin with Level 1 for density estimator\n", + "2. For NLE, begin with default sampling (Level 1)\n", + "3. Move to higher levels only when you need specific control or encounter issues\n", + "4. Both dimensions are independent - you can use Level 1 density estimator with Level 4 sampling, or vice versa" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}