Skip to content

Commit e472d28

Browse files
committed
remove labels from eval
1 parent ca1c9a6 commit e472d28

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

olmoearth_pretrain/evals/finetune.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def forward(
6767
) -> torch.Tensor:
6868
"""Forward pass through the model and head."""
6969
dev = next(self.wrapper.parameters()).device
70-
# classification: (B, D), segmentation: (B, H, W, D)
70+
# The wrapper requires both batch and labels as input, mainly for model
71+
# like AnySat that modify the shape of labels during training
7172
emb, labels = self.wrapper(batch, labels)
7273
emb = cast(torch.Tensor, emb)
7374
emb_dim = emb.shape[-1]
@@ -105,7 +106,7 @@ def _eval_cls(
105106
label = label.to(device=device)
106107
masked = _to_device(masked, device)
107108
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
108-
logits, label = module(masked, label) # (B, C)
109+
logits, _ = module(masked, label) # (B, C)
109110
logits_all.append(logits.float().cpu())
110111
labels_all.append(label.cpu())
111112
logits = torch.cat(logits_all, 0)
@@ -138,7 +139,7 @@ def _eval_seg(
138139
label = label.to(device=device)
139140
masked = _to_device(masked, device)
140141
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
141-
logits, label = module(masked, label) # (B, H, W, C*p*p)
142+
logits, _ = module(masked, label) # (B, H, W, C*p*p)
142143
H, W = logits.shape[1], logits.shape[2]
143144
logits = rearrange(
144145
logits,

0 commit comments

Comments
 (0)