Skip to content

Commit 964e8d2

Browse files
committed
Fix classification labels for batch size 1
1 parent 1750555 commit 964e8d2

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

asparagus/modules/lightning_modules/clsreg_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def configure_metrics(self, prefix: str):
188188
)
189189

190190
def on_before_batch_transfer(self, batch, dataloader_idx):
191-
batch["CLSREG_label"] = batch["CLSREG_label"].squeeze().long()
191+
batch["CLSREG_label"] = batch["CLSREG_label"].view(-1).long()
192192
return batch
193193

194194
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):

0 commit comments

Comments
 (0)