diff --git a/olmoearth_pretrain/evals/eval_wrapper.py b/olmoearth_pretrain/evals/eval_wrapper.py index 5c2c067e2..8736e373d 100644 --- a/olmoearth_pretrain/evals/eval_wrapper.py +++ b/olmoearth_pretrain/evals/eval_wrapper.py @@ -48,6 +48,7 @@ def __init__( pooling_type: PoolingType, concat_features: bool = False, use_pooled_tokens: bool = False, + use_center_token: bool = False, ): """Initialize the eval wrapper. @@ -58,7 +59,8 @@ def __init__( pooling_type: The pooling type to use for the model. concat_features: Whether to concatenate features across modalities. use_pooled_tokens: Whether to use pooled tokens. - is_train: whether this is being used on the training data. + use_center_token: Whether to use the center spatial patch embedding instead + of pooling across all patches for classification tasks. """ super().__init__() self.model = model @@ -68,6 +70,12 @@ def __init__( self.concat_features = concat_features self.spatial_pool = task_type == TaskType.SEGMENTATION self.use_pooled_tokens = use_pooled_tokens + self.use_center_token = use_center_token + if self.use_center_token and self.spatial_pool: + raise ValueError( + "use_center_token is only supported for classification tasks, " + "not segmentation (spatial_pool=True)" + ) if self.use_pooled_tokens: assert isinstance(self.model, EncodeEarlyAttnPool), ( "Pooled tokens are only supported for EncodeEarlyAttnPool" @@ -91,6 +99,19 @@ def __getattr__(self, name: str) -> Any: """Delegate attribute access to the underlying model if the attribute is not found on the wrapper.""" return getattr(self.model, name) + @staticmethod + def _extract_center_token(spatial_embeddings: torch.Tensor) -> torch.Tensor: + """Extract the center spatial patch embedding. + + Args: + spatial_embeddings: Tensor of shape (B, H, W, D). + + Returns: + Tensor of shape (B, D) from the center patch. + """ + H, W = spatial_embeddings.shape[1], spatial_embeddings.shape[2] + return spatial_embeddings[:, H // 2, W // 2, :] + def __call__( self, masked_olmoearth_sample: MaskedOlmoEarthSample, @@ -116,11 +137,20 @@ def __call__( masked_olmoearth_sample, patch_size=self.patch_size, fast_pass=True )["tokens_and_masks"] # (bsz, dim) # Concat features across modalities in space averaged across time - batch_embeddings = batch_embeddings.pool_unmasked_tokens( - self.pooling_type, - spatial_pooling=self.spatial_pool, - concat_features=self.concat_features, - ) + if self.use_center_token: + # Get spatial embeddings (B, H, W, D) then take center patch + batch_embeddings = batch_embeddings.pool_unmasked_tokens( + self.pooling_type, + spatial_pooling=True, + concat_features=self.concat_features, + ) + batch_embeddings = self._extract_center_token(batch_embeddings) + else: + batch_embeddings = batch_embeddings.pool_unmasked_tokens( + self.pooling_type, + spatial_pooling=self.spatial_pool, + concat_features=self.concat_features, + ) else: pooled_tokens_dict = self.model( masked_olmoearth_sample, patch_size=self.patch_size, fast_pass=True @@ -138,8 +168,16 @@ def __call__( pooled_tokens = reduce( pooled_tokens, "b h w ... d -> b h w d", self.pooling_type ) + elif self.use_center_token: + # Pool time but keep spatial, then take center patch + if pooled_tokens.shape[1] == 1 and pooled_tokens.ndim == 3: + pooled_tokens = pooled_tokens.unsqueeze(1) + pooled_tokens = reduce( + pooled_tokens, "b h w ... d -> b h w d", self.pooling_type + ) + pooled_tokens = self._extract_center_token(pooled_tokens) else: - # Take the mean of all dims excetp the first and last + # Take the mean of all dims except the first and last pooled_tokens = reduce( pooled_tokens, "b ... d -> b d", self.pooling_type ) @@ -162,11 +200,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token batch_embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) return batch_embeddings, labels @@ -180,11 +221,13 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" - if self.spatial_pool: + if self.spatial_pool or self.use_center_token: # Intermediate features are not yet working because of some bug internal to the model batch_embeddings = self.model.forward_features( masked_olmoearth_sample, pooling=self.pooling_type ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) else: batch_embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type @@ -202,11 +245,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + embeddings = self._extract_center_token(embeddings) return embeddings, labels @@ -220,11 +266,15 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + embeddings = self._extract_center_token(embeddings) + return embeddings, labels if is_train and (self.task_type == TaskType.SEGMENTATION): # this is a special case for AnySat. Since it outputs per-pixel embeddings, # we subsample training pixels to keep the memory requirements reasonable. @@ -267,11 +317,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + embeddings = self._extract_center_token(embeddings) return embeddings, labels @@ -285,11 +338,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token batch_embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) return batch_embeddings, labels @@ -303,11 +359,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token batch_embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) return batch_embeddings, labels @@ -321,11 +380,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token batch_embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) return batch_embeddings, labels @@ -340,12 +402,14 @@ def __call__( ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" # i need to do the apply imagenet normalizer thing in here - if self.spatial_pool: + if self.spatial_pool or self.use_center_token: # Intermediate features are not yet working because of some bug internal to the model batch_embeddings = self.model.forward_features( masked_olmoearth_sample, pooling=self.pooling_type, ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) else: # should this call model ditectly batch_embeddings = self.model( @@ -365,11 +429,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token batch_embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) return batch_embeddings, labels @@ -383,11 +450,14 @@ def __call__( is_train: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through the model produces the embedding specified by initialization.""" + spatial_pool = self.spatial_pool or self.use_center_token batch_embeddings = self.model( masked_olmoearth_sample, pooling=self.pooling_type, - spatial_pool=self.spatial_pool, + spatial_pool=spatial_pool, ) + if self.use_center_token: + batch_embeddings = self._extract_center_token(batch_embeddings) return batch_embeddings, labels diff --git a/olmoearth_pretrain/evals/linear_probe.py b/olmoearth_pretrain/evals/linear_probe.py index 3e728c130..79e6bb58d 100644 --- a/olmoearth_pretrain/evals/linear_probe.py +++ b/olmoearth_pretrain/evals/linear_probe.py @@ -33,6 +33,7 @@ class ProbeType(StrEnum): ATTNPOOL = "attnpool" LINEAR = "linear" + INTERPOLATE_LINEAR = "interpolate_linear" class AttnPoolLinearProbe(nn.Module): @@ -40,7 +41,9 @@ class AttnPoolLinearProbe(nn.Module): Args: in_dim (int): Input feature dimension. Must be divisible by 64. - out_dim (int): Output dimension (typically num_classes * patch_size * patch_size). + num_classes (int): Number of output classes. + task_type (TaskType): Must be SEGMENTATION. + num_output_pixels_per_side_of_patch (int | None): Number of output pixels per side of each patch. Attributes: query_token (nn.Parameter): Learnable query token for attention pooling. @@ -49,10 +52,25 @@ class AttnPoolLinearProbe(nn.Module): linear (nn.Linear): Final linear layer for output logits. """ - def __init__(self, in_dim: int, out_dim: int) -> None: + def __init__( + self, + in_dim: int, + num_classes: int, + task_type: TaskType, + num_output_pixels_per_side_of_patch: int | None = None, + ) -> None: """Initialize the attention pooling linear probe.""" super().__init__() + if task_type != TaskType.SEGMENTATION: + raise ValueError("AttnPoolLinearProbe only supports segmentation") + if num_output_pixels_per_side_of_patch is None: + raise ValueError( + "num_output_pixels_per_side_of_patch is required for AttnPoolLinearProbe" + ) assert in_dim % 64 == 0, "in_dim must be divisible by 64" + out_dim = num_classes * num_output_pixels_per_side_of_patch**2 + self.num_classes = num_classes + self.num_output_pixels_per_side_of_patch = num_output_pixels_per_side_of_patch self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim)) self.num_heads: int = in_dim // 64 self.kv: nn.Linear = nn.Linear(in_dim, in_dim * 2) @@ -74,9 +92,9 @@ def forward(self, feat_tokens: torch.Tensor) -> dict: feat_tokens (torch.Tensor): Input feature tokens of shape (B, H, W, N, D). Returns: - tuple[torch.Tensor, torch.Tensor]: - - Output logits after linear layer, shape (B, H, W, out_dim). - - Attention weights, shape (B*H*W, num_heads, 1, N). + dict with: + - "logits": Output logits, shape (B, C, H_out, W_out). + - "attn_weights": Attention weights, shape (B*H*W, num_heads, 1, N). """ B, H, W, N, D = feat_tokens.shape feat_tokens = rearrange(feat_tokens, "b h w n d -> (b h w) n d") @@ -98,24 +116,115 @@ def forward(self, feat_tokens: torch.Tensor) -> dict: attn_weights = F.softmax(attn_scores, dim=-1) x = torch.matmul(attn_weights, v) # [B, head, 1, D_head] x = x.reshape(B, H, W, D) - return {"logits": self.linear(x), "attn_weights": attn_weights} + logits = self.linear(x) # (B, H, W, out_dim) + logits = rearrange( + logits, + "b h w (c i j) -> b c (h i) (w j)", + c=self.num_classes, + i=self.num_output_pixels_per_side_of_patch, + j=self.num_output_pixels_per_side_of_patch, + ) + return {"logits": logits, "attn_weights": attn_weights} + + +class InterpolateLinearProbe(nn.Module): + """Probe that bilinear-interpolates embeddings to full resolution then applies a per-pixel linear layer. + + For segmentation only. Takes (B, H_p, W_p, D) embeddings, upsamples to + (B, H_p * num_output_pixels_per_side_of_patch, ..., D), then applies Linear(D, num_classes). + """ + + def __init__( + self, + in_dim: int, + num_classes: int, + task_type: TaskType, + num_output_pixels_per_side_of_patch: int | None = None, + ) -> None: + """Initialize the interpolate linear probe.""" + super().__init__() + if task_type != TaskType.SEGMENTATION: + raise ValueError("InterpolateLinearProbe only supports segmentation") + if num_output_pixels_per_side_of_patch is None: + raise ValueError( + "num_output_pixels_per_side_of_patch is required for InterpolateLinearProbe" + ) + self.linear = nn.Linear(in_dim, num_classes) + self.num_output_pixels_per_side_of_patch = num_output_pixels_per_side_of_patch + + def forward(self, x: torch.Tensor) -> dict: + """Forward pass: bilinear upsample embeddings, then per-pixel linear. + + Args: + x: Embedding tensor of shape (B, H_p, W_p, D). + + Returns: + dict with "logits" of shape (B, C, H, W). + """ + B, H_p, W_p, D = x.shape + target_hw = H_p * self.num_output_pixels_per_side_of_patch + x = rearrange(x, "b h w d -> b d h w") + x = F.interpolate( + x, + size=(target_hw, target_hw), + mode="bilinear", + align_corners=True, + ) + x = rearrange(x, "b d h w -> b h w d") + logits = self.linear(x) # (B, target_hw, target_hw, C) + logits = rearrange(logits, "b h w c -> b c h w") + return {"logits": logits} class LinearProbe(nn.Module): - """Linear Probe for classification tasks.""" + """Linear Probe for classification and segmentation tasks. + + For classification: applies BatchNorm1d then Linear(D, num_classes). + For segmentation: applies Linear(D, num_classes * ps^2) then rearranges to (B, C, H, W). + """ - def __init__(self, in_dim: int, out_dim: int, use_batchnorm: bool = False) -> None: + def __init__( + self, + in_dim: int, + num_classes: int, + task_type: TaskType, + num_output_pixels_per_side_of_patch: int | None = None, + ) -> None: """Initialize the linear probe.""" super().__init__() - self.linear = nn.Linear(in_dim, out_dim) - if use_batchnorm: - self.batchnorm = nn.BatchNorm1d(in_dim) + self.task_type = task_type + self.num_classes = num_classes + self.num_output_pixels_per_side_of_patch = num_output_pixels_per_side_of_patch + if task_type == TaskType.SEGMENTATION: + assert num_output_pixels_per_side_of_patch is not None, ( + "num_output_pixels_per_side_of_patch is required for segmentation" + ) + out_dim = num_classes * num_output_pixels_per_side_of_patch**2 + self.batchnorm: nn.Module = nn.Identity() else: - self.batchnorm = nn.Identity() + out_dim = num_classes + self.batchnorm = nn.BatchNorm1d(in_dim) + self.linear = nn.Linear(in_dim, out_dim) def forward(self, x: torch.Tensor) -> dict: """Forward pass for linear probe.""" - return {"logits": self.linear(self.batchnorm(x))} + logits = self.linear(self.batchnorm(x)) + if self.task_type == TaskType.SEGMENTATION: + logits = rearrange( + logits, + "b h w (c i j) -> b c (h i) (w j)", + c=self.num_classes, + i=self.num_output_pixels_per_side_of_patch, + j=self.num_output_pixels_per_side_of_patch, + ) + return {"logits": logits} + + +PROBE_TYPE_TO_CLASS: dict[ProbeType, type[nn.Module]] = { + ProbeType.LINEAR: LinearProbe, + ProbeType.ATTNPOOL: AttnPoolLinearProbe, + ProbeType.INTERPOLATE_LINEAR: InterpolateLinearProbe, +} def train_and_eval_probe( @@ -162,27 +271,14 @@ def train_and_eval_probe( output_pixels_per_side_of_patch = int( (config.height_width**2 / num_patches) ** 0.5 ) - num_output_pixels = config.num_classes * output_pixels_per_side_of_patch**2 - logits_per_patch = num_output_pixels - if probe_type == ProbeType.ATTNPOOL: - probe = AttnPoolLinearProbe( - in_dim=in_features, out_dim=logits_per_patch - ).to(device) - elif probe_type == ProbeType.LINEAR: - probe = LinearProbe( - in_dim=in_features, out_dim=logits_per_patch, use_batchnorm=False - ).to(device) - else: - raise ValueError(f"Probe type {probe_type} not supported for segmentation.") - else: - if probe_type == ProbeType.LINEAR: - probe = LinearProbe( - in_dim=in_features, out_dim=config.num_classes, use_batchnorm=True - ).to(device) - else: - raise ValueError( - f"Probe type {probe_type} not supported for classification." - ) + + probe_cls = PROBE_TYPE_TO_CLASS[probe_type] + probe = probe_cls( + in_dim=in_features, + num_classes=config.num_classes, + task_type=config.task_type, + num_output_pixels_per_side_of_patch=output_pixels_per_side_of_patch, + ).to(device) num_times_to_run_eval = math.ceil(epochs / eval_interval) val_results: list[EvalResult] = [] @@ -201,15 +297,12 @@ def train_and_eval_probe( end_epoch = min(start_epoch + eval_interval, epochs) probe = train_probe( - task_type=config.task_type, probe=probe, data_loader=data_loader, lr=lr, epochs=end_epoch, total_epochs=epochs, current_epoch=start_epoch, - num_classes=config.num_classes, - num_output_pixels_per_side_of_patch=output_pixels_per_side_of_patch, device=device, ) val_result = evaluate_probe( @@ -220,7 +313,6 @@ def train_and_eval_probe( ), probe=probe, num_classes=config.num_classes, - num_output_pixels_per_side_of_patch=output_pixels_per_side_of_patch, device=device, task_type=config.task_type, probe_type=probe_type, @@ -281,11 +373,8 @@ def train_and_eval_probe( all_preds, all_labels = get_probe_predictions( data_loader=test_data_loader, probe=probe, - num_classes=config.num_classes, device=device, - task_type=config.task_type, probe_type=probe_type, - num_output_pixels_per_side_of_patch=output_pixels_per_side_of_patch, ) if n_bootstrap > 0: @@ -362,12 +451,9 @@ def train_probe( current_epoch: int, epochs: int, total_epochs: int, - num_classes: int, device: torch.device, - task_type: TaskType, - num_output_pixels_per_side_of_patch: int | None = None, ) -> nn.Module: - """Train a linear probe on a segmentation task.""" + """Train a linear probe on a classification or segmentation task.""" opt = torch.optim.AdamW(probe.parameters(), lr=lr) probe = probe.train() @@ -375,40 +461,12 @@ def train_probe( start_epoch = current_epoch for epoch in range(start_epoch, epochs): for i, batch in enumerate(data_loader): - batch_emb, batch_labels = batch # (bsz, t_h, t_w, dim), (bsz, H, W) - spatial_patches_per_dim = batch_emb.shape[1] + batch_emb, batch_labels = batch batch_emb = batch_emb.to(device) with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): - outputs = probe( - batch_emb - ) # (bsz, num_patches, logits_per_patch) or (bsz, n_cls) + outputs = probe(batch_emb) logits = outputs["logits"] - # logger.info(f"logits: {logits.shape}") - if task_type == TaskType.SEGMENTATION: - assert num_output_pixels_per_side_of_patch is not None, ( - "num_output_pixels_per_side_of_patch is required for segmentation" - ) - # This is effectively nearest neighbor interpolation - logits = rearrange( - logits, - "b h w (c i j) -> b c (h i) (w j)", - h=spatial_patches_per_dim, - w=spatial_patches_per_dim, - c=num_classes, - i=num_output_pixels_per_side_of_patch, - j=num_output_pixels_per_side_of_patch, - ) - if logits.shape[-2] != batch_labels.shape[-2]: - logger.debug( - f"Logits shape {logits.shape} does not match batch_labels shape {batch_labels.shape} interpolating to labels shape" - ) - logits = F.interpolate( - logits, - size=(batch_labels.shape[-2], batch_labels.shape[-1]), - mode="bilinear", - align_corners=True, - ) # (bsz, num_classes, H, W) loss = loss_function(logits, batch_labels.to(device)) loss.backward() @@ -430,11 +488,8 @@ def train_probe( def get_probe_predictions( data_loader: DataLoader, probe: nn.Module, - num_classes: int, device: torch.device, - task_type: TaskType, probe_type: ProbeType, - num_output_pixels_per_side_of_patch: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Get predictions from a trained linear probe. @@ -448,33 +503,12 @@ def get_probe_predictions( all_attn_weights = [] with torch.no_grad(): for batch in data_loader: - batch_emb, batch_labels = batch # (bsz, num_patches, dim), (bsz, H, W) + batch_emb, batch_labels = batch batch_emb = batch_emb.to(device) with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): - outputs = probe(batch_emb) # (bsz, num_patches, logits_per_patch) + outputs = probe(batch_emb) logits = outputs["logits"] - if task_type == TaskType.SEGMENTATION: - assert num_output_pixels_per_side_of_patch is not None, ( - "num_output_pixels_per_side_of_patch is required for segmentation" - ) - spatial_patches_per_dim = batch_emb.shape[1] - logits = rearrange( - logits, - "b h w (c i j) -> b c (h i) (w j)", - h=spatial_patches_per_dim, - w=spatial_patches_per_dim, - c=num_classes, - i=num_output_pixels_per_side_of_patch, - j=num_output_pixels_per_side_of_patch, - ) - if logits.shape[-2] != batch_labels.shape[-2]: - logits = F.interpolate( - logits, - size=(batch_labels.shape[-2], batch_labels.shape[-1]), - mode="bilinear", - align_corners=True, - ) # (bsz, num_classes, H, W) preds = torch.argmax(logits, dim=1).cpu() all_preds.append(preds) @@ -527,7 +561,6 @@ def evaluate_probe( device: torch.device, task_type: TaskType, probe_type: ProbeType, - num_output_pixels_per_side_of_patch: int | None = None, ) -> EvalResult: """Evaluate a trained linear probe on a segmentation or classification task. @@ -537,10 +570,7 @@ def evaluate_probe( preds, labels = get_probe_predictions( data_loader=data_loader, probe=probe, - num_classes=num_classes, device=device, - task_type=task_type, probe_type=probe_type, - num_output_pixels_per_side_of_patch=num_output_pixels_per_side_of_patch, ) return compute_metric(preds, labels, num_classes, task_type) diff --git a/olmoearth_pretrain/train/callbacks/evaluator_callback.py b/olmoearth_pretrain/train/callbacks/evaluator_callback.py index 1fc9ad942..7324c561c 100644 --- a/olmoearth_pretrain/train/callbacks/evaluator_callback.py +++ b/olmoearth_pretrain/train/callbacks/evaluator_callback.py @@ -92,6 +92,9 @@ class DownstreamTaskConfig: eval_mode: EvalMode | None = None probe_type: ProbeType = ProbeType.LINEAR use_pooled_tokens: bool = False + # Use the center spatial patch embedding instead of pooling across all patches + # for classification tasks. Has no effect on segmentation tasks. + use_center_token: bool = False partition: str = field(default_factory=lambda: EvalDatasetPartition.TRAIN1X) norm_method: NormMethod = field(default_factory=lambda: NormMethod.NORM_NO_CLIP) select_final_test_miou_based_on_epoch_of_max_val_miou: bool = False @@ -152,6 +155,7 @@ def __init__( self.partition = task.partition self.norm_method = task.norm_method self.use_pooled_tokens = task.use_pooled_tokens + self.use_center_token = task.use_center_token self.select_final_test_miou_based_on_epoch_of_max_val_miou = ( task.select_final_test_miou_based_on_epoch_of_max_val_miou ) @@ -280,6 +284,7 @@ def _get_embeddings( "pooling_type": self.pooling_type, "concat_features": (self.probe_type == "attn_pool"), "use_pooled_tokens": self.use_pooled_tokens, + "use_center_token": self.use_center_token, } model = get_eval_wrapper(model, **wrapper_kwargs) return get_embeddings( diff --git a/scripts/archived/2026_02_19_eval_changes/base.py b/scripts/archived/2026_02_19_eval_changes/base.py new file mode 100644 index 000000000..365ddcbae --- /dev/null +++ b/scripts/archived/2026_02_19_eval_changes/base.py @@ -0,0 +1,67 @@ +"""Trying to prototype fitting everything into olmo core.""" + +import logging + +from script import ( + build_common_components, + build_dataloader_config, + build_dataset_config, + build_train_module_config, + build_trainer_config, + build_visualize_config, +) + +from olmoearth_pretrain.internal.experiment import CommonComponents, main +from olmoearth_pretrain.internal.utils import MODEL_SIZE_ARGS +from olmoearth_pretrain.nn.flexihelios import ( + EncoderConfig, + PredictorConfig, +) +from olmoearth_pretrain.nn.latent_mim import LatentMIMConfig + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 + + +def build_model_config(common: CommonComponents) -> LatentMIMConfig: + """Build the model config for an experiment.""" + model_size = MODEL_SIZE_ARGS["base_shallow_decoder"] + + encoder_config = EncoderConfig( + embedding_size=model_size["encoder_embedding_size"], + num_heads=model_size["encoder_num_heads"], + depth=model_size["encoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + supported_modality_names=common.training_modalities, + max_patch_size=MAX_PATCH_SIZE, + drop_path=0.1, + max_sequence_length=12, + ) + decoder_config = PredictorConfig( + encoder_embedding_size=model_size["encoder_embedding_size"], + decoder_embedding_size=model_size["decoder_embedding_size"], + depth=model_size["decoder_depth"], + mlp_ratio=model_size["mlp_ratio"], + num_heads=model_size["decoder_num_heads"], + supported_modality_names=common.training_modalities, + max_sequence_length=12, + ) + model_config = LatentMIMConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + return model_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) diff --git a/scripts/archived/2026_02_19_eval_changes/eval_cls_center_token.py b/scripts/archived/2026_02_19_eval_changes/eval_cls_center_token.py new file mode 100644 index 000000000..8a4daf26e --- /dev/null +++ b/scripts/archived/2026_02_19_eval_changes/eval_cls_center_token.py @@ -0,0 +1,122 @@ +"""Evaluate the base model on classification tasks: baseline vs use_center_token. + +Usage: + torchrun scripts/archived/2026_02_19_eval_changes/eval_cls_center_token.py \ + evaluate eval-center-token-comparison local \ + --trainer.load_path=CHECKPOINT_DIR + +The load_path should point to the directory containing the base model checkpoint. +""" + +import logging + +from base import build_model_config +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig +from script import ( + build_common_components, + build_dataloader_config, + build_dataset_config, + build_train_module_config, + build_visualize_config, +) + +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.internal.experiment import CommonComponents, main +from olmoearth_pretrain.nn.flexi_vit import PoolingType +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import DownstreamTaskConfig + +logger = logging.getLogger(__name__) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build trainer config that runs classification evals and exits.""" + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project="eval_cls_center_token", + entity="eai-ai2", + enabled=True, + upload_dataset_distribution_pre_train=False, + upload_modality_data_band_distribution_pre_train=False, + ) + + cls_datasets = [ + ("m_eurosat", "m-eurosat", 128, 8), + ("m_so2sat", "m-so2sat", 128, 8), + ("m_brick_kiln", "m-brick-kiln", 128, 8), + ("m_bigearthnet", "m-bigearthnet", 64, 8), + ] + + eval_tasks: dict[str, DownstreamTaskConfig] = {} + for use_center in [False, True]: + suffix = "center" if use_center else "baseline" + for name, dataset, batch_size, workers in cls_datasets: + eval_tasks[f"{name}_{suffix}"] = DownstreamTaskConfig( + dataset=dataset, + embedding_batch_size=batch_size, + num_workers=workers, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + use_center_token=use_center, + eval_interval=Duration.epochs(1), + ) + + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LoadStrategy.if_available, + save_folder=common.save_folder, + cancel_check_interval=1, + metrics_collect_interval=10, + max_duration=Duration.epochs(9999), + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=eval_tasks, + eval_on_startup=True, + cancel_after_first_eval=True, + run_on_test=True, + ), + ) + .with_callback("garbage_collector", GarbageCollectorCallback(gc_interval=1)) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=5000, + ephemeral_save_interval=250, + ), + ) + ) + return trainer_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) diff --git a/scripts/archived/2026_02_19_eval_changes/eval_seg_probe_comparison.py b/scripts/archived/2026_02_19_eval_changes/eval_seg_probe_comparison.py new file mode 100644 index 000000000..a523c0dc9 --- /dev/null +++ b/scripts/archived/2026_02_19_eval_changes/eval_seg_probe_comparison.py @@ -0,0 +1,165 @@ +"""Evaluate the base model on segmentation tasks with LinearProbe vs InterpolateLinearProbe. + +Usage: + torchrun scripts/archived/2026_02_19_eval_changes/eval_seg_probe_comparison.py \ + evaluate eval-seg-probe-comparison local \ + --trainer.load_path=CHECKPOINT_DIR + +The load_path should point to the directory containing the base model checkpoint. +""" + +import logging + +from base import build_model_config +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig +from script import ( + build_common_components, + build_dataloader_config, + build_dataset_config, + build_train_module_config, + build_visualize_config, +) + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.evals.datasets.normalize import NormMethod +from olmoearth_pretrain.evals.linear_probe import ProbeType +from olmoearth_pretrain.internal.experiment import CommonComponents, main +from olmoearth_pretrain.nn.flexi_vit import PoolingType +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import DownstreamTaskConfig + +logger = logging.getLogger(__name__) + +PROBE_LR = 0.001 +PROBE_EPOCHS = 50 + + +def _seg_task( + dataset: str, + probe_type: ProbeType, + embedding_batch_size: int = 128, + probe_batch_size: int = 128, + num_workers: int = 8, + input_modalities: list[str] | None = None, +) -> DownstreamTaskConfig: + """Helper to build a segmentation linear probe task config.""" + return DownstreamTaskConfig( + dataset=dataset, + embedding_batch_size=embedding_batch_size, + probe_batch_size=probe_batch_size, + num_workers=num_workers, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + norm_method=NormMethod.NORM_NO_CLIP_2_STD, + probe_lr=PROBE_LR, + epochs=PROBE_EPOCHS, + eval_mode="LINEAR_PROBE", + probe_type=probe_type, + eval_interval=Duration.epochs(PROBE_EPOCHS), + input_modalities=input_modalities or [], + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build trainer config that runs segmentation evals and exits.""" + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project="eval_seg_probe_comparison", + entity="eai-ai2", + enabled=True, + upload_dataset_distribution_pre_train=False, + upload_modality_data_band_distribution_pre_train=False, + ) + + eval_tasks: dict[str, DownstreamTaskConfig] = {} + for probe_type in [ProbeType.LINEAR, ProbeType.INTERPOLATE_LINEAR]: + suffix = probe_type.value # "linear" or "interpolate_linear" + eval_tasks[f"mados_{suffix}"] = _seg_task( + "mados", + probe_type, + ) + eval_tasks[f"sen1floods11_{suffix}"] = _seg_task( + "sen1floods11", + probe_type, + ) + eval_tasks[f"m_cashew_plant_{suffix}"] = _seg_task( + "m-cashew-plant", + probe_type, + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + ) + eval_tasks[f"m_sa_crop_type_{suffix}"] = _seg_task( + "m-sa-crop-type", + probe_type, + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + ) + eval_tasks[f"pastis_s2_{suffix}"] = _seg_task( + "pastis", + probe_type, + embedding_batch_size=32, + probe_batch_size=8, + num_workers=2, + input_modalities=[Modality.SENTINEL2_L2A.name], + ) + + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LoadStrategy.if_available, + save_folder=common.save_folder, + cancel_check_interval=1, + metrics_collect_interval=10, + max_duration=Duration.epochs(9999), + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=eval_tasks, + eval_on_startup=True, + cancel_after_first_eval=True, + run_on_test=True, + ), + ) + .with_callback("garbage_collector", GarbageCollectorCallback(gc_interval=1)) + .with_callback("beaker", BeakerCallback()) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=5000, + ephemeral_save_interval=250, + ), + ) + ) + return trainer_config + + +if __name__ == "__main__": + main( + common_components_builder=build_common_components, + model_config_builder=build_model_config, + train_module_config_builder=build_train_module_config, + dataset_config_builder=build_dataset_config, + dataloader_config_builder=build_dataloader_config, + trainer_config_builder=build_trainer_config, + visualize_config_builder=build_visualize_config, + ) diff --git a/scripts/archived/2026_02_19_eval_changes/script.py b/scripts/archived/2026_02_19_eval_changes/script.py new file mode 100644 index 000000000..5ab5e65f6 --- /dev/null +++ b/scripts/archived/2026_02_19_eval_changes/script.py @@ -0,0 +1,275 @@ +"""Trying to prototype fitting everything into olmo core.""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel.data_parallel import ( + DataParallelConfig, + DataParallelType, +) +from olmo_core.optim import AdamWConfig +from olmo_core.optim.scheduler import CosWithWarmup +from olmo_core.train.callbacks import ( + BeakerCallback, + CheckpointerCallback, + ConfigSaverCallback, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.train.common import Duration, LoadStrategy +from olmo_core.train.config import TrainerConfig + +from olmoearth_pretrain.data.constants import Modality +from olmoearth_pretrain.data.dataloader import OlmoEarthDataLoaderConfig +from olmoearth_pretrain.data.dataset import OlmoEarthDatasetConfig +from olmoearth_pretrain.internal.common import ( + build_common_components as build_common_components_default, +) +from olmoearth_pretrain.internal.experiment import ( + CommonComponents, + OlmoEarthVisualizeConfig, + SubCmd, +) +from olmoearth_pretrain.nn.flexi_vit import ( + PoolingType, +) +from olmoearth_pretrain.train.callbacks import ( + DownstreamEvaluatorCallbackConfig, + OlmoEarthSpeedMonitorCallback, + OlmoEarthWandBCallback, +) +from olmoearth_pretrain.train.callbacks.evaluator_callback import DownstreamTaskConfig +from olmoearth_pretrain.train.loss import LossConfig +from olmoearth_pretrain.train.masking import MaskingConfig +from olmoearth_pretrain.train.train_module.contrastive_latentmim import ( + ContrastiveLatentMIMTrainModuleConfig, +) + +logger = logging.getLogger(__name__) + +MAX_PATCH_SIZE = 8 +MIN_PATCH_SIZE = 1 + + +def build_common_components( + script: str, cmd: SubCmd, run_name: str, cluster: str, overrides: list[str] +) -> CommonComponents: + """Build the common components for an experiment.""" + config = build_common_components_default(script, cmd, run_name, cluster, overrides) + config.training_modalities = [ + Modality.SENTINEL2_L2A.name, + Modality.SENTINEL1.name, + Modality.LANDSAT.name, + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + ] + return config + + +def get_masking_config(common: CommonComponents) -> MaskingConfig: + """Get the masking configuration for the experiment. + + Args: + common: Common experiment components containing optional tokenization_config. + """ + return MaskingConfig( + strategy_config={ + "type": "modality_cross_random", + "encode_ratio": 0.5, + "decode_ratio": 0.5, + "allow_encoding_decoding_same_bandset": True, + "only_decode_modalities": [ + Modality.WORLDCOVER.name, + Modality.SRTM.name, + Modality.OPENSTREETMAP_RASTER.name, + Modality.WRI_CANOPY_HEIGHT_MAP.name, + Modality.CDL.name, + Modality.WORLDCEREAL.name, + ], + }, + tokenization_config=common.tokenization_config, + ) + + +def build_train_module_config( + common: CommonComponents, +) -> ContrastiveLatentMIMTrainModuleConfig: + """Build the train module config for an experiment. + + Args: + common: Common experiment components. + """ + # The train module still needs the masking_config for reference (e.g., for metric + # naming), but the actual masking happens in the dataloader workers. + return ContrastiveLatentMIMTrainModuleConfig( + optim_config=AdamWConfig(lr=0.0001, weight_decay=0.02, fused=False), + rank_microbatch_size=32, + masking_config=get_masking_config(common), + loss_config=LossConfig( + loss_config={ + "type": "modality_patch_discrimination_new", + "tau": 0.1, + } + ), + contrastive_config=LossConfig( + loss_config={ + "type": "InfoNCE", + "weight": 0.1, + } + ), + token_exit_cfg={modality: 0 for modality in common.training_modalities}, + max_grad_norm=1.0, + scheduler=CosWithWarmup(warmup_steps=8000), + ema_decay=(1.0, 1.0), + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + ), + ) + + +def build_dataloader_config( + common: CommonComponents, +) -> OlmoEarthDataLoaderConfig: + """Build the dataloader config for an experiment. + + Masking is performed in the dataloader workers (CPU) instead of in the train module + (GPU). This improves throughput by offloading CPU-bound masking operations to + dataloader workers. + + Args: + common: Common experiment components. + """ + return OlmoEarthDataLoaderConfig( + num_workers=12, + global_batch_size=512, + token_budget=2250, + prefetch_factor=2, + sampled_hw_p_list=list(range(1, 13)), # try only temporal tokens + min_patch_size=MIN_PATCH_SIZE, + max_patch_size=MAX_PATCH_SIZE, + work_dir=common.save_folder, + seed=3622, + num_masked_views=2, # ContrastiveLatentMIM needs 2 views + masking_config=get_masking_config(common), + # masking_config_b is not set, so both views use the same strategy + tokenization_config=common.tokenization_config, + ) + + +def build_dataset_config(common: CommonComponents) -> OlmoEarthDatasetConfig: + """Build the dataset config for an experiment.""" + return OlmoEarthDatasetConfig( + h5py_dir="/weka/dfive-default/helios/dataset/osm_sampling/h5py_data_w_missing_timesteps_zstd_3_128_x_4/cdl_gse_landsat_openstreetmap_raster_sentinel1_sentinel2_l2a_srtm_worldcereal_worldcover_worldpop_wri_canopy_height_map/1138828", + training_modalities=common.training_modalities, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + """Build the trainer config for an experiment.""" + MAX_DURATION = Duration.epochs(300) + METRICS_COLLECT_INTERVAL = 10 + CANCEL_CHECK_INTERVAL = 25 + LOAD_STRATEGY = LoadStrategy.if_available + WANDB_USERNAME = "eai-ai2" # nosec + WANDB_PROJECT = "2025_10_02_phase2" + PERMANENT_SAVE_INTERVAL = 5000 + EPHERMERAL_SAVE_INTERVAL = 250 + checkpointer_config = CheckpointerConfig(work_dir=common.save_folder) + wandb_callback = OlmoEarthWandBCallback( + name=common.run_name, + project=WANDB_PROJECT, + entity=WANDB_USERNAME, + enabled=True, # set to False to avoid wandb errors + ) + # Safe to collect everys tep for now + garbage_collector_callback = GarbageCollectorCallback(gc_interval=1) + EVAL_TASKS = { + "m-eurosat": DownstreamTaskConfig( + dataset="m-eurosat", + embedding_batch_size=128, + num_workers=0, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(4000), + ), + "m_so2sat": DownstreamTaskConfig( + dataset="m-so2sat", + embedding_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + eval_interval=Duration.steps(20000), + ), + "mados": DownstreamTaskConfig( + dataset="mados", + embedding_batch_size=128, + probe_batch_size=128, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=False, + probe_lr=0.01, + epochs=50, + eval_interval=Duration.steps(4000), + ), + "pastis": DownstreamTaskConfig( + dataset="pastis", + embedding_batch_size=32, + probe_batch_size=8, + num_workers=8, + pooling_type=PoolingType.MEAN, + norm_stats_from_pretrained=True, + probe_lr=0.1, + eval_interval=Duration.steps(20000), + input_modalities=[Modality.SENTINEL2_L2A.name], + epochs=50, + ), + } + trainer_config = ( + TrainerConfig( + work_dir=common.save_folder, + load_strategy=LOAD_STRATEGY, + save_folder=common.save_folder, + cancel_check_interval=CANCEL_CHECK_INTERVAL, + metrics_collect_interval=METRICS_COLLECT_INTERVAL, + max_duration=MAX_DURATION, + checkpointer=checkpointer_config, + ) + .with_callback("wandb", wandb_callback) + .with_callback("speed_monitor", OlmoEarthSpeedMonitorCallback()) + .with_callback("gpu_memory_monitor", GPUMemoryMonitorCallback()) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=EVAL_TASKS, + ), + ) + .with_callback("garbage_collector", garbage_collector_callback) + .with_callback( + "beaker", BeakerCallback() + ) # this shoukd not be here, but for now it is + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=PERMANENT_SAVE_INTERVAL, + ephemeral_save_interval=EPHERMERAL_SAVE_INTERVAL, + ), + ) + ) + return trainer_config + + +def build_visualize_config(common: CommonComponents) -> OlmoEarthVisualizeConfig: + """Build the visualize config for an experiment.""" + return OlmoEarthVisualizeConfig( + num_samples=None, + output_dir=str(f"{common.save_folder}/visualizations"), + std_multiplier=2.0, + ) diff --git a/tests/unit/eval/test_eval_wrapper.py b/tests/unit/eval/test_eval_wrapper.py new file mode 100644 index 000000000..b9d9c8725 --- /dev/null +++ b/tests/unit/eval/test_eval_wrapper.py @@ -0,0 +1,33 @@ +"""Unit tests for eval wrapper.""" + +import torch + +from olmoearth_pretrain.evals.eval_wrapper import EvalWrapper + + +class TestExtractCenterToken: + """Tests for _extract_center_token static method.""" + + def test_odd_spatial_dims(self) -> None: + """Get center token for odd dimensions.""" + B, H, W, D = 2, 7, 7, 64 + x = torch.randn(B, H, W, D) + result = EvalWrapper._extract_center_token(x) + assert result.shape == (B, D) + assert torch.equal(result, x[:, 3, 3, :]) + + def test_even_spatial_dims(self) -> None: + """Get bottom-right of center for even dimensions.""" + B, H, W, D = 2, 8, 8, 64 + x = torch.randn(B, H, W, D) + result = EvalWrapper._extract_center_token(x) + assert result.shape == (B, D) + assert torch.equal(result, x[:, 4, 4, :]) + + def test_non_square(self) -> None: + """Correct center for non-square dimensions.""" + B, H, W, D = 3, 4, 6, 32 + x = torch.randn(B, H, W, D) + result = EvalWrapper._extract_center_token(x) + assert result.shape == (B, D) + assert torch.equal(result, x[:, 2, 3, :]) diff --git a/tests/unit/eval/test_linear_probe.py b/tests/unit/eval/test_linear_probe.py new file mode 100644 index 000000000..68df70534 --- /dev/null +++ b/tests/unit/eval/test_linear_probe.py @@ -0,0 +1,61 @@ +"""Unit tests for probe modules in linear_probe.py.""" + +import pytest +import torch + +from olmoearth_pretrain.evals.datasets.configs import TaskType +from olmoearth_pretrain.evals.linear_probe import ( + InterpolateLinearProbe, + LinearProbe, +) + + +class TestLinearProbeClassification: + """Tests for LinearProbe.""" + + def test_output_shape_classification(self) -> None: + """Classification probe: (B, D) -> (B, C).""" + probe = LinearProbe(in_dim=32, num_classes=5, task_type=TaskType.CLASSIFICATION) + x = torch.randn(4, 32) + logits = probe(x)["logits"] + assert logits.shape == (4, 5) + + def test_output_shape_segmentation(self) -> None: + """Segmentation probe: (B, H_p, W_p, D) -> (B, C, H, W).""" + probe = LinearProbe( + in_dim=32, + num_classes=5, + task_type=TaskType.SEGMENTATION, + num_output_pixels_per_side_of_patch=4, + ) + # 8 patches per dim * 4 pixels per patch = 32 output pixels per dim + x = torch.randn(2, 8, 8, 32) + logits = probe(x)["logits"] + assert logits.shape == (2, 5, 32, 32) + + +class TestInterpolateLinearProbe: + """Tests for InterpolateLinearProbe.""" + + def test_output_shape_segmentation(self) -> None: + """InterpolateLinearProbe: (B, H_p, W_p, D) -> (B, C, H, W) via bilinear upsample.""" + probe = InterpolateLinearProbe( + in_dim=32, + num_classes=5, + task_type=TaskType.SEGMENTATION, + num_output_pixels_per_side_of_patch=4, + ) + x = torch.randn(2, 8, 8, 32) + logits = probe(x)["logits"] + # 8 patches * 4 pixels per patch = 32 + assert logits.shape == (2, 5, 32, 32) + + def test_rejects_classification(self) -> None: + """InterpolateLinearProbe should reject non-segmentation tasks.""" + with pytest.raises(ValueError, match="only supports segmentation"): + InterpolateLinearProbe( + in_dim=32, + num_classes=5, + task_type=TaskType.CLASSIFICATION, + num_output_pixels_per_side_of_patch=4, + )