From 1d201305c983c71276b9bcc6b4a16d90ee458754 Mon Sep 17 00:00:00 2001 From: Shreyas Date: Tue, 20 Aug 2024 20:12:51 -0400 Subject: [PATCH] re-fix of old silent error (loss shape broadcast mismatch caused by extra dim of size 1) (#1063) fixed shapes error in normalfactor to squeeze when dim(-1)=1 and no gating --- bliss/encoder/variational_dist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bliss/encoder/variational_dist.py b/bliss/encoder/variational_dist.py index a323d7d77..51120868d 100644 --- a/bliss/encoder/variational_dist.py +++ b/bliss/encoder/variational_dist.py @@ -133,6 +133,8 @@ def compute_nll(self, params, true_tile_cat): target = torch.where(gating.unsqueeze(-1), target, 0) assert not torch.isnan(target).any() ungated_nll = -qk.log_prob(target) + if ungated_nll.dim() == target.dim(): # (b, w, h, 1) -> (b,w,h) silent error + ungated_nll = ungated_nll.squeeze(-1) return ungated_nll * gating