Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions olmoearth_pretrain/evals/eval_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
pooling_type: PoolingType,
concat_features: bool = False,
use_pooled_tokens: bool = False,
use_center_token: bool = False,
):
"""Initialize the eval wrapper.

Expand All @@ -58,7 +59,8 @@ def __init__(
pooling_type: The pooling type to use for the model.
concat_features: Whether to concatenate features across modalities.
use_pooled_tokens: Whether to use pooled tokens.
is_train: whether this is being used on the training data.
use_center_token: Whether to use the center spatial patch embedding instead
of pooling across all patches for classification tasks.
"""
super().__init__()
self.model = model
Expand All @@ -68,6 +70,12 @@ def __init__(
self.concat_features = concat_features
self.spatial_pool = task_type == TaskType.SEGMENTATION
self.use_pooled_tokens = use_pooled_tokens
self.use_center_token = use_center_token
if self.use_center_token and self.spatial_pool:
raise ValueError(
"use_center_token is only supported for classification tasks, "
"not segmentation (spatial_pool=True)"
)
if self.use_pooled_tokens:
assert isinstance(self.model, EncodeEarlyAttnPool), (
"Pooled tokens are only supported for EncodeEarlyAttnPool"
Expand All @@ -91,6 +99,19 @@ def __getattr__(self, name: str) -> Any:
"""Delegate attribute access to the underlying model if the attribute is not found on the wrapper."""
return getattr(self.model, name)

@staticmethod
def _extract_center_token(spatial_embeddings: torch.Tensor) -> torch.Tensor:
"""Extract the center spatial patch embedding.

Args:
spatial_embeddings: Tensor of shape (B, H, W, D).

Returns:
Tensor of shape (B, D) from the center patch.
"""
H, W = spatial_embeddings.shape[1], spatial_embeddings.shape[2]
return spatial_embeddings[:, H // 2, W // 2, :]

def __call__(
self,
masked_olmoearth_sample: MaskedOlmoEarthSample,
Expand All @@ -116,11 +137,20 @@ def __call__(
masked_olmoearth_sample, patch_size=self.patch_size, fast_pass=True
)["tokens_and_masks"] # (bsz, dim)
# Concat features across modalities in space averaged across time
batch_embeddings = batch_embeddings.pool_unmasked_tokens(
self.pooling_type,
spatial_pooling=self.spatial_pool,
concat_features=self.concat_features,
)
if self.use_center_token:
# Get spatial embeddings (B, H, W, D) then take center patch
batch_embeddings = batch_embeddings.pool_unmasked_tokens(
self.pooling_type,
spatial_pooling=True,
concat_features=self.concat_features,
)
batch_embeddings = self._extract_center_token(batch_embeddings)
else:
batch_embeddings = batch_embeddings.pool_unmasked_tokens(
self.pooling_type,
spatial_pooling=self.spatial_pool,
concat_features=self.concat_features,
)
else:
pooled_tokens_dict = self.model(
masked_olmoearth_sample, patch_size=self.patch_size, fast_pass=True
Expand All @@ -138,8 +168,16 @@ def __call__(
pooled_tokens = reduce(
pooled_tokens, "b h w ... d -> b h w d", self.pooling_type
)
elif self.use_center_token:
# Pool time but keep spatial, then take center patch
if pooled_tokens.shape[1] == 1 and pooled_tokens.ndim == 3:
pooled_tokens = pooled_tokens.unsqueeze(1)
pooled_tokens = reduce(
pooled_tokens, "b h w ... d -> b h w d", self.pooling_type
)
pooled_tokens = self._extract_center_token(pooled_tokens)
else:
# Take the mean of all dims excetp the first and last
# Take the mean of all dims except the first and last
pooled_tokens = reduce(
pooled_tokens, "b ... d -> b d", self.pooling_type
)
Expand All @@ -162,11 +200,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
batch_embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
return batch_embeddings, labels


Expand All @@ -180,11 +221,13 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
if self.spatial_pool:
if self.spatial_pool or self.use_center_token:
# Intermediate features are not yet working because of some bug internal to the model
batch_embeddings = self.model.forward_features(
masked_olmoearth_sample, pooling=self.pooling_type
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
else:
batch_embeddings = self.model(
masked_olmoearth_sample, pooling=self.pooling_type
Expand All @@ -202,11 +245,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
embeddings = self._extract_center_token(embeddings)
return embeddings, labels


Expand All @@ -220,11 +266,15 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
embeddings = self._extract_center_token(embeddings)
return embeddings, labels
if is_train and (self.task_type == TaskType.SEGMENTATION):
# this is a special case for AnySat. Since it outputs per-pixel embeddings,
# we subsample training pixels to keep the memory requirements reasonable.
Expand Down Expand Up @@ -267,11 +317,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
embeddings = self._extract_center_token(embeddings)
return embeddings, labels


Expand All @@ -285,11 +338,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
batch_embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
return batch_embeddings, labels


Expand All @@ -303,11 +359,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
batch_embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
return batch_embeddings, labels


Expand All @@ -321,11 +380,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
batch_embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
return batch_embeddings, labels


Expand All @@ -340,12 +402,14 @@ def __call__(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
# i need to do the apply imagenet normalizer thing in here
if self.spatial_pool:
if self.spatial_pool or self.use_center_token:
# Intermediate features are not yet working because of some bug internal to the model
batch_embeddings = self.model.forward_features(
masked_olmoearth_sample,
pooling=self.pooling_type,
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
else:
# should this call model ditectly
batch_embeddings = self.model(
Expand All @@ -365,11 +429,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
batch_embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
return batch_embeddings, labels


Expand All @@ -383,11 +450,14 @@ def __call__(
is_train: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the model produces the embedding specified by initialization."""
spatial_pool = self.spatial_pool or self.use_center_token
batch_embeddings = self.model(
masked_olmoearth_sample,
pooling=self.pooling_type,
spatial_pool=self.spatial_pool,
spatial_pool=spatial_pool,
)
if self.use_center_token:
batch_embeddings = self._extract_center_token(batch_embeddings)
return batch_embeddings, labels


Expand Down
Loading