Skip to content

Commit 5914851

Browse files
authored
feat: add CategoricalMADE for estimating multiple discrete dimension (#1269)
* wip: first draft on categorical made * wip: forward, log_prob, sample working * wip: CategoricalMassEstimator can be build and MixedDensityEstimator too. log_prob has shape issues tho * wip: sampling and log_prob works for categorical_made. working on getting mixed_density estimator log_probs and sample to work as well * wip: build_mnle works and trains without log_transform. Add made as arg to categorical_model * fix: categorical_made now trains in 1D MNLE * fix: change net kwarg * fix: verify ND training is working with CatMADE. * fix: fix embedding net mistake * fix: address comments * wip: save dev nb * wip: update toy simulator * wip: save wip * rm: rm legacy CategoricalNet * fix: correct i/o shapes, updated tutorial * doc: fix input arg dostrings * wip: fixes from PR implemented * doc: fix docstring shapes * fix: fix MADEWrapper bug and replace made by made wrapper * fix: add multiple disc dims to tests * fix: rem tests * doc: add comments * chore: comments, doc, kwargs, last cleanup * fix: fix linter and kwarg issues
1 parent dd4aef7 commit 5914851

File tree

10 files changed

+216
-126
lines changed

10 files changed

+216
-126
lines changed
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
22
from sbi.neural_nets.estimators.categorical_net import (
3+
CategoricalMADE,
34
CategoricalMassEstimator,
4-
CategoricalNet,
55
)
66
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
7-
from sbi.neural_nets.estimators.mixed_density_estimator import (
8-
MixedDensityEstimator,
9-
)
7+
from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator
108
from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow
119
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
1210
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow

sbi/neural_nets/estimators/categorical_net.py

Lines changed: 117 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Optional
4+
from typing import Callable, Optional
55

66
import torch
7+
from nflows.utils import torchutils
78
from torch import Tensor, nn
89
from torch.distributions import Categorical
9-
from torch.nn import Sigmoid, Softmax
10+
from torch.nn import functional as F
1011

1112
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
13+
from sbi.utils.nn_utils import MADEWrapper as MADE
1214

1315

14-
class CategoricalNet(nn.Module):
15-
"""Conditional density (mass) estimation for a categorical random variable.
16+
class CategoricalMADE(MADE):
17+
"""Conditional density (mass) estimation for a n-dim categorical random variable.
1618
1719
Takes as input parameters theta and learns the parameters p of a Categorical.
1820
@@ -21,108 +23,153 @@ class CategoricalNet(nn.Module):
2123

2224
def __init__(
2325
self,
24-
num_input: int,
25-
num_categories: int,
26-
num_hidden: int = 20,
27-
num_layers: int = 2,
28-
embedding_net: Optional[nn.Module] = None,
26+
num_categories: Tensor,
27+
num_hidden_features: int,
28+
num_context_features: Optional[int] = None,
29+
num_blocks: int = 2,
30+
use_residual_blocks: bool = True,
31+
random_mask: bool = False,
32+
activation: Callable = F.relu,
33+
dropout_probability: float = 0.0,
34+
use_batch_norm: bool = False,
35+
epsilon: float = 1e-2,
36+
embedding_net: nn.Module = nn.Identity(),
2937
):
3038
"""Initialize the neural net.
3139
3240
Args:
33-
num_input: number of input units, i.e., dimensionality of the features.
34-
num_categories: number of output units, i.e., number of categories.
35-
num_hidden: number of hidden units per layer.
36-
num_layers: number of hidden layers.
41+
num_categories: number of categories for each variable. len(categories)
42+
defines the number of input units, i.e., dimensionality of the features.
43+
max(categories) defines the number of output units, i.e., the largest
44+
number of categories. Can handle mutliple variables with differing
45+
numbers of choices.
46+
num_hidden_features: number of hidden units per layer.
47+
num_context_features: number of context features.
48+
num_blocks: number of masked blocks.
49+
use_residual_blocks: whether to use residual blocks.
50+
random_mask: whether to use a random mask.
51+
activation: activation function. default is ReLU.
52+
dropout_probability: dropout probability. default is 0.0.
53+
use_batch_norm: whether to use batch normalization.
3754
embedding_net: emebedding net for input.
3855
"""
39-
super().__init__()
40-
41-
self.num_hidden = num_hidden
42-
self.num_input = num_input
43-
self.activation = Sigmoid()
44-
self.softmax = Softmax(dim=1)
45-
self.num_categories = num_categories
46-
47-
# Maybe add embedding net in front.
48-
if embedding_net is not None:
49-
self.input_layer = nn.Sequential(
50-
embedding_net, nn.Linear(num_input, num_hidden)
51-
)
52-
else:
53-
self.input_layer = nn.Linear(num_input, num_hidden)
56+
if use_residual_blocks and random_mask:
57+
raise ValueError("Residual blocks can't be used with random masks.")
58+
59+
self.num_variables = len(num_categories)
60+
self.num_categories = int(torch.max(num_categories))
61+
self.mask = torch.zeros(self.num_variables, self.num_categories)
62+
for i, c in enumerate(num_categories):
63+
self.mask[i, :c] = 1
5464

55-
# Repeat hidden units hidden layers times.
56-
self.hidden_layers = nn.ModuleList()
57-
for _ in range(num_layers):
58-
self.hidden_layers.append(nn.Linear(num_hidden, num_hidden))
65+
super().__init__(
66+
features=self.num_variables,
67+
hidden_features=num_hidden_features,
68+
context_features=num_context_features,
69+
num_blocks=num_blocks,
70+
output_multiplier=self.num_categories,
71+
use_residual_blocks=use_residual_blocks,
72+
random_mask=random_mask,
73+
activation=activation,
74+
dropout_probability=dropout_probability,
75+
use_batch_norm=use_batch_norm,
76+
)
5977

60-
self.output_layer = nn.Linear(num_hidden, num_categories)
78+
self.embedding_net = embedding_net
79+
self.hidden_features = num_hidden_features
80+
self.epsilon = epsilon
81+
self.context_features = num_context_features
6182

62-
def forward(self, condition: Tensor) -> Tensor:
63-
"""Return categorical probability predicted from a batch of inputs.
83+
def forward(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor:
84+
r"""Forward pass of the categorical density estimator network to compute the
85+
conditional density at a given time.
6486
6587
Args:
66-
condition: batch of context parameters for the net.
88+
input: Inputs datapoints of shape `(batch_size, *input_shape)`
89+
condition: Conditioning variable. `(batch_size, *condition_shape)`
6790
6891
Returns:
69-
Tensor: batch of predicted categorical probabilities.
92+
Predicted categorical logits. `(batch_size, *input_shape,
93+
num_categories)`
7094
"""
71-
# forward path
72-
condition = self.activation(self.input_layer(condition))
95+
embedded_condition = self.embedding_net.forward(condition)
96+
out = super().forward(input, context=embedded_condition)
97+
# masks out logits i.e. for variables with num_categories < max(num_categories)
98+
return out.masked_fill(~self.mask.bool().flatten(), float("-inf"))
7399

74-
# iterate n hidden layers, input condition and calculate tanh activation
75-
for layer in self.hidden_layers:
76-
condition = self.activation(layer(condition))
100+
def log_prob(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor:
101+
r"""Return log-probability of samples.
77102
78-
return self.softmax(self.output_layer(condition))
79-
80-
def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
81-
"""Return categorical log probability of categories input, given condition.
103+
Evaluates `Categorical.log_prob`. The logits are given by the MADE.
82104
83105
Args:
84-
input: categories to evaluate.
85-
condition: parameters.
106+
input: Input datapoints of shape `(batch_size, *input_shape)`.
107+
condition: Conditioning variable. `(batch_size, *condition_shape)`.
86108
87109
Returns:
88-
Tensor: log probs with shape (input.shape[0],)
110+
Log-probabilities of shape `(batch_size,)`.
89111
"""
90-
# Predict categorical ps and evaluate.
91-
ps = self.forward(condition)
92-
# Squeeze the last dimension (event dim) because `Categorical` has
93-
# `event_shape=()` but our data usually has an event_shape of `(1,)`.
94-
return Categorical(probs=ps).log_prob(input.squeeze(dim=-1))
112+
outputs = self.forward(input, condition=condition)
113+
114+
outputs = outputs.reshape(*input.shape, self.num_categories)
115+
log_prob = Categorical(logits=outputs).log_prob(input).sum(dim=-1)
116+
117+
return log_prob
95118

96-
def sample(self, sample_shape: torch.Size, condition: Tensor) -> Tensor:
97-
"""Returns samples from categorical random variable with probs predicted from
98-
the neural net.
119+
def sample(
120+
self, sample_shape: torch.Size, context: Optional[Tensor] = None
121+
) -> Tensor:
122+
"""Sample from the conditional categorical distribution.
123+
124+
Autoregressively samples from the conditional categorical distribution.
125+
Calls `Categorical.sample`. The logits are given by the MADE.
99126
100127
Args:
101-
sample_shape: number of samples to obtain.
102-
condition: batch of parameters for prediction.
128+
sample_shape: Shape of samples.
129+
context: Conditioning variable. `(batch_dim, *condition_shape)`.
103130
104131
Returns:
105-
Tensor: Samples with shape (num_samples, 1)
132+
Samples of shape `(*sample_shape, batch_dim)`.
106133
"""
134+
num_samples = int(torch.prod(torch.tensor(sample_shape)))
135+
136+
# Prepare context
137+
if context is not None:
138+
batch_dim = context.shape[0]
139+
if context.ndim == 2:
140+
context = context.unsqueeze(0)
141+
if batch_dim == 1:
142+
context = torchutils.repeat_rows(context, num_samples)
143+
else:
144+
context_dim = 0 if self.context_features is None else self.context_features
145+
context = torch.zeros(num_samples, context_dim)
146+
batch_dim = 1
107147

108-
# Predict Categorical ps and sample.
109-
ps = self.forward(condition)
110-
return Categorical(probs=ps).sample(sample_shape=sample_shape)
148+
# Autoregressively sample from the conditional categorical distribution.
149+
# for i = 1, ..., num_variables:
150+
# x_i ~ Categorical(logits=f_i(x_1, ..., x_{i-1}, c))
151+
with torch.no_grad():
152+
samples = torch.randn(num_samples, batch_dim, self.num_variables)
153+
for i in range(self.num_variables):
154+
outputs = self.forward(samples, context)
155+
outputs = outputs.reshape(*samples.shape, self.num_categories)
156+
samples[:, :, : i + 1] = Categorical(
157+
logits=outputs[:, :, : i + 1]
158+
).sample()
111159

160+
return samples.reshape(*sample_shape, batch_dim, self.num_variables)
112161

113-
class CategoricalMassEstimator(ConditionalDensityEstimator):
114-
"""Conditional density (mass) estimation for a categorical random variable.
115162

116-
The event_shape of this class is `()`.
117-
"""
163+
class CategoricalMassEstimator(ConditionalDensityEstimator):
164+
"""Conditional density (mass) estimation for a categorical random variable."""
118165

119166
def __init__(
120-
self, net: CategoricalNet, input_shape: torch.Size, condition_shape: torch.Size
167+
self, net: CategoricalMADE, input_shape: torch.Size, condition_shape: torch.Size
121168
) -> None:
122169
"""Initialize the mass estimator.
123170
124171
Args:
125-
net: CategoricalNet.
172+
net: CategoricalMADE.
126173
input_shape: Shape of the input data.
127174
condition_shape: Shape of the condition data
128175
"""
@@ -133,7 +180,7 @@ def __init__(
133180
self.num_categories = net.num_categories
134181

135182
def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
136-
"""Return log-probability of samples.
183+
"""Return log-probability of samples under the categorical distribution.
137184
138185
Args:
139186
input: Input datapoints of shape

sbi/neural_nets/estimators/mixed_density_estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ def sample(
8080
sample_shape=sample_shape,
8181
condition=condition,
8282
)
83-
# Trailing `1` because `Categorical` has event_shape `()`.
84-
discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1)
83+
num_variables = self.discrete_net.net.num_variables
84+
discrete_samples = discrete_samples.reshape(
85+
num_samples * batch_dim, num_variables
86+
)
8587

8688
# repeat the batch of embedded condition to match number of choices.
8789
condition_event_dim = embedded_condition.dim() - 1
@@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
145147
f"{input_batch_dim} do not match."
146148
)
147149

148-
cont_input, disc_input = _separate_input(input)
150+
num_discrete_variables = self.discrete_net.net.num_variables
151+
cont_input, disc_input = _separate_input(input, num_discrete_variables)
149152
# Embed continuous condition
150153
embedded_condition = self.condition_embedding(condition)
151154
# expand and repeat to match batch of inputs.
@@ -204,3 +207,8 @@ def _separate_input(
204207
Assumes the discrete data to live in the last columns of input.
205208
"""
206209
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:]
210+
211+
212+
def _is_discrete(input: Tensor) -> Tensor:
213+
"""Infer discrete columns in input data."""
214+
return torch.tensor([torch.allclose(col, col.round()) for col in input.T])

