Skip to content

Commit 1d20130

Browse files
authored
re-fix of old silent error (loss shape broadcast mismatch caused by extra dim of size 1) (#1063)
fixed shapes error in normalfactor to squeeze when dim(-1)=1 and no gating
1 parent e9de7ce commit 1d20130

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

bliss/encoder/variational_dist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def compute_nll(self, params, true_tile_cat):
133133
target = torch.where(gating.unsqueeze(-1), target, 0)
134134
assert not torch.isnan(target).any()
135135
ungated_nll = -qk.log_prob(target)
136+
if ungated_nll.dim() == target.dim(): # (b, w, h, 1) -> (b,w,h) silent error
137+
ungated_nll = ungated_nll.squeeze(-1)
136138
return ungated_nll * gating
137139

138140

0 commit comments

Comments
 (0)