Skip to content

Commit 89dc308

Browse files
Improve docs of MNPE and EnsemblePosterior
1 parent 8762b0e commit 89dc308

File tree

4 files changed

+64
-23
lines changed

4 files changed

+64
-23
lines changed

docs/sbi.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Training
5757

5858
sbi.inference.NPE_C
5959
sbi.inference.NPE_A
60+
sbi.inference.MNPE
6061
sbi.inference.FMPE
6162
sbi.inference.NPSE
6263
sbi.inference.NLE_A

sbi/inference/posteriors/ensemble_posterior.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
class EnsemblePosterior(NeuralPosterior):
2020
r"""Wrapper for bundling together different posterior instances into an ensemble.
2121
22-
This class creates a posterior ensemble from a set of N different, already trained
23-
posterior estimators :math:`p_{i}(\theta|x_o)`, where :math:`i \in \{1,...,N\}`.
22+
This class creates a posterior ensemble from a set of :math:`N` different, already
23+
trained posterior estimators :math:`p_{i}(\theta \mid x_o)`, where
24+
:math:`i \in \{1, \ldots, N\}`.
2425
2526
It can wrap all posterior classes available in ``sbi`` and even a mixture of
2627
different posteriors, i.e. obtained via SNLE and SNPE at the same time, since it
@@ -30,32 +31,38 @@ class EnsemblePosterior(NeuralPosterior):
3031
3132
So far, ``log_prob()``, ``sample()`` and ``map()`` functionality are supported.
3233
33-
Attributes:
34-
posteriors: List of the posterior estimators making up the ensemble.
35-
num_components: Number of posterior estimators.
36-
weights: Weight of each posterior distribution. If none are provided each
37-
posterior is weighted with 1/N.
38-
priors: Prior distributions of all posterior components.
39-
theta_transform: If passed, this transformation will be applied during the
40-
optimization performed when obtaining the map. It does not affect the
41-
.sample() and .log_prob() methods.
42-
device: device to host the posterior distribution.
43-
4434
Example:
4535
--------
4636
4737
::
4838
4939
import torch
50-
from joblib import Parallel, delayed
51-
from sbi.examples.minimal import simple
40+
from sbi.inference import NPE, EnsemblePosterior
41+
42+
theta = prior.sample((100,))
43+
x = simulate(theta)
5244
5345
n_ensembles = 10
54-
posteriors = Parallel(n_jobs=-1)(delayed(simple)() for i in range(n_ensembles))
46+
posteriors = []
47+
for _ in range(n_ensembles):
48+
inference = NPE()
49+
inference.append_simulations(theta, x).train()
50+
posteriors.append(inference.build_posterior())
5551
5652
ensemble = EnsemblePosterior(posteriors)
5753
ensemble.set_default_x(torch.zeros((3,)))
5854
ensemble.sample((1,))
55+
56+
Attributes:
57+
posteriors: List of the posterior estimators making up the ensemble.
58+
num_components: Number of posterior estimators.
59+
weights: Weight of each posterior distribution. If none are provided each
60+
posterior is weighted with 1/N.
61+
priors: Prior distributions of all posterior components.
62+
theta_transform: If passed, this transformation will be applied during the
63+
optimization performed when obtaining the map. It does not affect the
64+
.sample() and .log_prob() methods.
65+
device: device to host the posterior distribution.
5966
"""
6067

6168
def __init__(

sbi/inference/trainers/nle/mnle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020

2121

2222
class MNLE(LikelihoodEstimatorTrainer):
23-
"""Mixed Neural Likelihood Estimation (MNLE) [1].
23+
"""Method that can infer parameters given discrete and continuous data (Mixed NLE).
2424
2525
Like NLE, but designed to be applied to data with mixed types, e.g., continuous
2626
data and discrete data like they occur in decision-making experiments
2727
(reation times and choices).
2828
29-
[1] Flexible and efficient simulation-based inference for models of
29+
Flexible and efficient simulation-based inference for models of
3030
decision-making, Boelts et al. 2021,
3131
https://www.biorxiv.org/content/10.1101/2021.12.22.473472v2
3232
"""

sbi/inference/trainers/npe/mnpe.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,45 @@
2121

2222

2323
class MNPE(NPE_C):
24-
r"""Mixed Neural Posterior Estimation (MNPE).
24+
r"""Method that can infer discrete and continuous parameters (Mixed NPE).
2525
26-
Like NPE-C, but designed to be applied to data with mixed types, e.g.,
27-
continuous parameters and discrete parameters like they occur in models with
28-
switching components. The emebedding net will only operate on the continuous
29-
parameters, note this to design the dimension of the embedding net.
26+
MNPE (Mixed Neural Posterior Estimation) is similar to NPE: it directly
27+
estimates a distribution over parameters given data. Unlike NPE, it is designed to
28+
be applied to parmaeters with mixed types, i.e., continuous and discrete parameters.
29+
This can occur, for example, in models with switching components. The emebedding
30+
net will only operate on the continuous parameters, note this to design the
31+
dimension of the embedding net.
32+
33+
Example:
34+
--------
35+
36+
::
37+
38+
import torch
39+
from sbi.inference import MNPE
40+
41+
dim_theta_discrete = 3
42+
dim_theta_continuous = 2
43+
dim_theta = 5
44+
dim_x = 50
45+
46+
num_sims = 100
47+
48+
discrete_theta = torch.randint(low=0, high=2, size=(100, dim_theta_discrete))
49+
continuous_theta = torch.randn((num_sims, dim_theta_discrete))
50+
51+
# Important: the theta must have all continuous paramters first, and
52+
# discrete parameters after this.
53+
theta = torch.cat([continuous_theta, discrete_theta], dim=1)
54+
x = torch.randn((num_sims, dim_x))
55+
56+
inference = MNPE()
57+
_ = inference.append_simulations(theta, x).train()
58+
59+
posterior = inference.build_posterior()
60+
61+
x_o = torch.randn((1, dim_x))
62+
samples = posterior.sample((100,), x=x_o)
3063
"""
3164

3265
def __init__(

0 commit comments

Comments
 (0)