11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4- from typing import Optional
4+ from typing import Callable , Optional
55
66import torch
7+ from nflows .utils import torchutils
78from torch import Tensor , nn
89from torch .distributions import Categorical
9- from torch .nn import Sigmoid , Softmax
10+ from torch .nn import functional as F
1011
1112from sbi .neural_nets .estimators .base import ConditionalDensityEstimator
13+ from sbi .utils .nn_utils import MADEWrapper as MADE
1214
1315
14- class CategoricalNet ( nn . Module ):
15- """Conditional density (mass) estimation for a categorical random variable.
16+ class CategoricalMADE ( MADE ):
17+ """Conditional density (mass) estimation for a n-dim categorical random variable.
1618
1719 Takes as input parameters theta and learns the parameters p of a Categorical.
1820
@@ -21,108 +23,153 @@ class CategoricalNet(nn.Module):
2123
2224 def __init__ (
2325 self ,
24- num_input : int ,
25- num_categories : int ,
26- num_hidden : int = 20 ,
27- num_layers : int = 2 ,
28- embedding_net : Optional [nn .Module ] = None ,
26+ num_categories : Tensor ,
27+ num_hidden_features : int ,
28+ num_context_features : Optional [int ] = None ,
29+ num_blocks : int = 2 ,
30+ use_residual_blocks : bool = True ,
31+ random_mask : bool = False ,
32+ activation : Callable = F .relu ,
33+ dropout_probability : float = 0.0 ,
34+ use_batch_norm : bool = False ,
35+ epsilon : float = 1e-2 ,
36+ embedding_net : nn .Module = nn .Identity (),
2937 ):
3038 """Initialize the neural net.
3139
3240 Args:
33- num_input: number of input units, i.e., dimensionality of the features.
34- num_categories: number of output units, i.e., number of categories.
35- num_hidden: number of hidden units per layer.
36- num_layers: number of hidden layers.
41+ num_categories: number of categories for each variable. len(categories)
42+ defines the number of input units, i.e., dimensionality of the features.
43+ max(categories) defines the number of output units, i.e., the largest
44+ number of categories. Can handle mutliple variables with differing
45+ numbers of choices.
46+ num_hidden_features: number of hidden units per layer.
47+ num_context_features: number of context features.
48+ num_blocks: number of masked blocks.
49+ use_residual_blocks: whether to use residual blocks.
50+ random_mask: whether to use a random mask.
51+ activation: activation function. default is ReLU.
52+ dropout_probability: dropout probability. default is 0.0.
53+ use_batch_norm: whether to use batch normalization.
3754 embedding_net: emebedding net for input.
3855 """
39- super ().__init__ ()
40-
41- self .num_hidden = num_hidden
42- self .num_input = num_input
43- self .activation = Sigmoid ()
44- self .softmax = Softmax (dim = 1 )
45- self .num_categories = num_categories
46-
47- # Maybe add embedding net in front.
48- if embedding_net is not None :
49- self .input_layer = nn .Sequential (
50- embedding_net , nn .Linear (num_input , num_hidden )
51- )
52- else :
53- self .input_layer = nn .Linear (num_input , num_hidden )
56+ if use_residual_blocks and random_mask :
57+ raise ValueError ("Residual blocks can't be used with random masks." )
58+
59+ self .num_variables = len (num_categories )
60+ self .num_categories = int (torch .max (num_categories ))
61+ self .mask = torch .zeros (self .num_variables , self .num_categories )
62+ for i , c in enumerate (num_categories ):
63+ self .mask [i , :c ] = 1
5464
55- # Repeat hidden units hidden layers times.
56- self .hidden_layers = nn .ModuleList ()
57- for _ in range (num_layers ):
58- self .hidden_layers .append (nn .Linear (num_hidden , num_hidden ))
65+ super ().__init__ (
66+ features = self .num_variables ,
67+ hidden_features = num_hidden_features ,
68+ context_features = num_context_features ,
69+ num_blocks = num_blocks ,
70+ output_multiplier = self .num_categories ,
71+ use_residual_blocks = use_residual_blocks ,
72+ random_mask = random_mask ,
73+ activation = activation ,
74+ dropout_probability = dropout_probability ,
75+ use_batch_norm = use_batch_norm ,
76+ )
5977
60- self .output_layer = nn .Linear (num_hidden , num_categories )
78+ self .embedding_net = embedding_net
79+ self .hidden_features = num_hidden_features
80+ self .epsilon = epsilon
81+ self .context_features = num_context_features
6182
62- def forward (self , condition : Tensor ) -> Tensor :
63- """Return categorical probability predicted from a batch of inputs.
83+ def forward (self , input : Tensor , condition : Optional [Tensor ] = None ) -> Tensor :
84+ r"""Forward pass of the categorical density estimator network to compute the
85+ conditional density at a given time.
6486
6587 Args:
66- condition: batch of context parameters for the net.
88+ input: Inputs datapoints of shape `(batch_size, *input_shape)`
89+ condition: Conditioning variable. `(batch_size, *condition_shape)`
6790
6891 Returns:
69- Tensor: batch of predicted categorical probabilities.
92+ Predicted categorical logits. `(batch_size, *input_shape,
93+ num_categories)`
7094 """
71- # forward path
72- condition = self .activation (self .input_layer (condition ))
95+ embedded_condition = self .embedding_net .forward (condition )
96+ out = super ().forward (input , context = embedded_condition )
97+ # masks out logits i.e. for variables with num_categories < max(num_categories)
98+ return out .masked_fill (~ self .mask .bool ().flatten (), float ("-inf" ))
7399
74- # iterate n hidden layers, input condition and calculate tanh activation
75- for layer in self .hidden_layers :
76- condition = self .activation (layer (condition ))
100+ def log_prob (self , input : Tensor , condition : Optional [Tensor ] = None ) -> Tensor :
101+ r"""Return log-probability of samples.
77102
78- return self .softmax (self .output_layer (condition ))
79-
80- def log_prob (self , input : Tensor , condition : Tensor ) -> Tensor :
81- """Return categorical log probability of categories input, given condition.
103+ Evaluates `Categorical.log_prob`. The logits are given by the MADE.
82104
83105 Args:
84- input: categories to evaluate .
85- condition: parameters .
106+ input: Input datapoints of shape `(batch_size, *input_shape)` .
107+ condition: Conditioning variable. `(batch_size, *condition_shape)` .
86108
87109 Returns:
88- Tensor: log probs with shape (input.shape[0],)
110+ Log-probabilities of shape `(batch_size,)`.
89111 """
90- # Predict categorical ps and evaluate.
91- ps = self .forward (condition )
92- # Squeeze the last dimension (event dim) because `Categorical` has
93- # `event_shape=()` but our data usually has an event_shape of `(1,)`.
94- return Categorical (probs = ps ).log_prob (input .squeeze (dim = - 1 ))
112+ outputs = self .forward (input , condition = condition )
113+
114+ outputs = outputs .reshape (* input .shape , self .num_categories )
115+ log_prob = Categorical (logits = outputs ).log_prob (input ).sum (dim = - 1 )
116+
117+ return log_prob
95118
96- def sample (self , sample_shape : torch .Size , condition : Tensor ) -> Tensor :
97- """Returns samples from categorical random variable with probs predicted from
98- the neural net.
119+ def sample (
120+ self , sample_shape : torch .Size , context : Optional [Tensor ] = None
121+ ) -> Tensor :
122+ """Sample from the conditional categorical distribution.
123+
124+ Autoregressively samples from the conditional categorical distribution.
125+ Calls `Categorical.sample`. The logits are given by the MADE.
99126
100127 Args:
101- sample_shape: number of samples to obtain .
102- condition: batch of parameters for prediction .
128+ sample_shape: Shape of samples.
129+ context: Conditioning variable. `(batch_dim, *condition_shape)` .
103130
104131 Returns:
105- Tensor: Samples with shape (num_samples, 1)
132+ Samples of shape `(*sample_shape, batch_dim)`.
106133 """
134+ num_samples = int (torch .prod (torch .tensor (sample_shape )))
135+
136+ # Prepare context
137+ if context is not None :
138+ batch_dim = context .shape [0 ]
139+ if context .ndim == 2 :
140+ context = context .unsqueeze (0 )
141+ if batch_dim == 1 :
142+ context = torchutils .repeat_rows (context , num_samples )
143+ else :
144+ context_dim = 0 if self .context_features is None else self .context_features
145+ context = torch .zeros (num_samples , context_dim )
146+ batch_dim = 1
107147
108- # Predict Categorical ps and sample.
109- ps = self .forward (condition )
110- return Categorical (probs = ps ).sample (sample_shape = sample_shape )
148+ # Autoregressively sample from the conditional categorical distribution.
149+ # for i = 1, ..., num_variables:
150+ # x_i ~ Categorical(logits=f_i(x_1, ..., x_{i-1}, c))
151+ with torch .no_grad ():
152+ samples = torch .randn (num_samples , batch_dim , self .num_variables )
153+ for i in range (self .num_variables ):
154+ outputs = self .forward (samples , context )
155+ outputs = outputs .reshape (* samples .shape , self .num_categories )
156+ samples [:, :, : i + 1 ] = Categorical (
157+ logits = outputs [:, :, : i + 1 ]
158+ ).sample ()
111159
160+ return samples .reshape (* sample_shape , batch_dim , self .num_variables )
112161
113- class CategoricalMassEstimator (ConditionalDensityEstimator ):
114- """Conditional density (mass) estimation for a categorical random variable.
115162
116- The event_shape of this class is `()`.
117- """
163+ class CategoricalMassEstimator ( ConditionalDensityEstimator ):
164+ """Conditional density (mass) estimation for a categorical random variable."""
118165
119166 def __init__ (
120- self , net : CategoricalNet , input_shape : torch .Size , condition_shape : torch .Size
167+ self , net : CategoricalMADE , input_shape : torch .Size , condition_shape : torch .Size
121168 ) -> None :
122169 """Initialize the mass estimator.
123170
124171 Args:
125- net: CategoricalNet .
172+ net: CategoricalMADE .
126173 input_shape: Shape of the input data.
127174 condition_shape: Shape of the condition data
128175 """
@@ -133,7 +180,7 @@ def __init__(
133180 self .num_categories = net .num_categories
134181
135182 def log_prob (self , input : Tensor , condition : Tensor , ** kwargs ) -> Tensor :
136- """Return log-probability of samples.
183+ """Return log-probability of samples under the categorical distribution .
137184
138185 Args:
139186 input: Input datapoints of shape
0 commit comments