|
8 | 8 | from nflows.utils import torchutils |
9 | 9 | from torch import Tensor, nn |
10 | 10 | from torch.distributions import Categorical |
11 | | -from torch.nn import Sigmoid, Softmax |
12 | 11 | from torch.nn import functional as F |
13 | 12 |
|
14 | 13 | from sbi.neural_nets.estimators.base import ConditionalDensityEstimator |
@@ -87,7 +86,7 @@ def forward(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor: |
87 | 86 | condition: Conditioning variable. (batch_size, *condition_shape) |
88 | 87 |
|
89 | 88 | Returns: |
90 | | - Predicted categorical probabilities. (batch_size, *input_shape, |
| 89 | + Predicted categorical logits. (batch_size, *input_shape, |
91 | 90 | num_categories) |
92 | 91 | """ |
93 | 92 | embedded_context = self.embedding_net.forward(context) |
@@ -149,119 +148,19 @@ def _initialize(self): |
149 | 148 | pass |
150 | 149 |
|
151 | 150 |
|
152 | | -class CategoricalNet(nn.Module): |
153 | | - """Conditional density (mass) estimation for a categorical random variable. |
154 | | -
|
155 | | - Takes as input parameters theta and learns the parameters p of a Categorical. |
156 | | -
|
157 | | - Defines log prob and sample functions. |
158 | | - """ |
159 | | - |
160 | | - def __init__( |
161 | | - self, |
162 | | - num_input: int, |
163 | | - num_categories: int, |
164 | | - num_hidden: int = 20, |
165 | | - num_layers: int = 2, |
166 | | - embedding_net: Optional[nn.Module] = None, |
167 | | - ): |
168 | | - """Initialize the neural net. |
169 | | -
|
170 | | - Args: |
171 | | - num_input: number of input units, i.e., dimensionality of the features. |
172 | | - num_categories: number of output units, i.e., number of categories. |
173 | | - num_hidden: number of hidden units per layer. |
174 | | - num_layers: number of hidden layers. |
175 | | - embedding_net: emebedding net for input. |
176 | | - """ |
177 | | - super().__init__() |
178 | | - |
179 | | - self.num_hidden = num_hidden |
180 | | - self.num_input = num_input |
181 | | - self.activation = Sigmoid() |
182 | | - self.softmax = Softmax(dim=1) |
183 | | - self.num_categories = num_categories |
184 | | - self.num_variables = 1 |
185 | | - |
186 | | - # Maybe add embedding net in front. |
187 | | - if embedding_net is not None: |
188 | | - self.input_layer = nn.Sequential( |
189 | | - embedding_net, nn.Linear(num_input, num_hidden) |
190 | | - ) |
191 | | - else: |
192 | | - self.input_layer = nn.Linear(num_input, num_hidden) |
193 | | - |
194 | | - # Repeat hidden units hidden layers times. |
195 | | - self.hidden_layers = nn.ModuleList() |
196 | | - for _ in range(num_layers): |
197 | | - self.hidden_layers.append(nn.Linear(num_hidden, num_hidden)) |
198 | | - |
199 | | - self.output_layer = nn.Linear(num_hidden, num_categories) |
200 | | - |
201 | | - def forward(self, condition: Tensor) -> Tensor: |
202 | | - """Return categorical probability predicted from a batch of inputs. |
203 | | -
|
204 | | - Args: |
205 | | - condition: batch of context parameters for the net. |
206 | | -
|
207 | | - Returns: |
208 | | - Tensor: batch of predicted categorical probabilities. |
209 | | - """ |
210 | | - # forward path |
211 | | - condition = self.activation(self.input_layer(condition)) |
212 | | - |
213 | | - # iterate n hidden layers, input condition and calculate tanh activation |
214 | | - for layer in self.hidden_layers: |
215 | | - condition = self.activation(layer(condition)) |
216 | | - |
217 | | - return self.softmax(self.output_layer(condition)) |
218 | | - |
219 | | - def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: |
220 | | - """Return categorical log probability of categories input, given condition. |
221 | | -
|
222 | | - Args: |
223 | | - input: categories to evaluate. |
224 | | - condition: parameters. |
225 | | -
|
226 | | - Returns: |
227 | | - Tensor: log probs with shape (input.shape[0],) |
228 | | - """ |
229 | | - # Predict categorical ps and evaluate. |
230 | | - ps = self.forward(condition) |
231 | | - # Squeeze the last dimension (event dim) because `Categorical` has |
232 | | - # `event_shape=()` but our data usually has an event_shape of `(1,)`. |
233 | | - return Categorical(probs=ps).log_prob(input.squeeze(dim=-1)) |
234 | | - |
235 | | - def sample(self, sample_shape: torch.Size, condition: Tensor) -> Tensor: |
236 | | - """Returns samples from categorical random variable with probs predicted from |
237 | | - the neural net. |
238 | | -
|
239 | | - Args: |
240 | | - sample_shape: number of samples to obtain. |
241 | | - condition: batch of parameters for prediction. |
242 | | -
|
243 | | - Returns: |
244 | | - Tensor: Samples with shape (num_samples, 1) |
245 | | - """ |
246 | | - |
247 | | - # Predict Categorical ps and sample. |
248 | | - ps = self.forward(condition) |
249 | | - return Categorical(probs=ps).sample(sample_shape=sample_shape) |
250 | | - |
251 | | - |
252 | 151 | class CategoricalMassEstimator(ConditionalDensityEstimator): |
253 | 152 | """Conditional density (mass) estimation for a categorical random variable. |
254 | 153 |
|
255 | 154 | The event_shape of this class is `()`. |
256 | 155 | """ |
257 | 156 |
|
258 | 157 | def __init__( |
259 | | - self, net: CategoricalNet, input_shape: torch.Size, condition_shape: torch.Size |
| 158 | + self, net: CategoricalMADE, input_shape: torch.Size, condition_shape: torch.Size |
260 | 159 | ) -> None: |
261 | 160 | """Initialize the mass estimator. |
262 | 161 |
|
263 | 162 | Args: |
264 | | - net: CategoricalNet. |
| 163 | + net: CategoricalMADE. |
265 | 164 | input_shape: Shape of the input data. |
266 | 165 | condition_shape: Shape of the condition data |
267 | 166 | """ |
|
0 commit comments