Skip to content

Commit 333427f

Browse files
abelabajanfb
andauthored
refactor: bake RatioEstimator into the general ConditionalEstimator API (#1652)
* refactor: update RatioEstimator class to inherit from ConditionalEstimator * refactor: update type hints from Union to ConditionalEstimator type * reafactor(nre): remove RatioEstimatorBuilder protocol and update type hints * refactor(npe, nle): update estimator builder type hints to not include RatioEstimator * refactor: remove VectorFieldEstimatorBuilder and update estimator type hints to use DensityEstimatorBuilder * docs: remove empty line * test: add FMPE and NPSE for testing valid and invalid estimator builders * chore: add loss function to RatioEstimator * chore(nle, npe): remove ConditionalVectorFieldEstimator builder type * test: update parameters for checking invalid builder * chore(fmpe): assign default value for prior * test(npse, fmpe): correct test methods * chore: rename DensityEstimatorBuilder to ConditionalEstimatorBuilder * chore(nle, npe): update density_estimator type annotation * test: check embedding_net property is not None * chore: update density_estimator type annotation for mnle and mnpe * Update sbi/diagnostics/misspecification.py Co-authored-by: Jan <janfb@users.noreply.github.com> * chore(nle): update import structure * chore: update formatting --------- Co-authored-by: Jan <janfb@users.noreply.github.com>
1 parent eb62d96 commit 333427f

File tree

21 files changed

+189
-179
lines changed

21 files changed

+189
-179
lines changed

sbi/diagnostics/misspecification.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ def calc_misspecification_mmd(
155155
"in that case the MMD is computed in the x-space.",
156156
stacklevel=2,
157157
)
158+
if inference._neural_net.embedding_net is None:
159+
raise AttributeError(
160+
"embedding_net attribute is None but is required for misspecification"
161+
" detection."
162+
)
163+
158164
z_obs = inference._neural_net.embedding_net(x_obs).detach()
159165
z = inference._neural_net.embedding_net(x).detach()
160166
else:

sbi/inference/trainers/base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
ConditionalEstimator,
5050
ConditionalVectorFieldEstimator,
5151
)
52-
from sbi.neural_nets.ratio_estimators import RatioEstimator
5352
from sbi.sbi_types import TorchTransform
5453
from sbi.utils import (
5554
check_prior,
@@ -314,7 +313,7 @@ def train(
314313
def _get_potential_function(
315314
self,
316315
prior: Distribution,
317-
estimator: Union[RatioEstimator, ConditionalEstimator],
316+
estimator: ConditionalEstimator,
318317
) -> Tuple[BasePotential, TorchTransform]:
319318
"""Subclass-specific potential creation"""
320319
...
@@ -416,7 +415,7 @@ def get_dataloaders(
416415

417416
def build_posterior(
418417
self,
419-
estimator: Optional[Union[RatioEstimator, ConditionalEstimator]],
418+
estimator: Optional[ConditionalEstimator],
420419
prior: Optional[Distribution],
421420
sample_with: Literal[
422421
"mcmc", "rejection", "vi", "importance", "direct", "sde", "ode"
@@ -501,8 +500,8 @@ def _resolve_prior(self, prior: Optional[Distribution]) -> Distribution:
501500
return prior
502501

503502
def _resolve_estimator(
504-
self, estimator: Optional[Union[RatioEstimator, ConditionalEstimator]]
505-
) -> Tuple[Union[RatioEstimator, ConditionalEstimator], str]:
503+
self, estimator: Optional[ConditionalEstimator]
504+
) -> Tuple[ConditionalEstimator, str]:
506505
"""
507506
Resolves the estimator and determines its device.
508507
@@ -525,9 +524,9 @@ def _resolve_estimator(
525524
# If internal net is used device is defined.
526525
device = self._device
527526
else:
528-
if not isinstance(estimator, (ConditionalEstimator, RatioEstimator)):
527+
if not isinstance(estimator, ConditionalEstimator):
529528
raise TypeError(
530-
"estimator must be ConditionalEstimator or RatioEstimator,"
529+
"estimator must be ConditionalEstimator,"
531530
f" got {type(estimator).__name__}",
532531
)
533532
# Otherwise, infer it from the device of the net parameters.
@@ -759,7 +758,7 @@ def _validate_posterior_parameters_consistency(
759758

760759
def _create_posterior(
761760
self,
762-
estimator: Union[RatioEstimator, ConditionalEstimator],
761+
estimator: ConditionalEstimator,
763762
prior: Distribution,
764763
sample_with: Literal[
765764
"mcmc", "rejection", "vi", "importance", "direct", "sde", "ode"

sbi/inference/trainers/nle/mnle.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
1616
from sbi.neural_nets.estimators import MixedDensityEstimator
17-
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
17+
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
1818
from sbi.sbi_types import TensorBoardSummaryWriter
1919
from sbi.utils.sbiutils import del_entries
2020

@@ -34,7 +34,10 @@ class MNLE(LikelihoodEstimatorTrainer):
3434
def __init__(
3535
self,
3636
prior: Optional[Distribution] = None,
37-
density_estimator: Union[Literal["mnle"], DensityEstimatorBuilder] = "mnle",
37+
density_estimator: Union[
38+
Literal["mnle"],
39+
ConditionalEstimatorBuilder[MixedDensityEstimator],
40+
] = "mnle",
3841
device: str = "cpu",
3942
logging_level: Union[int, str] = "WARNING",
4043
summary_writer: Optional[TensorBoardSummaryWriter] = None,
@@ -49,11 +52,11 @@ def __init__(
4952
density_estimator: If it is a string, it must be "mnle" to use the
5053
preconfiugred neural nets for MNLE. Alternatively, a function
5154
that builds a custom neural network, which adheres to
52-
`DensityEstimatorBuilder` protocol can be provided. The function will
53-
be called with the first batch of simulations (theta, x), which can
55+
`ConditionalEstimatorBuilder` protocol can be provided. The function
56+
will be called with the first batch of simulations (theta, x), which can
5457
thus be used for shape inference and potentially for z-scoring. The
5558
density estimator needs to provide the methods `.log_prob` and
56-
`.sample()`.
59+
`.sample()` and must return a `MixedDensityEstimator`.
5760
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
5861
logging_level: Minimum severity of messages to log. One of the strings
5962
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/nle/nle_a.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from torch.distributions import Distribution
77

88
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
9-
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
9+
from sbi.neural_nets.estimators.base import (
10+
ConditionalDensityEstimator,
11+
ConditionalEstimatorBuilder,
12+
)
1013
from sbi.sbi_types import TensorBoardSummaryWriter
1114
from sbi.utils.sbiutils import del_entries
1215

@@ -23,7 +26,8 @@ def __init__(
2326
self,
2427
prior: Optional[Distribution] = None,
2528
density_estimator: Union[
26-
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
29+
Literal["nsf", "maf", "mdn", "made"],
30+
ConditionalEstimatorBuilder[ConditionalDensityEstimator],
2731
] = "maf",
2832
device: str = "cpu",
2933
logging_level: Union[int, str] = "WARNING",
@@ -39,11 +43,11 @@ def __init__(
3943
density_estimator: If it is a string, use a pre-configured network of the
4044
provided type (one of nsf, maf, mdn, made). Alternatively, a function
4145
that builds a custom neural network, which adheres to
42-
`DensityEstimatorBuilder` protocol can be provided. The function will
43-
be called with the first batch of simulations (theta, x), which can
46+
`ConditionalEstimatorBuilder` protocol can be provided. The function
47+
will be called with the first batch of simulations (theta, x), which can
4448
thus be used for shape inference and potentially for z-scoring. The
4549
density estimator needs to provide the methods `.log_prob` and
46-
`.sample()`.
50+
`.sample()` and must return a `ConditionalDensityEstimator`.
4751
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
4852
logging_level: Minimum severity of messages to log. One of the strings
4953
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/nle/nle_base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sbi.inference.trainers.base import NeuralInference
2727
from sbi.neural_nets import likelihood_nn
2828
from sbi.neural_nets.estimators import ConditionalDensityEstimator
29-
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
29+
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
3030
from sbi.neural_nets.estimators.shape_handling import (
3131
reshape_to_batch_event,
3232
)
@@ -40,7 +40,8 @@ def __init__(
4040
self,
4141
prior: Optional[Distribution] = None,
4242
density_estimator: Union[
43-
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
43+
Literal["nsf", "maf", "mdn", "made"],
44+
ConditionalEstimatorBuilder[ConditionalDensityEstimator],
4445
] = "maf",
4546
device: str = "cpu",
4647
logging_level: Union[int, str] = "WARNING",
@@ -57,11 +58,11 @@ def __init__(
5758
density_estimator: If it is a string, use a pre-configured network of the
5859
provided type (one of nsf, maf, mdn, made). Alternatively, a function
5960
that builds a custom neural network, which adheres to
60-
`DensityEstimatorBuilder` protocol can be provided. The function will
61-
be called with the first batch of simulations (theta, x), which can
61+
`ConditionalEstimatorBuilder` protocol can be provided. The function
62+
will be called with the first batch of simulations (theta, x), which can
6263
thus be used for shape inference and potentially for z-scoring. The
6364
density estimator needs to provide the methods `.log_prob` and
64-
`.sample()`.
65+
`.sample()` and must return a `ConditionalDensityEstimator`.
6566
6667
See docstring of `NeuralInference` class for all other arguments.
6768
"""

sbi/inference/trainers/npe/mnpe.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from sbi.inference.trainers.npe.npe_c import NPE_C
1717
from sbi.neural_nets.estimators import MixedDensityEstimator
18-
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
18+
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
1919
from sbi.sbi_types import TensorBoardSummaryWriter
2020
from sbi.utils.sbiutils import del_entries
2121

@@ -32,7 +32,10 @@ class MNPE(NPE_C):
3232
def __init__(
3333
self,
3434
prior: Optional[Distribution] = None,
35-
density_estimator: Union[Literal["mnpe"], DensityEstimatorBuilder] = "mnpe",
35+
density_estimator: Union[
36+
Literal["mnpe"],
37+
ConditionalEstimatorBuilder[MixedDensityEstimator],
38+
] = "mnpe",
3639
device: str = "cpu",
3740
logging_level: Union[int, str] = "WARNING",
3841
summary_writer: Optional[TensorBoardSummaryWriter] = None,
@@ -47,11 +50,11 @@ def __init__(
4750
density_estimator: If it is a string, it must be "mnpe" to use the
4851
preconfigured neural nets for MNPE. Alternatively, a function
4952
that builds a custom neural network, which adheres to
50-
`DensityEstimatorBuilder` protocol can be provided. The function will
51-
be called with the first batch of simulations (theta, x), which can
53+
`ConditionalEstimatorBuilder` protocol can be provided. The function
54+
will be called with the first batch of simulations (theta, x), which can
5255
thus be used for shape inference and potentially for z-scoring. The
5356
density estimator needs to provide the methods `.log_prob` and
54-
`.sample()`.
57+
`.sample()` and must return a `MixedDensityEstimator`.
5558
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
5659
logging_level: Minimum severity of messages to log. One of the strings
5760
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/npe/npe_a.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from sbi.neural_nets.estimators.base import (
2020
ConditionalDensityEstimator,
21-
DensityEstimatorBuilder,
21+
ConditionalEstimatorBuilder,
2222
)
2323
from sbi.sbi_types import TensorBoardSummaryWriter
2424
from sbi.utils import torchutils
@@ -53,7 +53,8 @@ def __init__(
5353
self,
5454
prior: Optional[Distribution] = None,
5555
density_estimator: Union[
56-
Literal["mdn_snpe_a"], DensityEstimatorBuilder
56+
Literal["mdn_snpe_a"],
57+
ConditionalEstimatorBuilder[ConditionalDensityEstimator],
5758
] = "mdn_snpe_a",
5859
num_components: int = 10,
5960
device: str = "cpu",
@@ -71,14 +72,15 @@ def __init__(
7172
density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a
7273
pre-configured mixture of densities network. Alternatively, a function
7374
that builds a custom neural network, which adheres to
74-
`DensityEstimatorBuilder` protocol can be provided. The function will
75-
be called with the first batch of simulations (theta, x), which can
75+
`ConditionalEstimatorBuilder` protocol can be provided. The function
76+
will be called with the first batch of simulations (theta, x), which can
7677
thus be used for shape inference and potentially for z-scoring. The
7778
density estimator needs to provide the methods `.log_prob` and
78-
`.sample()`. Note that until the last round only a single (multivariate)
79-
Gaussian component is used for training (seeAlgorithm 1 in [1]). In the
80-
last round, this component is replicated `num_components` times, its
81-
parameters are perturbed with a very small noise, and then the last
79+
`.sample()` and must return a `ConditionalDensityEstimator`.
80+
Note that until the last round only a single (multivariate) Gaussian
81+
component is used for training (seeAlgorithm 1 in [1]). In the last
82+
round, this component is replicated `num_components` times,
83+
its parameters are perturbed with a very small noise, and then the last
8284
training round is done with the expanded Gaussian mixture as estimator
8385
for the proposal posterior.
8486
num_components: Number of components of the mixture of Gaussians in the

sbi/inference/trainers/npe/npe_b.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from sbi.inference.trainers.npe.npe_base import (
1212
PosteriorEstimatorTrainer,
1313
)
14-
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
14+
from sbi.neural_nets.estimators.base import (
15+
ConditionalDensityEstimator,
16+
ConditionalEstimatorBuilder,
17+
)
1518
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
1619
from sbi.sbi_types import TensorBoardSummaryWriter
1720
from sbi.utils.sbiutils import del_entries
@@ -37,7 +40,8 @@ def __init__(
3740
self,
3841
prior: Optional[Distribution] = None,
3942
density_estimator: Union[
40-
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
43+
Literal["nsf", "maf", "mdn", "made"],
44+
ConditionalEstimatorBuilder[ConditionalDensityEstimator],
4145
] = "maf",
4246
device: str = "cpu",
4347
logging_level: Union[int, str] = "WARNING",
@@ -52,11 +56,11 @@ def __init__(
5256
density_estimator: If it is a string, use a pre-configured network of the
5357
provided type (one of nsf, maf, mdn, made). Alternatively, a function
5458
that builds a custom neural network, which adheres to
55-
`DensityEstimatorBuilder` protocol can be provided. The function will
56-
be called with the first batch of simulations (theta, x), which can
59+
`ConditionalEstimatorBuilder` protocol can be provided. The function
60+
will be called with the first batch of simulations (theta, x), which can
5761
thus be used for shape inference and potentially for z-scoring. The
5862
density estimator needs to provide the methods `.log_prob` and
59-
`.sample()`.
63+
`.sample()` and must return a `ConditionalDensityEstimator`.
6064
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
6165
logging_level: Minimum severity of messages to log. One of the strings
6266
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/npe/npe_base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from torch.utils.tensorboard.writer import SummaryWriter
1616
from typing_extensions import Self
1717

18-
from sbi.inference.posteriors import (
19-
DirectPosterior,
20-
)
18+
from sbi.inference.posteriors import DirectPosterior
2119
from sbi.inference.posteriors.base_posterior import NeuralPosterior
2220
from sbi.inference.posteriors.posterior_parameters import (
2321
DirectPosteriorParameters,
@@ -34,7 +32,7 @@
3432
)
3533
from sbi.neural_nets import posterior_nn
3634
from sbi.neural_nets.estimators import ConditionalDensityEstimator
37-
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
35+
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
3836
from sbi.neural_nets.estimators.shape_handling import (
3937
reshape_to_batch_event,
4038
reshape_to_sample_batch_event,
@@ -59,7 +57,8 @@ def __init__(
5957
self,
6058
prior: Optional[Distribution] = None,
6159
density_estimator: Union[
62-
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
60+
Literal["nsf", "maf", "mdn", "made"],
61+
ConditionalEstimatorBuilder[ConditionalDensityEstimator],
6362
] = "maf",
6463
device: str = "cpu",
6564
logging_level: Union[int, str] = "WARNING",
@@ -76,11 +75,11 @@ def __init__(
7675
density_estimator: If it is a string, use a pre-configured network of the
7776
provided type (one of nsf, maf, mdn, made). Alternatively, a function
7877
that builds a custom neural network, which adheres to
79-
`DensityEstimatorBuilder` protocol can be provided. The function will
80-
be called with the first batch of simulations (theta, x), which can
78+
`ConditionalEstimatorBuilder` protocol can be provided. The function
79+
will be called with the first batch of simulations (theta, x), which can
8180
thus be used for shape inference and potentially for z-scoring. The
8281
density estimator needs to provide the methods `.log_prob` and
83-
`.sample()`.
82+
`.sample()` and must return a `ConditionalDensityEstimator`.
8483
8584
See docstring of `NeuralInference` class for all other arguments.
8685
"""

sbi/inference/trainers/npe/npe_c.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from sbi.neural_nets.estimators.base import (
1717
ConditionalDensityEstimator,
18-
DensityEstimatorBuilder,
18+
ConditionalEstimatorBuilder,
1919
)
2020
from sbi.neural_nets.estimators.shape_handling import (
2121
reshape_to_batch_event,
@@ -73,7 +73,8 @@ def __init__(
7373
self,
7474
prior: Optional[Distribution] = None,
7575
density_estimator: Union[
76-
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
76+
Literal["nsf", "maf", "mdn", "made"],
77+
ConditionalEstimatorBuilder[ConditionalDensityEstimator],
7778
] = "maf",
7879
device: str = "cpu",
7980
logging_level: Union[int, str] = "WARNING",
@@ -88,12 +89,11 @@ def __init__(
8889
density_estimator: If it is a string, use a pre-configured network of the
8990
provided type (one of nsf, maf, mdn, made). Alternatively, a function
9091
that builds a custom neural network, which adheres to
91-
`DensityEstimatorBuilder` protocol can be provided. The function will
92-
be called with the first batch of simulations (theta, x), which can
92+
`ConditionalEstimatorBuilder` protocol can be provided. The function
93+
will be called with the first batch of simulations (theta, x), which can
9394
thus be used for shape inference and potentially for z-scoring. The
9495
density estimator needs to provide the methods `.log_prob` and
95-
`.sample()`.
96-
96+
`.sample()` and must return a `ConditionalDensityEstimator`.
9797
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
9898
logging_level: Minimum severity of messages to log. One of the strings
9999
INFO, WARNING, DEBUG, ERROR and CRITICAL.

0 commit comments

Comments
 (0)