Skip to content

Commit 7064286

Browse files
abelabaARna06janfb
authored
refactor: Add protocol for estimator builder (#1633)
* changed type of writer and module * Implement custom density builder protocol and change related types * change the type of density estimators builder to the implemented protocol for all the childs * refactor(nle, npe): update DensityEstimatorBuilder return type and replace callable annotation * test(nle, npe): add tests for checking DensityEstimatorBuilder Protocol * chore(npe): update DensityEstimatorBuilder import file * chore: move DensityEstimatorBuilder to estimators * docs(nle, npe): update docstring for density_estimator parameter * test: move estimator builder test to density_estimator_test file * Revert "changed type of writer and module" This reverts commit 8ae70bc. * remove leftover import --------- Co-authored-by: ARna06 <72038543+Leopard005537@users.noreply.github.com> Co-authored-by: Jan <janfb@users.noreply.github.com> Co-authored-by: Jan <jan.boelts@mailbox.org>
1 parent 67d7e1c commit 7064286

File tree

12 files changed

+243
-126
lines changed

12 files changed

+243
-126
lines changed

sbi/inference/trainers/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77
from dataclasses import asdict
88
from datetime import datetime
99
from pathlib import Path
10-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
10+
from typing import (
11+
Any,
12+
Callable,
13+
Dict,
14+
List,
15+
Literal,
16+
Optional,
17+
Tuple,
18+
Union,
19+
)
1120
from warnings import warn
1221

1322
import torch

sbi/inference/trainers/nle/mnle.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Dict, Literal, Optional, Union
4+
from typing import Any, Dict, Literal, Optional, Union
55

66
from torch.distributions import Distribution
77

@@ -14,6 +14,7 @@
1414
)
1515
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
1616
from sbi.neural_nets.estimators import MixedDensityEstimator
17+
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
1718
from sbi.sbi_types import TensorBoardSummaryWriter
1819
from sbi.utils.sbiutils import del_entries
1920