sbi/neural_nets/net_builders/categorial.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import warnings
45
from typing import Optional
56

6-
from torch import Tensor, nn, unique
7+
from torch import Tensor, nn, tensor, unique
78

8-
from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet
9-
from sbi.utils.nn_utils import get_numel
10-
from sbi.utils.sbiutils import (
11-
standardizing_net,
12-
z_score_parser,
9+
from sbi.neural_nets.estimators import (
10+
CategoricalMADE,
11+
CategoricalMassEstimator,
1312
)
13+
from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete
14+
from sbi.utils.nn_utils import get_numel
15+
from sbi.utils.sbiutils import standardizing_net, z_score_parser
1416
from sbi.utils.user_input_checks import check_data_device
1517

1618

@@ -21,6 +23,7 @@ def build_categoricalmassestimator(
2123
z_score_y: Optional[str] = "independent",
2224
num_hidden: int = 20,
2325
num_layers: int = 2,
26+
num_categories: Optional[Tensor] = None,
2427
embedding_net: nn.Module = nn.Identity(),
2528
):
2629
"""Returns a density estimator for a categorical random variable.
@@ -33,28 +36,39 @@ def build_categoricalmassestimator(
3336
num_hidden: Number of hidden units per layer.
3437
num_layers: Number of hidden layers.
3538
embedding_net: Embedding net for y.
39+
num_categories: number of categories for each variable.
3640
"""
3741

3842
if z_score_x != "none":
3943
raise ValueError("Categorical input should not be z-scored.")
44+
if num_categories is None:
45+
warnings.warn(
46+
"Inferring num_categories from batch_x. Ensure all categories are present.",
47+
stacklevel=2,
48+
)
4049

4150
check_data_device(batch_x, batch_y)
42-
if batch_x.shape[1] > 1:
43-
raise NotImplementedError("CategoricalMassEstimator only supports 1D input.")
44-
num_categories = unique(batch_x).numel()
45-
dim_condition = get_numel(batch_y, embedding_net=embedding_net)
4651

4752
z_score_y_bool, structured_y = z_score_parser(z_score_y)
53+
y_numel = get_numel(batch_y, embedding_net=embedding_net)
54+
4855
if z_score_y_bool:
4956
embedding_net = nn.Sequential(
5057
standardizing_net(batch_y, structured_y), embedding_net
5158
)
5259

53-
categorical_net = CategoricalNet(
54-
num_input=dim_condition,
60+
if num_categories is None:
61+
batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
62+
inferred_categories = tensor([
63+
unique(col).numel() for col in batch_x_discrete.T
64+
])
65+
num_categories = inferred_categories
66+
67+
categorical_net = CategoricalMADE(
5568
num_categories=num_categories,
56-
num_hidden=num_hidden,
57-
num_layers=num_layers,
69+
num_hidden_features=num_hidden,
70+
num_context_features=y_numel,
71+
num_blocks=num_layers,
5872
embedding_net=embedding_net,
5973
)
6074

0 commit comments

Comments
 (0)