Skip to content

Commit ea32b92

Browse files
ARna06ARna06
andauthored
fix: minor fix for unconditional estimator and LRU (#1556)
* change type of aggregate_fcn * change the type of sample * addressed reviewer's comments --------- Co-authored-by: ARna06 <72038543+Leopard005537@users.noreply.github.com>
1 parent d1d7845 commit ea32b92

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

sbi/diagnostics/misspecification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def calc_misspecification_logprob(
209209
log_prob_xo = estimator.log_prob(x_o).detach().item()
210210

211211
n_samples = x_val.shape[0]
212-
samples = estimator.sample((n_samples,))
212+
samples = estimator.sample(torch.Size((n_samples,)))
213213
try:
214214
check_c2st(x_val, samples, 'MarginalEstimator')
215215
except AssertionError as e:

sbi/neural_nets/embedding_nets/lru.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from math import sqrt
3-
from typing import Callable, Optional, Tuple
3+
from typing import Callable, Optional, Tuple, Union
44

55
import numpy as np
66
import torch
@@ -36,7 +36,7 @@ def __init__(
3636
mode: str = "loop",
3737
dropout: float = 0.0,
3838
apply_input_normalization: bool = False,
39-
aggregate_fcn: [str, Callable] = "mean",
39+
aggregate_fcn: Union[str, Callable] = "mean",
4040
):
4141
"""
4242
Args:

sbi/neural_nets/estimators/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def log_prob(self, x: Tensor) -> Tensor:
553553
self._neural_net.eval()
554554
return self._neural_net.log_prob(x)
555555

556-
def sample(self, sample_shape: torch.Size()) -> Tensor:
556+
def sample(self, sample_shape: torch.Size) -> Tensor:
557557
r"""Return samples from the density estimator.
558558
559559
Args:

0 commit comments

Comments
 (0)