@@ -33,7 +34,7 @@ class MNLE(LikelihoodEstimatorTrainer):
3334
def __init__(
3435
self,
3536
prior: Optional[Distribution] = None,
36-
density_estimator: Union[str, Callable] = "mnle",
37+
density_estimator: Union[Literal["mnle"], DensityEstimatorBuilder] = "mnle",
3738
device: str = "cpu",
3839
logging_level: Union[int, str] = "WARNING",
3940
summary_writer: Optional[TensorBoardSummaryWriter] = None,
@@ -47,12 +48,12 @@ def __init__(
4748
prior must be passed to `.build_posterior()`.
4849
density_estimator: If it is a string, it must be "mnle" to use the
4950
preconfiugred neural nets for MNLE. Alternatively, a function
50-
that builds a custom neural network can be provided. The function will
51+
that builds a custom neural network, which adheres to
52+
`DensityEstimatorBuilder` protocol can be provided. The function will
5153
be called with the first batch of simulations (theta, x), which can
52-
thus be used for shape inference and potentially for z-scoring. It
53-
needs to return a PyTorch `nn.Module` implementing the density
54-
estimator. The density estimator needs to provide the methods
55-
`.log_prob`, `.log_prob_iid()` and `.sample()`.
54+
thus be used for shape inference and potentially for z-scoring. The
55+
density estimator needs to provide the methods `.log_prob` and
56+
`.sample()`.
5657
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
5758
logging_level: Minimum severity of messages to log. One of the strings
5859
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/nle/nle_a.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Callable, Optional, Union
4+
from typing import Literal, Optional, Union
55

66
from torch.distributions import Distribution
77

88
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
9+
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
910
from sbi.sbi_types import TensorBoardSummaryWriter
1011
from sbi.utils.sbiutils import del_entries
1112

@@ -21,7 +22,9 @@ class NLE_A(LikelihoodEstimatorTrainer):
2122
def __init__(
2223
self,
2324
prior: Optional[Distribution] = None,
24-
density_estimator: Union[str, Callable] = "maf",
25+
density_estimator: Union[
26+
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
27+
] = "maf",
2528
device: str = "cpu",
2629
logging_level: Union[int, str] = "WARNING",
2730
summary_writer: Optional[TensorBoardSummaryWriter] = None,
@@ -35,12 +38,12 @@ def __init__(
3538
prior must be passed to `.build_posterior()`.
3639
density_estimator: If it is a string, use a pre-configured network of the
3740
provided type (one of nsf, maf, mdn, made). Alternatively, a function
38-
that builds a custom neural network can be provided. The function will
41+
that builds a custom neural network, which adheres to
42+
`DensityEstimatorBuilder` protocol can be provided. The function will
3943
be called with the first batch of simulations (theta, x), which can
40-
thus be used for shape inference and potentially for z-scoring. It
41-
needs to return a PyTorch `nn.Module` implementing the density
42-
estimator. The density estimator needs to provide the methods
43-
`.log_prob` and `.sample()`.
44+
thus be used for shape inference and potentially for z-scoring. The
45+
density estimator needs to provide the methods `.log_prob` and
46+
`.sample()`.
4447
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
4548
logging_level: Minimum severity of messages to log. One of the strings
4649
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/nle/nle_base.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from abc import ABC
66
from copy import deepcopy
7-
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
7+
from typing import Any, Dict, Literal, Optional, Tuple, Union
88

99
import torch
1010
from torch import Tensor
@@ -26,6 +26,7 @@
2626
from sbi.inference.trainers.base import NeuralInference
2727
from sbi.neural_nets import likelihood_nn
2828
from sbi.neural_nets.estimators import ConditionalDensityEstimator
29+
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
2930
from sbi.neural_nets.estimators.shape_handling import (
3031
reshape_to_batch_event,
3132
)
@@ -38,7 +39,9 @@ class LikelihoodEstimatorTrainer(NeuralInference, ABC):
3839
def __init__(
3940
self,
4041
prior: Optional[Distribution] = None,
41-
density_estimator: Union[str, Callable] = "maf",
42+
density_estimator: Union[
43+
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
44+
] = "maf",
4245
device: str = "cpu",
4346
logging_level: Union[int, str] = "WARNING",
4447
summary_writer: Optional[SummaryWriter] = None,
@@ -53,12 +56,12 @@ def __init__(
5356
distribution) can be used.
5457
density_estimator: If it is a string, use a pre-configured network of the
5558
provided type (one of nsf, maf, mdn, made). Alternatively, a function
56-
that builds a custom neural network can be provided. The function will
59+
that builds a custom neural network, which adheres to
60+
`DensityEstimatorBuilder` protocol can be provided. The function will
5761
be called with the first batch of simulations (theta, x), which can
58-
thus be used for shape inference and potentially for z-scoring. It
59-
needs to return a PyTorch `nn.Module` implementing the density
60-
estimator. The density estimator needs to provide the methods
61-
`.log_prob` and `.sample()`.
62+
thus be used for shape inference and potentially for z-scoring. The
63+
density estimator needs to provide the methods `.log_prob` and
64+
`.sample()`.
6265
6366
See docstring of `NeuralInference` class for all other arguments.
6467
"""

sbi/inference/trainers/npe/mnpe.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Dict, Literal, Optional, Union
4+
from typing import Any, Dict, Literal, Optional, Union
55

66
from torch.distributions import Distribution
77

@@ -15,6 +15,7 @@
1515
)
1616
from sbi.inference.trainers.npe.npe_c import NPE_C
1717
from sbi.neural_nets.estimators import MixedDensityEstimator
18+
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
1819
from sbi.sbi_types import TensorBoardSummaryWriter
1920
from sbi.utils.sbiutils import del_entries
2021

@@ -31,7 +32,7 @@ class MNPE(NPE_C):
3132
def __init__(
3233
self,
3334
prior: Optional[Distribution] = None,
34-
density_estimator: Union[str, Callable] = "mnpe",
35+
density_estimator: Union[Literal["mnpe"], DensityEstimatorBuilder] = "mnpe",
3536
device: str = "cpu",
3637
logging_level: Union[int, str] = "WARNING",
3738
summary_writer: Optional[TensorBoardSummaryWriter] = None,
@@ -45,12 +46,12 @@ def __init__(
4546
prior must be passed to `.build_posterior()`.
4647
density_estimator: If it is a string, it must be "mnpe" to use the
4748
preconfigured neural nets for MNPE. Alternatively, a function
48-
that builds a custom neural network can be provided. The function will
49+
that builds a custom neural network, which adheres to
50+
`DensityEstimatorBuilder` protocol can be provided. The function will
4951
be called with the first batch of simulations (theta, x), which can
50-
thus be used for shape inference and potentially for z-scoring. It
51-
needs to return a PyTorch `nn.Module` implementing the density
52-
estimator. The density estimator needs to provide the methods
53-
`.log_prob`, `.log_prob_iid()` and `.sample()`.
52+
thus be used for shape inference and potentially for z-scoring. The
53+
density estimator needs to provide the methods `.log_prob` and
54+
`.sample()`.
5455
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
5556
logging_level: Minimum severity of messages to log. One of the strings
5657
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/npe/npe_a.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from copy import deepcopy
66
from functools import partial
7-
from typing import Any, Callable, Dict, Optional, Union
7+
from typing import Any, Callable, Dict, Literal, Optional, Union
88

99
import torch
1010
from pyknos.mdn.mdn import MultivariateGaussianMDN
@@ -16,7 +16,10 @@
1616
from sbi.inference.trainers.npe.npe_base import (
1717
PosteriorEstimatorTrainer,
1818
)
19-
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
19+
from sbi.neural_nets.estimators.base import (
20+
ConditionalDensityEstimator,
21+
DensityEstimatorBuilder,
22+
)
2023
from sbi.sbi_types import TensorBoardSummaryWriter
2124
from sbi.utils import torchutils
2225
from sbi.utils.sbiutils import (
@@ -49,7 +52,9 @@ class NPE_A(PosteriorEstimatorTrainer):
4952
def __init__(
5053
self,
5154
prior: Optional[Distribution] = None,
52-
density_estimator: Union[str, Callable] = "mdn_snpe_a",
55+
density_estimator: Union[
56+
Literal["mdn_snpe_a"], DensityEstimatorBuilder
57+
] = "mdn_snpe_a",
5358
num_components: int = 10,
5459
device: str = "cpu",
5560
logging_level: Union[int, str] = "WARNING",
@@ -65,17 +70,17 @@ def __init__(
6570
distribution) can be used.
6671
density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a
6772
pre-configured mixture of densities network. Alternatively, a function
68-
that builds a custom neural network can be provided. The function will
73+
that builds a custom neural network, which adheres to
74+
`DensityEstimatorBuilder` protocol can be provided. The function will
6975
be called with the first batch of simulations (theta, x), which can
70-
thus be used for shape inference and potentially for z-scoring. It
71-
needs to return a PyTorch `nn.Module` implementing the density
72-
estimator. The density estimator needs to provide the methods
73-
`.log_prob` and `.sample()`. Note that until the last round only a
74-
single (multivariate) Gaussian component is used for training (see
75-
Algorithm 1 in [1]). In the last round, this component is replicated
76-
`num_components` times, its parameters are perturbed with a very small
77-
noise, and then the last training round is done with the expanded
78-
Gaussian mixture as estimator for the proposal posterior.
76+
thus be used for shape inference and potentially for z-scoring. The
77+
density estimator needs to provide the methods `.log_prob` and
78+
`.sample()`. Note that until the last round only a single (multivariate)
79+
Gaussian component is used for training (seeAlgorithm 1 in [1]). In the
80+
last round, this component is replicated `num_components` times, its
81+
parameters are perturbed with a very small noise, and then the last
82+
training round is done with the expanded Gaussian mixture as estimator
83+
for the proposal posterior.
7984
num_components: Number of components of the mixture of Gaussians in the
8085
last round. This overrides the `num_components` value passed to
8186
`posterior_nn()`.

sbi/inference/trainers/npe/npe_b.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Optional, Union
4+
from typing import Any, Literal, Optional, Union
55

66
import torch
77
from torch import Tensor
88
from torch.distributions import Distribution
99

1010
import sbi.utils as utils
11-
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
11+
from sbi.inference.trainers.npe.npe_base import (
12+
PosteriorEstimatorTrainer,
13+
)
14+
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
1215
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
1316
from sbi.sbi_types import TensorBoardSummaryWriter
1417
from sbi.utils.sbiutils import del_entries
@@ -33,7 +36,9 @@ class NPE_B(PosteriorEstimatorTrainer):
3336
def __init__(
3437
self,
3538
prior: Optional[Distribution] = None,
36-
density_estimator: Union[str, Callable] = "maf",
39+
density_estimator: Union[
40+
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
41+
] = "maf",
3742
device: str = "cpu",
3843
logging_level: Union[int, str] = "WARNING",
3944
summary_writer: Optional[TensorBoardSummaryWriter] = None,
@@ -46,12 +51,12 @@ def __init__(
4651
parameters, e.g. which ranges are meaningful for them.
4752
density_estimator: If it is a string, use a pre-configured network of the
4853
provided type (one of nsf, maf, mdn, made). Alternatively, a function
49-
that builds a custom neural network can be provided. The function will
54+
that builds a custom neural network, which adheres to
55+
`DensityEstimatorBuilder` protocol can be provided. The function will
5056
be called with the first batch of simulations (theta, x), which can
51-
thus be used for shape inference and potentially for z-scoring. It
52-
needs to return a PyTorch `nn.Module` implementing the density
53-
estimator. The density estimator needs to provide the methods
54-
`.log_prob` and `.sample()`.
57+
thus be used for shape inference and potentially for z-scoring. The
58+
density estimator needs to provide the methods `.log_prob` and
59+
`.sample()`.
5560
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
5661
logging_level: Minimum severity of messages to log. One of the strings
5762
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/npe/npe_base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@
2828
)
2929
from sbi.inference.potentials import posterior_estimator_based_potential
3030
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
31-
from sbi.inference.trainers.base import NeuralInference, check_if_proposal_has_default_x
31+
from sbi.inference.trainers.base import (
32+
NeuralInference,
33+
check_if_proposal_has_default_x,
34+
)
3235
from sbi.neural_nets import posterior_nn
3336
from sbi.neural_nets.estimators import ConditionalDensityEstimator
37+
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
3438
from sbi.neural_nets.estimators.shape_handling import (
3539
reshape_to_batch_event,
3640
reshape_to_sample_batch_event,
@@ -54,7 +58,9 @@ class PosteriorEstimatorTrainer(NeuralInference, ABC):
5458
def __init__(
5559
self,
5660
prior: Optional[Distribution] = None,
57-
density_estimator: Union[str, Callable] = "maf",
61+
density_estimator: Union[
62+
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
63+
] = "maf",
5864
device: str = "cpu",
5965
logging_level: Union[int, str] = "WARNING",
6066
summary_writer: Optional[SummaryWriter] = None,
@@ -69,12 +75,12 @@ def __init__(
6975
Args:
7076
density_estimator: If it is a string, use a pre-configured network of the
7177
provided type (one of nsf, maf, mdn, made). Alternatively, a function
72-
that builds a custom neural network can be provided. The function will
78+
that builds a custom neural network, which adheres to
79+
`DensityEstimatorBuilder` protocol can be provided. The function will
7380
be called with the first batch of simulations (theta, x), which can
74-
thus be used for shape inference and potentially for z-scoring. It
75-
needs to return a PyTorch `nn.Module` implementing the density
76-
estimator. The density estimator needs to provide the methods
77-
`.log_prob` and `.sample()`.
81+
thus be used for shape inference and potentially for z-scoring. The
82+
density estimator needs to provide the methods `.log_prob` and
83+
`.sample()`.
7884
7985
See docstring of `NeuralInference` class for all other arguments.
8086
"""

0 commit comments

Comments
 (0)