@@ -67,7 +67,8 @@ def forward(
6767 ) -> torch .Tensor :
6868 """Forward pass through the model and head."""
6969 dev = next (self .wrapper .parameters ()).device
70- # classification: (B, D), segmentation: (B, H, W, D)
70+ # The wrapper requires both batch and labels as input, mainly for model
71+ # like AnySat that modify the shape of labels during training
7172 emb , labels = self .wrapper (batch , labels )
7273 emb = cast (torch .Tensor , emb )
7374 emb_dim = emb .shape [- 1 ]
@@ -105,7 +106,7 @@ def _eval_cls(
105106 label = label .to (device = device )
106107 masked = _to_device (masked , device )
107108 with torch .amp .autocast (device_type = device .type , dtype = torch .bfloat16 ):
108- logits , label = module (masked , label ) # (B, C)
109+ logits , _ = module (masked , label ) # (B, C)
109110 logits_all .append (logits .float ().cpu ())
110111 labels_all .append (label .cpu ())
111112 logits = torch .cat (logits_all , 0 )
@@ -138,7 +139,7 @@ def _eval_seg(
138139 label = label .to (device = device )
139140 masked = _to_device (masked , device )
140141 with torch .amp .autocast (device_type = device .type , dtype = torch .bfloat16 ):
141- logits , label = module (masked , label ) # (B, H, W, C*p*p)
142+ logits , _ = module (masked , label ) # (B, H, W, C*p*p)
142143 H , W = logits .shape [1 ], logits .shape [2 ]
143144 logits = rearrange (
144145 logits ,
0 commit comments