Skip to content

Commit 9ca4291

Browse files
authored
fix: bug in sbiutils.py to use one-dimensional batch (#1577)
This may be my mistake, but it seems the original: ``` len(batch_t > 1): ``` will always evaluate to True (unless batch_t has only one data point), even if there's only one batch. To use `standardizing_net()` as intended to load a pre-trained net with a dummy single-batch data, shouldn't it be: ``` len(batch_t) > 1: ``` Thanks !
1 parent 3e845bd commit 9ca4291

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

sbi/utils/sbiutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def standardizing_net(
284284
# Compute per-dimension (independent) mean.
285285
t_mean = torch.mean(batch_t[is_valid_t], dim=0)
286286

287-
if len(batch_t > 1):
287+
if len(batch_t) > 1:
288288
if structured_dims:
289289
# Compute std per-sample first.
290290
sample_std = torch.std(batch_t[is_valid_t], dim=1)

0 commit comments

Comments
 (0)