Skip to content

Commit 046cdb0

Browse files
authored
fix: Update RatioEstimator classifier argument to use a Protocol (#1582)
* fix: Update RatioEstimator classifier argument to use a Protocol * test: add tests for RatioEstimatorBuilder protocol * Combine NRE classifier builder tests into a single function * Fix formatting and import sorting * Remove alias import for RatioEstimator
1 parent 164cf25 commit 046cdb0

File tree

6 files changed

+123
-39
lines changed

6 files changed

+123
-39
lines changed

sbi/inference/trainers/nre/bnre.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Callable, Dict, Optional, Union
4+
from typing import Dict, Optional, Union
55

66
import torch
77
from torch import Tensor, nn, ones
88
from torch.distributions import Distribution
99

1010
from sbi.inference.trainers.nre.nre_a import NRE_A
11+
from sbi.inference.trainers.nre.nre_base import RatioEstimatorBuilder
1112
from sbi.sbi_types import TensorboardSummaryWriter
1213
from sbi.utils.sbiutils import del_entries
1314
from sbi.utils.torchutils import assert_all_finite
@@ -28,7 +29,7 @@ class BNRE(NRE_A):
2829
def __init__(
2930
self,
3031
prior: Optional[Distribution] = None,
31-
classifier: Union[str, Callable] = "resnet",
32+
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
3233
device: str = "cpu",
3334
logging_level: Union[int, str] = "warning",
3435
summary_writer: Optional[TensorboardSummaryWriter] = None,
@@ -42,11 +43,11 @@ def __init__(
4243
prior must be passed to `.build_posterior()`.
4344
classifier: Classifier trained to approximate likelihood ratios. If it is
4445
a string, use a pre-configured network of the provided type (one of
45-
linear, mlp, resnet). Alternatively, a function that builds a custom
46-
neural network can be provided. The function will be called with the
47-
first batch of simulations $(\theta, x)$, which can thus be used for
48-
shape inference and potentially for z-scoring. It needs to return a
49-
PyTorch `nn.Module` implementing the classifier.
46+
linear, mlp, resnet), or a callable that implements the
47+
`RatioEstimatorBuilder` protocol. The callable will be called with the
48+
first batch of simulations (theta, x), which can thus be used for
49+
shape inference and potentially for z-scoring. It returns a
50+
`RatioEstimator`.
5051
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
5152
logging_level: Minimum severity of messages to log. One of the strings
5253
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/nre/nre_a.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Dict, Optional, Union
4+
from typing import Any, Dict, Optional, Union
55

66
import torch
77
from torch import Tensor, nn, ones
88
from torch.distributions import Distribution
99

10-
from sbi.inference.trainers.nre.nre_base import RatioEstimatorTrainer
10+
from sbi.inference.trainers.nre.nre_base import (
11+
RatioEstimatorBuilder,
12+
RatioEstimatorTrainer,
13+
)
1114
from sbi.sbi_types import TensorboardSummaryWriter
1215
from sbi.utils.sbiutils import del_entries
1316
from sbi.utils.torchutils import assert_all_finite
@@ -23,7 +26,7 @@ class NRE_A(RatioEstimatorTrainer):
2326
def __init__(
2427
self,
2528
prior: Optional[Distribution] = None,
26-
classifier: Union[str, Callable] = "resnet",
29+
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
2730
device: str = "cpu",
2831
logging_level: Union[int, str] = "warning",
2932
summary_writer: Optional[TensorboardSummaryWriter] = None,
@@ -37,11 +40,11 @@ def __init__(
3740
prior must be passed to `.build_posterior()`.
3841
classifier: Classifier trained to approximate likelihood ratios. If it is
3942
a string, use a pre-configured network of the provided type (one of
40-
linear, mlp, resnet). Alternatively, a function that builds a custom
41-
neural network can be provided. The function will be called with the
42-
first batch of simulations (theta, x), which can thus be used for shape
43-
inference and potentially for z-scoring. It needs to return a PyTorch
44-
`nn.Module` implementing the classifier.
43+
linear, mlp, resnet), or a callable that implements the
44+
`RatioEstimatorBuilder` protocol. The callable will be called with the
45+
first batch of simulations (theta, x), which can thus be used for
46+
shape inference and potentially for z-scoring. It returns a
47+
`RatioEstimator`.
4548
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
4649
logging_level: Minimum severity of messages to log. One of the strings
4750
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/nre/nre_b.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Callable, Dict, Optional, Union
4+
from typing import Dict, Optional, Union
55

66
import torch
77
from torch import Tensor, nn
88
from torch.distributions import Distribution
99

10-
from sbi.inference.trainers.nre.nre_base import RatioEstimatorTrainer
10+
from sbi.inference.trainers.nre.nre_base import (
11+
RatioEstimatorBuilder,
12+
RatioEstimatorTrainer,
13+
)
1114
from sbi.sbi_types import TensorboardSummaryWriter
1215
from sbi.utils.sbiutils import del_entries
1316
from sbi.utils.torchutils import assert_all_finite
@@ -23,7 +26,7 @@ class NRE_B(RatioEstimatorTrainer):
2326
def __init__(
2427
self,
2528
prior: Optional[Distribution] = None,
26-
classifier: Union[str, Callable] = "resnet",
29+
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
2730
device: str = "cpu",
2831
logging_level: Union[int, str] = "warning",
2932
summary_writer: Optional[TensorboardSummaryWriter] = None,
@@ -37,11 +40,11 @@ def __init__(
3740
prior must be passed to `.build_posterior()`.
3841
classifier: Classifier trained to approximate likelihood ratios. If it is
3942
a string, use a pre-configured network of the provided type (one of
40-
linear, mlp, resnet). Alternatively, a function that builds a custom
41-
neural network can be provided. The function will be called with the
42-
first batch of simulations (theta, x), which can thus be used for shape
43-
inference and potentially for z-scoring. It needs to return a PyTorch
44-
`nn.Module` implementing the classifier.
43+
linear, mlp, resnet), or a callable that implements the
44+
`RatioEstimatorBuilder` protocol. The callable will be called with the
45+
first batch of simulations (theta, x), which can thus be used for
46+
shape inference and potentially for z-scoring. It returns a
47+
`RatioEstimator`.
4548
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
4649
logging_level: Minimum severity of messages to log. One of the strings
4750
INFO, WARNING, DEBUG, ERROR and CRITICAL.

sbi/inference/trainers/nre/nre_base.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from abc import ABC, abstractmethod
66
from copy import deepcopy
7-
from typing import Any, Callable, Dict, Optional, Union
7+
from typing import Any, Dict, Optional, Protocol, Union
88

99
import torch
1010
from torch import Tensor, eye, nn, ones
@@ -18,6 +18,7 @@
1818
from sbi.inference.potentials import ratio_estimator_based_potential
1919
from sbi.inference.trainers.base import NeuralInference
2020
from sbi.neural_nets import classifier_nn
21+
from sbi.neural_nets.ratio_estimators import RatioEstimator
2122
from sbi.utils import (
2223
check_estimator_arg,
2324
check_prior,
@@ -26,11 +27,28 @@
2627
from sbi.utils.torchutils import repeat_rows
2728

2829

30+
class RatioEstimatorBuilder(Protocol):
31+
"""Protocol for building a ratio estimator from data."""
32+
33+
def __call__(self, theta: Tensor, x: Tensor) -> RatioEstimator:
34+
"""Build a ratio estimator from theta and x, which mainly inform the
35+
shape of the input and the condition to the neural network.
36+
37+
Args:
38+
theta: Parameter sets.
39+
x: Simulation outputs.
40+
41+
Returns:
42+
Ratio Estimator.
43+
"""
44+
...
45+
46+
2947
class RatioEstimatorTrainer(NeuralInference, ABC):
3048
def __init__(
3149
self,
3250
prior: Optional[Distribution] = None,
33-
classifier: Union[str, Callable] = "resnet",
51+
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
3452
device: str = "cpu",
3553
logging_level: Union[int, str] = "warning",
3654
summary_writer: Optional[SummaryWriter] = None,
@@ -56,11 +74,11 @@ def __init__(
5674
Args:
5775
classifier: Classifier trained to approximate likelihood ratios. If it is
5876
a string, use a pre-configured network of the provided type (one of
59-
linear, mlp, resnet). Alternatively, a function that builds a custom
60-
neural network can be provided. The function will be called with the
61-
first batch of simulations (theta, x), which can thus be used for shape
62-
inference and potentially for z-scoring. It needs to return a PyTorch
63-
`nn.Module` implementing the classifier.
77+
linear, mlp, resnet), or a callable that implements the
78+
`RatioEstimatorBuilder` protocol. The callable will be called with the
79+
first batch of simulations (theta, x), which can thus be used for
80+
shape inference and potentially for z-scoring. It returns a
81+
`RatioEstimator`.
6482
6583
See docstring of `NeuralInference` class for all other arguments.
6684
"""

sbi/inference/trainers/nre/nre_c.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Callable, Dict, Optional, Tuple, Union
4+
from typing import Dict, Optional, Tuple, Union
55

66
import torch
77
from torch import Tensor, nn
88
from torch.distributions import Distribution
99

10-
from sbi.inference.trainers.nre.nre_base import RatioEstimatorTrainer
10+
from sbi.inference.trainers.nre.nre_base import (
11+
RatioEstimatorBuilder,
12+
RatioEstimatorTrainer,
13+
)
1114
from sbi.sbi_types import TensorboardSummaryWriter
1215
from sbi.utils.sbiutils import del_entries
1316
from sbi.utils.torchutils import assert_all_finite
@@ -37,7 +40,7 @@ class NRE_C(RatioEstimatorTrainer):
3740
def __init__(
3841
self,
3942
prior: Optional[Distribution] = None,
40-
classifier: Union[str, Callable] = "resnet",
43+
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
4144
device: str = "cpu",
4245
logging_level: Union[int, str] = "warning",
4346
summary_writer: Optional[TensorboardSummaryWriter] = None,
@@ -51,11 +54,11 @@ def __init__(
5154
prior must be passed to `.build_posterior()`.
5255
classifier: Classifier trained to approximate likelihood ratios. If it is
5356
a string, use a pre-configured network of the provided type (one of
54-
linear, mlp, resnet). Alternatively, a function that builds a custom
55-
neural network can be provided. The function will be called with the
56-
first batch of simulations (theta, x), which can thus be used for shape
57-
inference and potentially for z-scoring. It needs to return a PyTorch
58-
`nn.Module` implementing the classifier.
57+
linear, mlp, resnet), or a callable that implements the
58+
`RatioEstimatorBuilder` protocol. The callable will be called with the
59+
first batch of simulations (theta, x), which can thus be used for
60+
shape inference and potentially for z-scoring. It returns a
61+
`RatioEstimator`.
5962
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
6063
logging_level: Minimum severity of messages to log. One of the strings
6164
INFO, WARNING, DEBUG, ERROR and CRITICAL.

tests/ratio_estimator_test.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
import pytest
77
import torch
8-
from torch import eye, zeros
8+
from torch import Tensor, eye, zeros
99
from torch.distributions import MultivariateNormal
1010

11+
from sbi.inference import NRE
1112
from sbi.neural_nets.embedding_nets import CNNEmbedding
1213
from sbi.neural_nets.net_builders import build_linear_classifier
1314
from sbi.neural_nets.ratio_estimators import RatioEstimator
15+
from sbi.utils.torchutils import BoxUniform
1416

1517

1618
class EmbeddingNet(torch.nn.Module):
@@ -72,3 +74,57 @@ def test_api_ratio_estimator(ratio_estimator, theta_shape, x_shape):
7274
nsamples,
7375
), f"""unnormalized_log_ratio shape is not correct. It is of shape
7476
{unnormalized_log_ratio.shape}, but should be {(nsamples,)}"""
77+
78+
79+
def build_classifier(theta, x):
80+
net = torch.nn.Linear(theta.shape[1] + x.shape[1], 1)
81+
return RatioEstimator(net=net, theta_shape=theta[0].shape, x_shape=x[0].shape)
82+
83+
84+
def build_classifier_missing_args():
85+
pass
86+
87+
88+
def build_classifier_missing_return(theta: Tensor, x: Tensor):
89+
pass
90+
91+
92+
@pytest.mark.parametrize(
93+
"classifier_builder",
94+
[
95+
build_classifier,
96+
pytest.param(
97+
build_classifier_missing_args,
98+
marks=pytest.mark.xfail(
99+
raises=TypeError,
100+
reason="Missing required parameters in classifier builder.",
101+
),
102+
),
103+
pytest.param(
104+
build_classifier_missing_return,
105+
marks=pytest.mark.xfail(
106+
raises=AttributeError,
107+
reason="Missing return of RatioEstimator in classifier builder.",
108+
),
109+
),
110+
],
111+
)
112+
def test_nre_with_valid_and_invalid_classifier_builders(classifier_builder):
113+
r"""Test NRE works with valid classifier builders and fails with invalid ones.
114+
115+
Args:
116+
classifier_builder: Function to build the classifier.
117+
"""
118+
119+
def simulator(theta):
120+
return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
121+
122+
num_dim = 3
123+
prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
124+
theta = prior.sample((300,))
125+
x = simulator(theta)
126+
127+
inference = NRE(classifier=classifier_builder)
128+
inference.append_simulations(theta, x)
129+
130+
inference.train(max_num_epochs=1)

0 commit comments

Comments
 (0)